开发者指南 / 从头开始在 JAX 中编写训练循环

从头开始在 JAX 中编写训练循环

作者: fchollet
创建日期 2023/06/25
最后修改日期 2023/06/25
描述: 在 JAX 中编写低级别训练和评估循环。

在 Colab 中查看 GitHub 源码


设置

import os

# This guide can only be run with the jax backend.
os.environ["KERAS_BACKEND"] = "jax"

import jax

# We import TF so we can use tf.data.
import tensorflow as tf
import keras
import numpy as np

引言

Keras 提供了默认的训练和评估循环,即 fit()evaluate()。其用法在指南使用内置方法进行训练和评估中有所介绍。

如果您想在仍然利用 fit() 便利性的同时(例如,使用 fit() 训练 GAN)定制模型的学习算法,您可以子类化 Model 类并实现自己的 train_step() 方法,该方法会在 fit() 期间被重复调用。

现在,如果您想对训练和评估进行非常低级别的控制,您应该从头开始编写自己的训练和评估循环。本指南正是关于这一主题的。


第一个端到端示例

要编写自定义训练循环,我们需要以下要素

  • 当然,还有一个要训练的模型。
  • 一个优化器。您可以使用 keras.optimizers 中的优化器,或者来自 optax 包的优化器。
  • 一个损失函数。
  • 一个数据集。JAX 生态系统中的标准做法是通过 tf.data 加载数据,因此我们将使用它。

让我们把它们准备好。

首先,让我们获取模型和 MNIST 数据集

def get_model():
    inputs = keras.Input(shape=(784,), name="digits")
    x1 = keras.layers.Dense(64, activation="relu")(inputs)
    x2 = keras.layers.Dense(64, activation="relu")(x1)
    outputs = keras.layers.Dense(10, name="predictions")(x2)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model


model = get_model()

# Prepare the training dataset.
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784)).astype("float32")
x_test = np.reshape(x_test, (-1, 784)).astype("float32")
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)

# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)

接下来,这是损失函数和优化器。在本例中,我们将使用 Keras 优化器。

# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)

# Instantiate an optimizer.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)

在 JAX 中获取梯度

让我们使用自定义训练循环和 mini-batch 梯度来训练我们的模型。

在 JAX 中,梯度是通过元编程计算的:您在函数上调用 jax.grad(或 jax.value_and_grad)以便为该函数创建梯度计算函数。

因此,我们首先需要一个返回损失值的函数。这就是我们将用来生成梯度函数的函数。类似这样

def compute_loss(x, y):
    ...
    return loss

一旦有了这样的函数,您就可以通过元编程计算梯度,如下所示

grad_fn = jax.grad(compute_loss)
grads = grad_fn(x, y)

通常,您不仅想获取梯度值,还想获取损失值。您可以通过使用 jax.value_and_grad 代替 jax.grad 来实现这一点。

grad_fn = jax.value_and_grad(compute_loss)
loss, grads = grad_fn(x, y)

JAX 计算是纯粹无状态的

在 JAX 中,一切都必须是无状态函数——因此我们的损失计算函数也必须是无状态的。这意味着所有 Keras 变量(例如权重张量)必须作为函数输入传递,并且在前向传播期间更新的任何变量都必须作为函数输出返回。函数不能有副作用。

在前向传播期间,Keras 模型的非可训练变量可能会被更新。这些变量可能是,例如,RNG 种子状态变量或 BatchNormalization 统计信息。我们需要返回这些变量。因此我们需要类似这样的东西

def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
    ...
    return loss, non_trainable_variables

一旦有了这样的函数,您可以通过在 value_and_grad 中指定 has_aux 来获取梯度函数:它告诉 JAX 损失计算函数返回的输出不止损失值。请注意,损失值应始终是第一个输出。

grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
(loss, non_trainable_variables), grads = grad_fn(
    trainable_variables, non_trainable_variables, x, y
)

现在我们已经确立了基础,接下来实现 compute_loss_and_updates 函数。Keras 模型有一个 stateless_call 方法,在这里会非常有用。它与 model.__call__ 的工作方式类似,但它要求您显式传递模型中所有变量的值,并且它不仅返回 __call__ 的输出,还返回(可能已更新的)非可训练变量。

def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
    y_pred, non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x, training=True
    )
    loss = loss_fn(y, y_pred)
    return loss, non_trainable_variables

让我们获取梯度函数

grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)

训练步骤函数

接下来,实现端到端的训练步骤,该函数将执行前向传播、计算损失、计算梯度,还会使用优化器更新可训练变量。这个函数也需要是无状态的,因此它将接收一个包含我们将使用的所有状态元素的 state 元组作为输入

  • trainable_variablesnon_trainable_variables:模型的变量。
  • optimizer_variables:优化器的状态变量,例如动量累加器。

