export
方法Model.export(
filepath, format="tf_saved_model", verbose=None, input_signature=None, **kwargs
)
将模型导出为用于推断的工件。
参数
str
或 pathlib.Path
对象。保存工件的路径。str
。导出格式。支持的值:“"tf_saved_model"
”和“"onnx"
”。默认为“"tf_saved_model"
”。bool
。导出期间是否打印消息。默认为 None
,这表示使用不同后端和格式设置的默认值。keras.InputSpec
、tf.TensorSpec
、backend.KerasTensor
或后端张量的结构。如果未提供,将自动计算。默认为 None
。format="tf_saved_model"
专用: - is_static
:可选 bool
。指示 fn
是否为静态。如果 fn
涉及状态更新(例如,RNG 种子和计数器),则设置为 False
。 - jax2tf_kwargs
:可选 dict
。jax2tf.convert
的参数。请参阅 jax2tf.convert
的文档。如果未提供 native_serialization
和 polymorphic_shapes
,它们将自动计算。注意:此功能目前仅支持 TensorFlow、JAX 和 Torch 后端。
注意:请注意,当使用 format="onnx"
、verbose=True
和 Torch 后端时,导出的工件可能包含来自本地文件系统的信息。
示例
以下是如何导出用于推断的 TensorFlow SavedModel。
# Export the model as a TensorFlow SavedModel artifact
model.export("path/to/location", format="tf_saved_model")
# Load the artifact in a different process/environment
reloaded_artifact = tf.saved_model.load("path/to/location")
predictions = reloaded_artifact.serve(input_data)
以下是如何导出用于推断的 ONNX。
# Export the model as a ONNX artifact
model.export("path/to/location", format="onnx")
# Load the artifact in a different process/environment
ort_session = onnxruntime.InferenceSession("path/to/location")
ort_inputs = {
k.name: v for k, v in zip(ort_session.get_inputs(), input_data)
}
predictions = ort_session.run(None, ort_inputs)
ExportArchive
类keras.export.ExportArchive()
ExportArchive 用于写入 SavedModel 工件(例如用于推断)。
如果您有一个 Keras 模型或层,想要将其导出为 SavedModel 以供服务(例如通过 TensorFlow-Serving),您可以使用 ExportArchive
来配置您需要提供的不同服务终结点及其签名。只需实例化一个 ExportArchive
,使用 track()
注册要使用的层或模型,然后使用 add_endpoint()
方法注册新的服务终结点。完成后,使用 write_out()
方法保存工件。
生成的工件是 SavedModel,可以通过 tf.saved_model.load
重新加载。
示例
以下是如何导出用于推断的模型。
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)
export_archive.write_out("path/to/location")
# Elsewhere, we can reload the artifact and serve it.
# The endpoint we added is available as a method:
serving_model = tf.saved_model.load("path/to/location")
outputs = serving_model.serve(inputs)
以下是如何导出一个具有一个用于推断的终结点和一个用于训练模式前向传递(例如,启用 dropout)的终结点的模型。
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="call_inference",
fn=lambda x: model.call(x, training=False),
input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)
export_archive.add_endpoint(
name="call_training",
fn=lambda x: model.call(x, training=True),
input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)
export_archive.write_out("path/to/location")
资源跟踪注意事项
ExportArchive
能够自动跟踪其终结点使用的所有 keras.Variables
,因此大多数情况下并不严格要求调用 .track(model)
。但是,如果您的模型使用查找层,例如 IntegerLookup
、StringLookup
或 TextVectorization
,则需要通过 .track(model)
显式跟踪。
如果您需要能够访问恢复后的存档上的 variables
、trainable_variables
或 non_trainable_variables
属性,也需要显式跟踪。
add_endpoint
方法ExportArchive.add_endpoint(name, fn, input_signature=None, **kwargs)
注册一个新的服务终结点。
参数
str
。终结点的名称。ExportArchive
跟踪的模型/层上可用的资源(例如 keras.Variable
对象或 tf.lookup.StaticHashTable
对象)(您可以调用 .track(model)
来跟踪新模型)。函数输入的形状和 dtype 必须已知。为此,您可以 1) 确保 fn
是一个至少被调用过一次的 tf.function
,或 2) 提供一个指定输入形状和 dtype 的 input_signature
参数(见下文)。fn
的形状和 dtype。可以是 keras.InputSpec
、tf.TensorSpec
、backend.KerasTensor
或后端张量的结构(参见下文显示一个具有 2 个输入参数的 Functional
模型的示例)。如果未提供,fn
必须是一个至少被调用过一次的 tf.function
。默认为 None
。is_static
:可选 bool
。指示 fn
是否为静态。如果 fn
涉及状态更新(例如,RNG 种子),则设置为 False
。 - jax2tf_kwargs
:可选 dict
。jax2tf.convert
的参数。请参阅 jax2tf.convert
。如果未提供 native_serialization
和 polymorphic_shapes
,它们将自动计算。返回
已添加到存档中的 tf.function
包装器。
示例
当模型具有单个输入参数时,使用 input_signature
参数添加终结点
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)
当模型具有两个位置输入参数时,使用 input_signature
参数添加终结点
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[
keras.InputSpec(shape=(None, 3), dtype="float32"),
keras.InputSpec(shape=(None, 4), dtype="float32"),
],
)
当模型具有一个输入参数(该参数是一个包含 2 个张量的列表,例如具有 2 个输入的函数模型)时,使用 input_signature
参数添加终结点
model = keras.Model(inputs=[x1, x2], outputs=outputs)
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[
[
keras.InputSpec(shape=(None, 3), dtype="float32"),
keras.InputSpec(shape=(None, 4), dtype="float32"),
],
],
)
这也适用于字典输入
model = keras.Model(inputs={"x1": x1, "x2": x2}, outputs=outputs)
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[
{
"x1": keras.InputSpec(shape=(None, 3), dtype="float32"),
"x2": keras.InputSpec(shape=(None, 4), dtype="float32"),
},
],
)
添加一个作为 tf.function
的终结点
@tf.function()
def serving_fn(x):
return model(x)
# The function must be traced, i.e. it must be called at least once.
serving_fn(tf.random.normal(shape=(2, 3)))
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(name="serve", fn=serving_fn)
将模型与一些 TensorFlow 预处理相结合,其中可以使用 TensorFlow 资源
lookup_table = tf.lookup.StaticHashTable(initializer, default_value=0.0)
export_archive = ExportArchive()
model_fn = export_archive.track_and_add_endpoint(
"model_fn",
model,
input_signature=[tf.TensorSpec(shape=(None, 5), dtype=tf.float32)],
)
export_archive.track(lookup_table)
@tf.function()
def serving_fn(x):
x = lookup_table.lookup(x)
return model_fn(x)
export_archive.add_endpoint(name="serve", fn=serving_fn)
add_variable_collection
方法ExportArchive.add_variable_collection(name, variables)
注册一组变量,以便在重新加载后检索。
参数
keras.Variable
实例的元组/列表/集合。示例
export_archive = ExportArchive()
export_archive.track(model)
# Register an endpoint
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)
# Save a variable collection
export_archive.add_variable_collection(
name="optimizer_variables", variables=model.optimizer.variables)
export_archive.write_out("path/to/location")
# Reload the object
revived_object = tf.saved_model.load("path/to/location")
# Retrieve the variables
optimizer_variables = revived_object.optimizer_variables
track
方法ExportArchive.track(resource)
跟踪(层或模型的)变量和其他资产。
默认情况下,当您调用 add_endpoint()
时,终结点函数使用的所有变量都会自动跟踪。但是,非变量资产(如查找表)需要手动跟踪。请注意,内置 Keras 层(TextVectorization
、IntegerLookup
、StringLookup
)使用的查找表由 add_endpoint()
自动跟踪。
参数
write_out
方法ExportArchive.write_out(filepath, options=None, verbose=True)
将相应的 SavedModel 写入磁盘。
参数
str
或 pathlib.Path
对象。保存工件的路径。tf.saved_model.SaveOptions
对象,指定 SavedModel 保存选项。关于 TF-Serving 的说明:所有通过 add_endpoint()
注册的终结点都在 SavedModel 工件中对 TF-Serving 可见。此外,第一个注册的终结点在别名 "serving_default"
下可见(除非已手动注册名称为 "serving_default"
的终结点),因为 TF-Serving 要求设置此终结点。