Keras 3 API 文档 / 模型 API / 保存与序列化 / 仅保存和加载权重

仅保存和加载权重

[源代码]

save_weights 方法

Model.save_weights(filepath, overwrite=True, max_shard_size=None)

将所有权重保存到单个文件或分片文件中。

默认情况下,权重将保存为单个 .weights.h5 文件。如果启用分片(max_shard_size 不为 None),权重将保存为多个文件,每个文件的大小最大为 max_shard_size(以 GB 为单位)。此外,一个名为 .weights.json 的配置文件将包含分片文件的元数据。

保存的分片文件包含:

  • *.weights.json:包含“metadata”和“weight_map”的配置文件。
  • *_xxxxxx.weights.h5:仅包含权重的分片文件。

参数

  • filepathstrpathlib.Path 对象。权重将保存到的路径。分片时,filepath 必须以 .weights.json 结尾。如果提供了 .weights.h5,它将被覆盖。
  • overwrite:是否覆盖目标位置的任何现有权重,或者通过交互式提示询问用户。
  • max_shard_sizeintfloat。每个分片文件的最大大小(以 GB 为单位)。如果为 None,则不进行分片。默认为 None

示例

# Instantiate a EfficientNetV2L model with about 454MB of weights.
model = keras.applications.EfficientNetV2L(weights=None)

# Save the weights in a single file.
model.save_weights("model.weights.h5")

# Save the weights in sharded files. Use `max_shard_size=0.25` means
# each sharded file will be at most ~250MB.
model.save_weights("model.weights.json", max_shard_size=0.25)

# Load the weights in a new model with the same architecture.
loaded_model = keras.applications.EfficientNetV2L(weights=None)
loaded_model.load_weights("model.weights.h5")
x = keras.random.uniform((1, 480, 480, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))

# Load the sharded weights in a new model with the same architecture.
loaded_model = keras.applications.EfficientNetV2L(weights=None)
loaded_model.load_weights("model.weights.json")
x = keras.random.uniform((1, 480, 480, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))

[源代码]

load_weights 方法

Model.load_weights(filepath, skip_mismatch=False, **kwargs)

从单个文件或分片文件中加载权重。

权重根据网络的拓扑结构加载。这意味着架构应与保存权重时相同。请注意,没有权重的层在拓扑排序中不被考虑,因此只要它们没有权重,添加或移除层就没有问题。

部分权重加载

如果您修改了模型,例如添加了一个新层(带有权重)或更改了层权重的形状,您可以选择忽略错误并继续加载,方法是设置 skip_mismatch=True。在这种情况下,任何权重不匹配的层都将被跳过。每个跳过的层都将显示警告。

分片

加载分片权重时,指定以 *.weights.json 结尾的 filepath 非常重要,因为它用作配置文件。此外,分片文件 *_xxxxx.weights.h5 必须与配置文件位于同一目录中。

参数

  • filepathstrpathlib.Path 对象。权重将保存到的路径。分片时,filepath 必须以 .weights.json 结尾。
  • skip_mismatch:布尔值,当权重数量不匹配或权重形状不匹配时,是否跳过加载这些层。

示例

# Load the weights in a single file.
model.load_weights("model.weights.h5")

# Load the weights in sharded files.
model.load_weights("model.weights.json")