8wDlpd.png
8wDFp9.png
8wDEOx.png
8wDMfH.png
8wDKte.png

Tensorflow 2.17 - 保存自定义模型不起作用

Brze 1月前

11 0

(tensorflow 2.17,Windows 10)我无法使用子类化方法保存和恢复自定义模型。请在下面找到代码来重现该问题:import numpy as npimport tensorflow as tfx =...

(TensorFlow 2.17,Windows 10)

我无法使用子类化方法保存和恢复自定义模型。

请参阅以下代码来重现该问题:

import numpy as np
import tensorflow as tf

x = np.random.random((1000, 32))
y = np.random.random((1000, 1))

@tf.keras.utils.register_keras_serializable()
class MModel(tf.keras.Model):

    def __init__(self):
        super().__init__()
        self.dense1 = tf.keras.layers.Dense(10)
        self.dense2 = tf.keras.layers.Dense(1)

    def call(self, inputs):
        x = self.dense1(inputs)
        return self.dense2(x)

model = MModel()

model.compile( optimizer="adam",  loss="mse",  metrics=["mae"])
model.fit( x, y, epochs=5)

model.save("save1.keras")
model_ = tf.keras.models.load_model("save1.keras")

print()
print(f"{model.evaluate(x, y,verbose=0)  = }")
print(f"{model_.evaluate(x, y,verbose=0) = }")

for m, m_ in zip(model.weights, model_.weights):
    np.testing.assert_allclose(m.numpy(), m_.numpy())

请注意, init () 没有参数,因此不需要覆盖 get_config()。

错误信息:

TypeError: <class '__main__.MModel'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by `get_config()` are explicitly deserialized in the model's `from_config()` method.

config={'module': None, 'class_name': 'MModel', 'config': {'trainable': True, 'dtype': {'module': 'keras', 'class_name': 'DTypePolicy', 'config': {'name': 'float32'}, 'registered_name': None}}, 'registered_name': 'Custom>MModel', 'build_config': {'input_shape': [None, 32]}, 'compile_config': {'optimizer': {'module': 'keras.optimizers', 'class_name': 'Adam', 'config': {'name': 'adam', 'learning_rate': 0.0010000000474974513, 'weight_decay': None, 'clipnorm': None, 'global_clipnorm': None, 'clipvalue': None, 'use_ema': False, 'ema_momentum': 0.99, 'ema_overwrite_frequency': None, 'loss_scale_factor': None, 'gradient_accumulation_steps': None, 'beta_1': 0.9, 'beta_2': 0.999, 'epsilon': 1e-07, 'amsgrad': False}, 'registered_name': None}, 'loss': 'mse', 'loss_weights': None, 'metrics': ['mae'], 'weighted_metrics': None, 'run_eagerly': False, 'steps_per_execution': 1, 'jit_compile': False}}.

Exception encountered: Unable to revive model from config. When overriding the `get_config()` method, make sure that the returned config contains all items used as arguments in the  constructor to <class '__main__.MModel'>, which is the default behavior. You can override this default behavior by defining a `from_config(cls, config)` class method to specify how to create an instance of MModel from its config.

Received config={'trainable': True, 'dtype': {'module': 'keras', 'class_name': 'DTypePolicy', 'config': {'name': 'float32'}, 'registered_name': None}}

Error encountered during deserialization: MModel.__init__() got an unexpected keyword argument 'trainable'
帖子版权声明 1、本帖标题:Tensorflow 2.17 - 保存自定义模型不起作用
    本站网址:http://xjnalaquan.com/
2、本网站的资源部分来源于网络,如有侵权,请联系站长进行删除处理。
3、会员发帖仅代表会员个人观点,并不代表本站赞同其观点和对其真实性负责。
4、本站一律禁止以任何方式发布或转载任何违法的相关信息,访客发现请向站长举报
5、站长邮箱:yeweds@126.com 除非注明,本帖由Brze在本站《tensorflow》版块原创发布, 转载请注明出处!
最新回复 (0)
返回
作者最近主题: