我正在使用 TensorFlow 和 Keras 训练带有 ImageDataGenerator 的模型。但是,我遇到了警告消息并且训练速度很慢。以下是相关代码和问题...
的模型 ImageDataGenerator
。但是,我遇到了警告消息并且训练速度很慢。以下是相关代码和我面临的问题:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
# ImageDataGenerator for training and testing
train_generator = ImageDataGenerator(
rescale=1./255,
rotation_range=20,
width_shift_range=0.10,
height_shift_range=0.2,
preprocessing_function=preprocess_to_black_lines
)
test_generator = ImageDataGenerator(
rescale=1./255,
rotation_range=20,
width_shift_range=0.10,
height_shift_range=0.2,
preprocessing_function=preprocess_to_black_lines
)
train_images = train_generator.flow_from_dataframe(
dataframe=train_df,
x_col='address',
y_col='labels',
target_size=(64, 64),
batch_size=200,
color_mode='grayscale',
class_mode='categorical',
seed=42,
shuffle=True,
subset='training'
)
val_images = train_generator.flow_from_dataframe(
dataframe=val_df,
x_col='address',
y_col='labels',
target_size=(64, 64),
batch_size=200,
color_mode='grayscale',
class_mode='categorical',
seed=42,
shuffle=True
)
test_images = test_generator.flow_from_dataframe(
dataframe=test_df,
x_col='address',
y_col='labels',
target_size=(64, 64),
batch_size=32,
color_mode='grayscale',
class_mode='categorical',
seed=42,
shuffle=False
)
# Model definition
inputs = tf.keras.Input(shape=(64, 64, 1))
x = tf.keras.layers.Conv2D(filters=6, kernel_size=(3, 3), activation='relu')(inputs)
x = tf.keras.layers.MaxPooling2D()(x)
x = tf.keras.layers.Conv2D(filters=16, kernel_size=(3, 3), activation='relu')(x)
x = tf.keras.layers.MaxPooling2D()(x)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(100, activation='relu')(x)
outputs = tf.keras.layers.Dense(8, activation='softmax')(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy']
)
model.summary()
history = model.fit(
train_images,
validation_data=val_images,
epochs=100,
callbacks=[
tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
]
)
当我运行代码时,收到以下警告:
Epoch 1/100
/usr/local/lib/python3.10/dist-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:121: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored.
self._warn_if_super_not_called()
59/60 5s 6s/step - accuracy: 0.1277 - loss: 3.6410 /usr/local/lib/python3.10/dist-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:121: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored.
self._warn_if_super_not_called()
60/60 553s 8s/step - accuracy: 0.1279 - loss: 3.6051 - val_accuracy: 0.1367 - val_loss: 2.0608
Epoch 2/100
59/60 5s 6s/step - accuracy: 0.1709 - loss: 2.0589
此外,训练速度非常慢。
审查警告: 我检查了警告消息并尝试调整 ImageDataGenerator
。具体来说,我确保 worker
, use_multiprocessing
和 max_queue_size 未传递给 fit() 方法。优化数据管道:我验证了用于训练、验证和测试的数据管道配置是否正确并与预期参数一致。
问题:
这个警告是什么意思,我该如何修复它?我该如何提高训练速度?
任何帮助都将不胜感激!