Keras 2 API 文档 / 回调 API / 基础回调类

基础回调类

[源代码]

Callback

tf_keras.callbacks.Callback()

用于构建新回调的抽象基类。

可以将回调传递给诸如 fitevaluatepredict 等 Keras 方法,以便钩入模型训练和推理生命周期的各个阶段。

要创建自定义回调,请继承 keras.callbacks.Callback 并覆盖与感兴趣阶段关联的方法。有关更多信息,请参阅 自定义回调

示例

>>> training_finished = False
>>> class MyCallback(tf.keras.callbacks.Callback):
...   def on_train_end(self, logs=None):
...     global training_finished
...     training_finished = True
>>> model = tf.keras.Sequential([
...     tf.keras.layers.Dense(1, input_shape=(1,))])
>>> model.compile(loss='mean_squared_error')
>>> model.fit(tf.constant([[1.0]]), tf.constant([[1.0]]),
...           callbacks=[MyCallback()])
>>> assert training_finished == True

如果您想在自定义训练循环中使用 Callback 对象

  1. 您应该将所有回调打包到单个 callbacks.CallbackList 中,以便可以一起调用它们。
  2. 您需要在循环中适当的位置手动调用所有 on_* 方法。像这样

示例

   callbacks =  tf.keras.callbacks.CallbackList([...])
   callbacks.append(...)
   callbacks.on_train_begin(...)
   for epoch in range(EPOCHS):
     callbacks.on_epoch_begin(epoch)
     for i, data in dataset.enumerate():
       callbacks.on_train_batch_begin(i)
       batch_logs = model.train_step(data)
       callbacks.on_train_batch_end(i, batch_logs)
     epoch_logs = ...
     callbacks.on_epoch_end(epoch, epoch_logs)
   final_logs=...
   callbacks.on_train_end(final_logs)

属性

  • params:字典。训练参数(例如,详细程度、批量大小、轮数等)。
  • modelkeras.models.Model 的实例。正在训练的模型的引用。

回调方法作为参数接收的 logs 字典将包含与当前批次或轮次相关的数量的键(请参阅特定于方法的文档字符串)。