为了更新可训练变量,我们使用优化器的无状态方法 stateless_apply。它等效于 optimizer.apply(),但它总是要求传递 trainable_variablesoptimizer_variables。它返回更新后的可训练变量和更新后的 optimizer_variables。

def train_step(state, data):
    trainable_variables, non_trainable_variables, optimizer_variables = state
    x, y = data
    (loss, non_trainable_variables), grads = grad_fn(
        trainable_variables, non_trainable_variables, x, y
    )
    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )
    # Return updated state
    return loss, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
    )

使用 jax.jit 加速

默认情况下,JAX 操作是即时运行的(eagerly),就像 TensorFlow 和 PyTorch 的即时模式一样。而且就像 TensorFlow 和 PyTorch 的即时模式一样,它相当慢——即时模式更适合用作调试环境,而不是用于实际工作。因此,我们通过编译来加速 train_step

当您有一个无状态的 JAX 函数时,您可以通过 @jax.jit 装饰器将其编译为 XLA。它会在第一次执行时被追踪,后续执行将运行追踪到的图(这类似于 @tf.function(jit_compile=True))。让我们试试看

@jax.jit
def train_step(state, data):
    trainable_variables, non_trainable_variables, optimizer_variables = state
    x, y = data
    (loss, non_trainable_variables), grads = grad_fn(
        trainable_variables, non_trainable_variables, x, y
    )
    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )
    # Return updated state
    return loss, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
    )

现在我们准备好训练模型了。训练循环本身很简单:我们只需重复调用 loss, state = train_step(state, data)

注意

  • 在将 tf.data.Dataset 生成的 TF 张量传递给 JAX 函数之前,我们将其转换为 NumPy。
  • 所有变量都必须提前构建:模型必须构建,优化器也必须构建。由于我们使用的是函数式 API 模型,它已经构建好了,但如果是一个子类模型,您需要在一批数据上调用它来构建它。
# Build optimizer variables.
optimizer.build(model.trainable_variables)

trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
state = trainable_variables, non_trainable_variables, optimizer_variables

# Training loop
for step, data in enumerate(train_dataset):
    data = (data[0].numpy(), data[1].numpy())
    loss, state = train_step(state, data)
    # Log every 100 batches.
    if step % 100 == 0:
        print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
        print(f"Seen so far: {(step + 1) * batch_size} samples")
Training loss (for 1 batch) at step 0: 96.2726
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.0853
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.6535
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 1.2679
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.7563
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.7154
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 1.0267
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.6860
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.7306
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.4571
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.6023
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.9140
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.4224
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.6696
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.1399
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.5761
Seen so far: 48032 samples

这里需要注意的关键一点是,循环完全是无状态的——附加到模型上的变量(model.weights)在循环期间从未更新。它们的新值只存储在 state 元组中。这意味着在某个时候,在保存模型之前,您应该将新的变量值重新关联到模型上。

只需对您想要更新的每个模型变量调用 variable.assign(new_value)

trainable_variables, non_trainable_variables, optimizer_variables = state
for variable, value in zip(model.trainable_variables, trainable_variables):
    variable.assign(value)
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
    variable.assign(value)

指标的低级别处理

让我们为这个基本的训练循环添加指标监控。

您可以在从头开始编写的训练循环中轻松重用内置的 Keras 指标(或您自己编写的自定义指标)。流程如下

  • 在循环开始时实例化指标
  • metric_variables 包含在 train_step 参数和 compute_loss_and_updates 参数中。
  • compute_loss_and_updates 函数中调用 metric.stateless_update_state()。它等效于 update_state() – 只是无状态版本。
  • 当您需要在 train_step 之外(在即时执行范围内)显示指标的当前值时,将新的指标变量值附加到指标对象上并调用 metric.result()
  • 当您需要清除指标状态时(通常在每个 epoch 结束时)调用 metric.reset_state()

让我们利用这些知识在训练结束时计算训练数据和验证数据上的 CategoricalAccuracy

# Get a fresh model
model = get_model()

# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)

# Prepare the metrics.
train_acc_metric = keras.metrics.CategoricalAccuracy()
val_acc_metric = keras.metrics.CategoricalAccuracy()


def compute_loss_and_updates(
    trainable_variables, non_trainable_variables, metric_variables, x, y
):
    y_pred, non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x
    )
    loss = loss_fn(y, y_pred)
    metric_variables = train_acc_metric.stateless_update_state(
        metric_variables, y, y_pred
    )
    return loss, (non_trainable_variables, metric_variables)


grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)


@jax.jit
def train_step(state, data):
    (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
        metric_variables,
    ) = state
    x, y = data
    (loss, (non_trainable_variables, metric_variables)), grads = grad_fn(
        trainable_variables, non_trainable_variables, metric_variables, x, y
    )
    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )
    # Return updated state
    return loss, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
        metric_variables,
    )

我们还将准备一个评估步骤函数

