Keras 2 API 文档 / 模型 API / 保存与序列化 / 用于推理的模型导出

用于推理的模型导出

[来源]

ExportArchive

tf_keras.export.ExportArchive()

ExportArchive 用于写入 SavedModel 工件(例如,用于推理)。

如果您有一个 TF-Keras 模型或层,想要将其导出为 SavedModel 以供服务(例如,通过 TensorFlow-Serving),您可以使用 ExportArchive 配置您需要提供的不同服务端点及其签名。只需实例化一个 ExportArchive,使用 track() 注册要使用的层或模型,然后使用 add_endpoint() 方法注册新的服务端点。完成后,使用 write_out() 方法保存工件。

生成的工件是 SavedModel,可以通过 tf.saved_model.load 重新加载。

示例

以下是导出模型以进行推理的方法。

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.write_out("path/to/location")

# Elsewhere, we can reload the artifact and serve it.
# The endpoint we added is available as a method:
serving_model = tf.saved_model.load("path/to/location")
outputs = serving_model.serve(inputs)

以下是导出一个模型,其中一个端点用于推理,另一个端点用于训练模式的前向传递(例如,带有 dropout)的方法。

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="call_inference",
    fn=lambda x: model.call(x, training=False),
    input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.add_endpoint(
    name="call_training",
    fn=lambda x: model.call(x, training=True),
    input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.write_out("path/to/location")

关于资源跟踪的说明

ExportArchive 能够自动跟踪其端点使用的所有 tf.Variables,因此大多数情况下,调用 .track(model) 并不是严格必需的。但是,如果您的模型使用查找层,例如 IntegerLookupStringLookupTextVectorization,则需要通过 .track(model) 显式跟踪它。

如果您需要能够访问恢复的归档文件上的 variablestrainable_variablesnon_trainable_variables 属性,则也需要显式跟踪。