ModelCheckpoint
类tf_keras.callbacks.ModelCheckpoint(
filepath,
monitor: str = "val_loss",
verbose: int = 0,
save_best_only: bool = False,
save_weights_only: bool = False,
mode: str = "auto",
save_freq="epoch",
options=None,
initial_value_threshold=None,
**kwargs
)
用于以一定频率保存 TF-Keras 模型或模型权重的回调函数。
ModelCheckpoint
回调函数与使用 model.fit()
进行训练结合使用,以便在某个间隔保存模型或权重(在检查点文件中),以便以后可以加载模型或权重以从保存的状态继续训练。
此回调函数提供了一些选项,包括:
注意:如果您收到 WARNING:tensorflow:Can save best model only with <name> available, skipping
的警告,请参阅 monitor
参数的说明,了解如何正确设置。
示例
model.compile(loss=..., optimizer=...,
metrics=['accuracy'])
EPOCHS = 10
checkpoint_filepath = '/tmp/checkpoint'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
monitor='val_accuracy',
mode='max',
save_best_only=True)
# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])
# The model weights (that are considered the best) are loaded into the
# model.
model.load_weights(checkpoint_filepath)
参数
PathLike
,保存模型文件的路径。例如,filepath = os.path.join(working_dir, 'ckpt', file_name)。filepath
可以包含命名格式选项,这些选项将填充 epoch
的值和 logs
中的键(在 on_epoch_end
中传递)。例如:如果 filepath
为 weights.{epoch:02d}-{val_loss:.2f}.hdf5
,则模型检查点将使用 epoch 编号和文件名中的验证损失保存。filepath 的目录不应被任何其他回调函数重复使用,以避免冲突。monitor:要监控的指标名称。通常,指标由 Model.compile
方法设置。注意:
"val_"
以监控验证指标。"loss"
或 "val_loss
" 以监控模型的总损失。"accuracy"
,请传递相同的字符串(带或不带 "val_"
前缀)。metrics.Metric
对象,则 monitor
应设置为 metric.name
history = model.fit()
返回的 history.history
字典的内容。save_best_only=True
,则仅在模型被认为是“最佳”模型时才保存,并且根据监控的数量,最新的最佳模型将不会被覆盖。如果 filepath
不包含诸如 {epoch}
之类的格式选项,则每个新的更好模型都会覆盖 filepath
。save_best_only=True
,则覆盖当前保存文件的决定是根据监控数量的最大化或最小化做出的。对于 val_acc
,这应该是 max
,对于 val_loss
,这应该是 min
,等等。在 auto
模式下,如果监控的量是 'acc' 或以 'fmeasure' 开头,则模式设置为 max
,对于其余量则设置为 min
。model.save_weights(filepath)
),否则保存完整模型 (model.save(filepath)
)。'epoch'
或整数。使用 'epoch'
时,回调函数在每个 epoch 之后保存模型。使用整数时,回调函数在这么多批次结束时保存模型。如果 Model
使用 steps_per_execution=N
编译,则保存条件将每 N 个批次检查一次。请注意,如果保存与 epoch 不对齐,则监控的指标可能会不太可靠(它可能反映的仅仅是 1 个批次,因为指标在每个 epoch 重置)。默认为 'epoch'
。save_weights_only
为真,则为可选的 tf.train.CheckpointOptions
对象;如果 save_weights_only
为假,则为可选的 tf.saved_model.SaveOptions
对象。save_best_value=True
时适用。仅当当前模型的性能优于此值时,才会覆盖已保存的模型权重。period
。