@jax.jit
def eval_step(state, data):
    trainable_variables, non_trainable_variables, metric_variables = state
    x, y = data
    y_pred, non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x
    )
    loss = loss_fn(y, y_pred)
    metric_variables = val_acc_metric.stateless_update_state(
        metric_variables, y, y_pred
    )
    return loss, (
        trainable_variables,
        non_trainable_variables,
        metric_variables,
    )

这是我们的循环

# Build optimizer variables.
optimizer.build(model.trainable_variables)

trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
metric_variables = train_acc_metric.variables
state = (
    trainable_variables,
    non_trainable_variables,
    optimizer_variables,
    metric_variables,
)

# Training loop
for step, data in enumerate(train_dataset):
    data = (data[0].numpy(), data[1].numpy())
    loss, state = train_step(state, data)
    # Log every 100 batches.
    if step % 100 == 0:
        print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
        _, _, _, metric_variables = state
        for variable, value in zip(train_acc_metric.variables, metric_variables):
            variable.assign(value)
        print(f"Training accuracy: {train_acc_metric.result()}")
        print(f"Seen so far: {(step + 1) * batch_size} samples")

metric_variables = val_acc_metric.variables
(
    trainable_variables,
    non_trainable_variables,
    optimizer_variables,
    metric_variables,
) = state
state = trainable_variables, non_trainable_variables, metric_variables

# Eval loop
for step, data in enumerate(val_dataset):
    data = (data[0].numpy(), data[1].numpy())
    loss, state = eval_step(state, data)
    # Log every 100 batches.
    if step % 100 == 0:
        print(f"Validation loss (for 1 batch) at step {step}: {float(loss):.4f}")
        _, _, metric_variables = state
        for variable, value in zip(val_acc_metric.variables, metric_variables):
            variable.assign(value)
        print(f"Validation accuracy: {val_acc_metric.result()}")
        print(f"Seen so far: {(step + 1) * batch_size} samples")
Training loss (for 1 batch) at step 0: 70.8851
Training accuracy: 0.09375
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.1930
Training accuracy: 0.6596534848213196
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 3.0249
Training accuracy: 0.7352300882339478
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.6004
Training accuracy: 0.7588247656822205
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 1.4633
Training accuracy: 0.7736907601356506
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 1.3367
Training accuracy: 0.7826846241950989
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.8767
Training accuracy: 0.7930532693862915
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.3479
Training accuracy: 0.8004636168479919
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.3608
Training accuracy: 0.8066869378089905
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.7582
Training accuracy: 0.8117369413375854
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 1.3135
Training accuracy: 0.8142170310020447
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 1.0202
Training accuracy: 0.8186308145523071
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.6766
Training accuracy: 0.822023332118988
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.7606
Training accuracy: 0.8257110118865967
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.7657
Training accuracy: 0.8290283679962158
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.6563
Training accuracy: 0.831653892993927
Seen so far: 48032 samples
Validation loss (for 1 batch) at step 0: 0.1622
Validation accuracy: 0.8329269289970398
Seen so far: 32 samples
Validation loss (for 1 batch) at step 100: 0.7455
Validation accuracy: 0.8338780999183655
Seen so far: 3232 samples
Validation loss (for 1 batch) at step 200: 0.2738
Validation accuracy: 0.836174488067627
Seen so far: 6432 samples
Validation loss (for 1 batch) at step 300: 0.1255
Validation accuracy: 0.8390461206436157
Seen so far: 9632 samples

模型跟踪的损失的低级别处理

层和模型递归地跟踪在前向传播期间由调用 self.add_loss(value) 的层创建的任何损失。生成的一系列标量损失值在前向传播结束时通过属性 model.losses 可用。

如果您想使用这些损失组成部分,您应该将它们求和并添加到训练步骤中的主损失中。

考虑这个创建活动正则化损失的层

class ActivityRegularizationLayer(keras.layers.Layer):
    def call(self, inputs):
        self.add_loss(1e-2 * jax.numpy.sum(inputs))
        return inputs

让我们构建一个使用它的非常简单的模型

inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu")(inputs)
# Insert activity regularization as a layer
x = ActivityRegularizationLayer()(x)
x = keras.layers.Dense(64, activation="relu")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)

model = keras.Model(inputs=inputs, outputs=outputs)

现在我们的 compute_loss_and_updates 函数应该看起来像这样了

  • return_losses=True 传递给 model.stateless_call()
  • 对生成的 losses 求和并将其添加到主损失中。
def compute_loss_and_updates(
    trainable_variables, non_trainable_variables, metric_variables, x, y
):
    y_pred, non_trainable_variables, losses = model.stateless_call(
        trainable_variables, non_trainable_variables, x, return_losses=True
    )
    loss = loss_fn(y, y_pred)
    if losses:
        loss += jax.numpy.sum(losses)
    metric_variables = train_acc_metric.stateless_update_state(
        metric_variables, y, y_pred
    )
    return loss, non_trainable_variables, metric_variables

就这样!