ExportArchive
类tf_keras.export.ExportArchive()
ExportArchive
用于写入 SavedModel 工件(例如,用于推理)。
如果您有一个 TF-Keras 模型或层,想要将其导出为 SavedModel 以供服务(例如,通过 TensorFlow-Serving),您可以使用 ExportArchive
配置您需要提供的不同服务端点及其签名。只需实例化一个 ExportArchive
,使用 track()
注册要使用的层或模型,然后使用 add_endpoint()
方法注册新的服务端点。完成后,使用 write_out()
方法保存工件。
生成的工件是 SavedModel,可以通过 tf.saved_model.load
重新加载。
示例
以下是导出模型以进行推理的方法。
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.write_out("path/to/location")
# Elsewhere, we can reload the artifact and serve it.
# The endpoint we added is available as a method:
serving_model = tf.saved_model.load("path/to/location")
outputs = serving_model.serve(inputs)
以下是导出一个模型,其中一个端点用于推理,另一个端点用于训练模式的前向传递(例如,带有 dropout)的方法。
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="call_inference",
fn=lambda x: model.call(x, training=False),
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.add_endpoint(
name="call_training",
fn=lambda x: model.call(x, training=True),
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.write_out("path/to/location")
关于资源跟踪的说明
ExportArchive
能够自动跟踪其端点使用的所有 tf.Variables
,因此大多数情况下,调用 .track(model)
并不是严格必需的。但是,如果您的模型使用查找层,例如 IntegerLookup
、StringLookup
或 TextVectorization
,则需要通过 .track(model)
显式跟踪它。
如果您需要能够访问恢复的归档文件上的 variables
、trainable_variables
或 non_trainable_variables
属性,则也需要显式跟踪。