作者: fchollet
创建日期 2023/06/25
最后修改日期 2023/06/25
描述: 在 JAX 中编写低级别训练和评估循环。
import os
# This guide can only be run with the jax backend.
os.environ["KERAS_BACKEND"] = "jax"
import jax
# We import TF so we can use tf.data.
import tensorflow as tf
import keras
import numpy as np
Keras 提供了默认的训练和评估循环,即 fit()
和 evaluate()
。其用法在指南使用内置方法进行训练和评估中有所介绍。
如果您想在仍然利用 fit()
便利性的同时(例如,使用 fit()
训练 GAN)定制模型的学习算法,您可以子类化 Model
类并实现自己的 train_step()
方法,该方法会在 fit()
期间被重复调用。
现在,如果您想对训练和评估进行非常低级别的控制,您应该从头开始编写自己的训练和评估循环。本指南正是关于这一主题的。
要编写自定义训练循环,我们需要以下要素
keras.optimizers
中的优化器,或者来自 optax
包的优化器。tf.data
加载数据,因此我们将使用它。让我们把它们准备好。
首先,让我们获取模型和 MNIST 数据集
def get_model():
inputs = keras.Input(shape=(784,), name="digits")
x1 = keras.layers.Dense(64, activation="relu")(inputs)
x2 = keras.layers.Dense(64, activation="relu")(x1)
outputs = keras.layers.Dense(10, name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
model = get_model()
# Prepare the training dataset.
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784)).astype("float32")
x_test = np.reshape(x_test, (-1, 784)).astype("float32")
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)
# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)
接下来,这是损失函数和优化器。在本例中,我们将使用 Keras 优化器。
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
# Instantiate an optimizer.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
让我们使用自定义训练循环和 mini-batch 梯度来训练我们的模型。
在 JAX 中,梯度是通过元编程计算的:您在函数上调用 jax.grad
(或 jax.value_and_grad
)以便为该函数创建梯度计算函数。
因此,我们首先需要一个返回损失值的函数。这就是我们将用来生成梯度函数的函数。类似这样
def compute_loss(x, y):
...
return loss
一旦有了这样的函数,您就可以通过元编程计算梯度,如下所示
grad_fn = jax.grad(compute_loss)
grads = grad_fn(x, y)
通常,您不仅想获取梯度值,还想获取损失值。您可以通过使用 jax.value_and_grad
代替 jax.grad
来实现这一点。
grad_fn = jax.value_and_grad(compute_loss)
loss, grads = grad_fn(x, y)
在 JAX 中,一切都必须是无状态函数——因此我们的损失计算函数也必须是无状态的。这意味着所有 Keras 变量(例如权重张量)必须作为函数输入传递,并且在前向传播期间更新的任何变量都必须作为函数输出返回。函数不能有副作用。
在前向传播期间,Keras 模型的非可训练变量可能会被更新。这些变量可能是,例如,RNG 种子状态变量或 BatchNormalization 统计信息。我们需要返回这些变量。因此我们需要类似这样的东西
def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
...
return loss, non_trainable_variables
一旦有了这样的函数,您可以通过在 value_and_grad
中指定 has_aux
来获取梯度函数:它告诉 JAX 损失计算函数返回的输出不止损失值。请注意,损失值应始终是第一个输出。
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
现在我们已经确立了基础,接下来实现 compute_loss_and_updates
函数。Keras 模型有一个 stateless_call
方法,在这里会非常有用。它与 model.__call__
的工作方式类似,但它要求您显式传递模型中所有变量的值,并且它不仅返回 __call__
的输出,还返回(可能已更新的)非可训练变量。
def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x, training=True
)
loss = loss_fn(y, y_pred)
return loss, non_trainable_variables
让我们获取梯度函数
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
接下来,实现端到端的训练步骤,该函数将执行前向传播、计算损失、计算梯度,还会使用优化器更新可训练变量。这个函数也需要是无状态的,因此它将接收一个包含我们将使用的所有状态元素的 state
元组作为输入
trainable_variables
和 non_trainable_variables
:模型的变量。optimizer_variables
:优化器的状态变量,例如动量累加器。为了更新可训练变量,我们使用优化器的无状态方法 stateless_apply
。它等效于 optimizer.apply()
,但它总是要求传递 trainable_variables
和 optimizer_variables
。它返回更新后的可训练变量和更新后的 optimizer_variables。
def train_step(state, data):
trainable_variables, non_trainable_variables, optimizer_variables = state
x, y = data
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# Return updated state
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
)
jax.jit
加速默认情况下,JAX 操作是即时运行的(eagerly),就像 TensorFlow 和 PyTorch 的即时模式一样。而且就像 TensorFlow 和 PyTorch 的即时模式一样,它相当慢——即时模式更适合用作调试环境,而不是用于实际工作。因此,我们通过编译来加速 train_step
。
当您有一个无状态的 JAX 函数时,您可以通过 @jax.jit
装饰器将其编译为 XLA。它会在第一次执行时被追踪,后续执行将运行追踪到的图(这类似于 @tf.function(jit_compile=True)
)。让我们试试看
@jax.jit
def train_step(state, data):
trainable_variables, non_trainable_variables, optimizer_variables = state
x, y = data
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# Return updated state
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
)
现在我们准备好训练模型了。训练循环本身很简单:我们只需重复调用 loss, state = train_step(state, data)
。
注意
tf.data.Dataset
生成的 TF 张量传递给 JAX 函数之前,我们将其转换为 NumPy。# Build optimizer variables.
optimizer.build(model.trainable_variables)
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
state = trainable_variables, non_trainable_variables, optimizer_variables
# Training loop
for step, data in enumerate(train_dataset):
data = (data[0].numpy(), data[1].numpy())
loss, state = train_step(state, data)
# Log every 100 batches.
if step % 100 == 0:
print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
print(f"Seen so far: {(step + 1) * batch_size} samples")
Training loss (for 1 batch) at step 0: 96.2726
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.0853
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.6535
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 1.2679
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.7563
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.7154
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 1.0267
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.6860
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.7306
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.4571
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.6023
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.9140
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.4224
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.6696
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.1399
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.5761
Seen so far: 48032 samples
这里需要注意的关键一点是,循环完全是无状态的——附加到模型上的变量(model.weights
)在循环期间从未更新。它们的新值只存储在 state
元组中。这意味着在某个时候,在保存模型之前,您应该将新的变量值重新关联到模型上。
只需对您想要更新的每个模型变量调用 variable.assign(new_value)
trainable_variables, non_trainable_variables, optimizer_variables = state
for variable, value in zip(model.trainable_variables, trainable_variables):
variable.assign(value)
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
variable.assign(value)
让我们为这个基本的训练循环添加指标监控。
您可以在从头开始编写的训练循环中轻松重用内置的 Keras 指标(或您自己编写的自定义指标)。流程如下
metric_variables
包含在 train_step
参数和 compute_loss_and_updates
参数中。compute_loss_and_updates
函数中调用 metric.stateless_update_state()
。它等效于 update_state()
– 只是无状态版本。train_step
之外(在即时执行范围内)显示指标的当前值时,将新的指标变量值附加到指标对象上并调用 metric.result()
。metric.reset_state()
。让我们利用这些知识在训练结束时计算训练数据和验证数据上的 CategoricalAccuracy
# Get a fresh model
model = get_model()
# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
# Prepare the metrics.
train_acc_metric = keras.metrics.CategoricalAccuracy()
val_acc_metric = keras.metrics.CategoricalAccuracy()
def compute_loss_and_updates(
trainable_variables, non_trainable_variables, metric_variables, x, y
):
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x
)
loss = loss_fn(y, y_pred)
metric_variables = train_acc_metric.stateless_update_state(
metric_variables, y, y_pred
)
return loss, (non_trainable_variables, metric_variables)
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
@jax.jit
def train_step(state, data):
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
) = state
x, y = data
(loss, (non_trainable_variables, metric_variables)), grads = grad_fn(
trainable_variables, non_trainable_variables, metric_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# Return updated state
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
)
我们还将准备一个评估步骤函数
@jax.jit
def eval_step(state, data):
trainable_variables, non_trainable_variables, metric_variables = state
x, y = data
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x
)
loss = loss_fn(y, y_pred)
metric_variables = val_acc_metric.stateless_update_state(
metric_variables, y, y_pred
)
return loss, (
trainable_variables,
non_trainable_variables,
metric_variables,
)
这是我们的循环
# Build optimizer variables.
optimizer.build(model.trainable_variables)
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
metric_variables = train_acc_metric.variables
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
)
# Training loop
for step, data in enumerate(train_dataset):
data = (data[0].numpy(), data[1].numpy())
loss, state = train_step(state, data)
# Log every 100 batches.
if step % 100 == 0:
print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
_, _, _, metric_variables = state
for variable, value in zip(train_acc_metric.variables, metric_variables):
variable.assign(value)
print(f"Training accuracy: {train_acc_metric.result()}")
print(f"Seen so far: {(step + 1) * batch_size} samples")
metric_variables = val_acc_metric.variables
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
) = state
state = trainable_variables, non_trainable_variables, metric_variables
# Eval loop
for step, data in enumerate(val_dataset):
data = (data[0].numpy(), data[1].numpy())
loss, state = eval_step(state, data)
# Log every 100 batches.
if step % 100 == 0:
print(f"Validation loss (for 1 batch) at step {step}: {float(loss):.4f}")
_, _, metric_variables = state
for variable, value in zip(val_acc_metric.variables, metric_variables):
variable.assign(value)
print(f"Validation accuracy: {val_acc_metric.result()}")
print(f"Seen so far: {(step + 1) * batch_size} samples")
Training loss (for 1 batch) at step 0: 70.8851
Training accuracy: 0.09375
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.1930
Training accuracy: 0.6596534848213196
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 3.0249
Training accuracy: 0.7352300882339478
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.6004
Training accuracy: 0.7588247656822205
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 1.4633
Training accuracy: 0.7736907601356506
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 1.3367
Training accuracy: 0.7826846241950989
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.8767
Training accuracy: 0.7930532693862915
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.3479
Training accuracy: 0.8004636168479919
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.3608
Training accuracy: 0.8066869378089905
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.7582
Training accuracy: 0.8117369413375854
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 1.3135
Training accuracy: 0.8142170310020447
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 1.0202
Training accuracy: 0.8186308145523071
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.6766
Training accuracy: 0.822023332118988
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.7606
Training accuracy: 0.8257110118865967
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.7657
Training accuracy: 0.8290283679962158
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.6563
Training accuracy: 0.831653892993927
Seen so far: 48032 samples
Validation loss (for 1 batch) at step 0: 0.1622
Validation accuracy: 0.8329269289970398
Seen so far: 32 samples
Validation loss (for 1 batch) at step 100: 0.7455
Validation accuracy: 0.8338780999183655
Seen so far: 3232 samples
Validation loss (for 1 batch) at step 200: 0.2738
Validation accuracy: 0.836174488067627
Seen so far: 6432 samples
Validation loss (for 1 batch) at step 300: 0.1255
Validation accuracy: 0.8390461206436157
Seen so far: 9632 samples
层和模型递归地跟踪在前向传播期间由调用 self.add_loss(value)
的层创建的任何损失。生成的一系列标量损失值在前向传播结束时通过属性 model.losses
可用。
如果您想使用这些损失组成部分,您应该将它们求和并添加到训练步骤中的主损失中。
考虑这个创建活动正则化损失的层
class ActivityRegularizationLayer(keras.layers.Layer):
def call(self, inputs):
self.add_loss(1e-2 * jax.numpy.sum(inputs))
return inputs
让我们构建一个使用它的非常简单的模型
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu")(inputs)
# Insert activity regularization as a layer
x = ActivityRegularizationLayer()(x)
x = keras.layers.Dense(64, activation="relu")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
现在我们的 compute_loss_and_updates
函数应该看起来像这样了
return_losses=True
传递给 model.stateless_call()
。losses
求和并将其添加到主损失中。def compute_loss_and_updates(
trainable_variables, non_trainable_variables, metric_variables, x, y
):
y_pred, non_trainable_variables, losses = model.stateless_call(
trainable_variables, non_trainable_variables, x, return_losses=True
)
loss = loss_fn(y, y_pred)
if losses:
loss += jax.numpy.sum(losses)
metric_variables = train_acc_metric.stateless_update_state(
metric_variables, y, y_pred
)
return loss, non_trainable_variables, metric_variables
就这样!