序列化工具

[源代码]

serialize_keras_object 函数

tf_keras.utils.serialize_keras_object(obj)

通过序列化 TF-Keras 对象来检索配置字典。

serialize_keras_object() 将 TF-Keras 对象序列化为一个表示该对象的 Python 字典,并且是 deserialize_keras_object() 的逆函数。有关配置格式的更多信息,请参阅 deserialize_keras_object()

参数

  • obj:要序列化的 TF-Keras 对象。

返回值

表示该对象的 Python 字典。该 Python 字典可以通过 deserialize_keras_object() 反序列化。


[源代码]

deserialize_keras_object 函数

tf_keras.utils.deserialize_keras_object(
    config, custom_objects=None, safe_mode=True, **kwargs
)

通过反序列化配置字典来检索对象。

配置字典是一个 Python 字典,由一组键值对组成,并表示一个 TF-Keras 对象,例如 OptimizerLayerMetrics 等。保存和加载库使用以下键来记录 TF-Keras 对象的信息:

  • class_name:字符串。这是类的名称,与源代码中完全定义的名称一致,例如 "LossesContainer"。
  • config:字典。库定义或用户定义的键值对,用于存储对象的配置,如通过 object.get_config() 获取。
  • module:字符串。Python 模块的路径,例如 "keras.engine.compile_utils"。内置的 TF-Keras 类应具有前缀 keras
  • registered_name:字符串。该类通过 keras.saving.register_keras_serializable(package, name) API 注册所使用的键。该键的格式为 '{package}>{name}',其中 packagename 是传递给 register_keras_serializable() 的参数。如果未提供 name,则使用类名。如果 registered_name 成功解析为一个类(已注册),则不会使用字典中的 class_nameconfig 值。registered_name 仅用于非内置类。

例如,以下字典表示内置的 Adam 优化器及其相关配置:

dict_structure = {
    "class_name": "Adam",
    "config": {
        "amsgrad": false,
        "beta_1": 0.8999999761581421,
        "beta_2": 0.9990000128746033,
        "decay": 0.0,
        "epsilon": 1e-07,
        "learning_rate": 0.0010000000474974513,
        "name": "Adam"
    },
    "module": "keras.optimizers",
    "registered_name": None
}
# Returns an `Adam` instance identical to the original one.
deserialize_keras_object(dict_structure)

如果该类没有导出的 TF-Keras 命名空间,则库会通过其 moduleclass_name 来跟踪它。例如:

dict_structure = {
  "class_name": "LossesContainer",
  "config": {
      "losses": [...],
      "total_loss_mean": {...},
  },
  "module": "keras.engine.compile_utils",
  "registered_name": "LossesContainer"
}

# Returns a `LossesContainer` instance identical to the original one.
deserialize_keras_object(dict_structure)

以下字典表示用户自定义的 MeanSquaredError 损失函数:

@keras.saving.register_keras_serializable(package='my_package')
class ModifiedMeanSquaredError(keras.losses.MeanSquaredError):
  ...

dict_structure = {
    "class_name": "ModifiedMeanSquaredError",
    "config": {
        "fn": "mean_squared_error",
        "name": "mean_squared_error",
        "reduction": "auto"
    },
    "registered_name": "my_package>ModifiedMeanSquaredError"
}
# Returns the `ModifiedMeanSquaredError` object
deserialize_keras_object(dict_structure)

参数

  • config:描述对象的 Python 字典。
  • custom_objects:Python 字典,包含自定义对象名称与相应类或函数之间的映射。
  • safe_mode:布尔值,指示是否禁止不安全的 lambda 反序列化。当 safe_mode=False 时,加载对象可能会触发任意代码执行。此参数仅适用于 TF-Keras v3 模型格式。默认为 True

返回值

config 字典描述的对象。


[源代码]

CustomObjectScope

tf_keras.saving.custom_object_scope(*args)

将自定义类/函数公开给 TF-Keras 反序列化内部机制。

with custom_object_scope(objects_dict) 的作用域下,TF-Keras 方法(如 tf.keras.models.load_modeltf.keras.models.model_from_config)将能够反序列化保存的配置引用的任何自定义对象(例如自定义层或指标)。

示例

考虑一个自定义正则化器 my_regularizer

layer = Dense(3, kernel_regularizer=my_regularizer)
# Config contains a reference to `my_regularizer`
config = layer.get_config()
...
# Later:
with custom_object_scope({'my_regularizer': my_regularizer}):
  layer = Dense.from_config(config)

参数

  • *args:字典或 {name: object} 对的字典。

[源代码]

get_custom_objects 函数

tf_keras.saving.get_custom_objects()

检索自定义对象的全局字典的实时引用。

使用 custom_object_scope 设置的自定义对象不会添加到自定义对象的全局字典中,也不会出现在返回的字典中。

示例

get_custom_objects().clear()
get_custom_objects()['MyObject'] = MyObject

返回值

将注册的类名映射到类的全局字典。


[源代码]

register_keras_serializable 函数

tf_keras.saving.register_keras_serializable(package="Custom", name=None)

向 TF-Keras 序列化框架注册一个对象。

此装饰器将装饰的类或函数注入到 TF-Keras 自定义对象字典中,以便可以在无需用户提供的自定义对象字典中添加条目的情况下对其进行序列化和反序列化。它还会注入一个 TF-Keras 将调用的函数,以获取对象的序列化字符串键。

请注意,要进行序列化和反序列化,类必须实现 get_config() 方法。函数没有此要求。

该对象将以 'package>name' 的键注册,其中 name 如果未传递,则默认为对象名称。

示例

# Note that `'my_package'` is used as the `package` argument here, and since
# the `name` argument is not provided, `'MyDense'` is used as the `name`.
@keras.saving.register_keras_serializable('my_package')
class MyDense(keras.layers.Dense):
  pass

assert keras.saving.get_registered_object('my_package>MyDense') == MyDense
assert keras.saving.get_registered_name(MyDense) == 'my_package>MyDense'

参数

  • package:此类所属的包。这用于 key(即 "package>name")来标识该类。请注意,这是传递给装饰器的第一个参数。
  • name:在此包中序列化此类的名称。如果未提供或为 None,则将使用该类的名称(请注意,当装饰器仅使用一个参数时,该参数将成为 package,即为这种情况)。

返回值

一个装饰器,使用传递的名称注册装饰的类。