(tensorflow 2.17,Windows 10)我无法使用子类化方法保存和恢复自定义模型。请在下面找到代码来重现该问题:import numpy as npimport tensorflow as tfx =...
(TensorFlow 2.17,Windows 10)
我无法使用子类化方法保存和恢复自定义模型。
请参阅以下代码来重现该问题:
import numpy as np
import tensorflow as tf
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
@tf.keras.utils.register_keras_serializable()
class MModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.dense1 = tf.keras.layers.Dense(10)
self.dense2 = tf.keras.layers.Dense(1)
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
model = MModel()
model.compile( optimizer="adam", loss="mse", metrics=["mae"])
model.fit( x, y, epochs=5)
model.save("save1.keras")
model_ = tf.keras.models.load_model("save1.keras")
print()
print(f"{model.evaluate(x, y,verbose=0) = }")
print(f"{model_.evaluate(x, y,verbose=0) = }")
for m, m_ in zip(model.weights, model_.weights):
np.testing.assert_allclose(m.numpy(), m_.numpy())
请注意, init () 没有参数,因此不需要覆盖 get_config()。
错误信息:
TypeError: <class '__main__.MModel'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by `get_config()` are explicitly deserialized in the model's `from_config()` method.
config={'module': None, 'class_name': 'MModel', 'config': {'trainable': True, 'dtype': {'module': 'keras', 'class_name': 'DTypePolicy', 'config': {'name': 'float32'}, 'registered_name': None}}, 'registered_name': 'Custom>MModel', 'build_config': {'input_shape': [None, 32]}, 'compile_config': {'optimizer': {'module': 'keras.optimizers', 'class_name': 'Adam', 'config': {'name': 'adam', 'learning_rate': 0.0010000000474974513, 'weight_decay': None, 'clipnorm': None, 'global_clipnorm': None, 'clipvalue': None, 'use_ema': False, 'ema_momentum': 0.99, 'ema_overwrite_frequency': None, 'loss_scale_factor': None, 'gradient_accumulation_steps': None, 'beta_1': 0.9, 'beta_2': 0.999, 'epsilon': 1e-07, 'amsgrad': False}, 'registered_name': None}, 'loss': 'mse', 'loss_weights': None, 'metrics': ['mae'], 'weighted_metrics': None, 'run_eagerly': False, 'steps_per_execution': 1, 'jit_compile': False}}.
Exception encountered: Unable to revive model from config. When overriding the `get_config()` method, make sure that the returned config contains all items used as arguments in the constructor to <class '__main__.MModel'>, which is the default behavior. You can override this default behavior by defining a `from_config(cls, config)` class method to specify how to create an instance of MModel from its config.
Received config={'trainable': True, 'dtype': {'module': 'keras', 'class_name': 'DTypePolicy', 'config': {'name': 'float32'}, 'registered_name': None}}
Error encountered during deserialization: MModel.__init__() got an unexpected keyword argument 'trainable'