Keras 3 API 文档 / 再计算 / RematScope

RematScope

[源代码]

RematScope

keras.RematScope(mode="full", output_size_threshold=1024, layer_names=None)

Keras 中用于启用再计算的上下文管理器。

再计算(梯度检查点)通过在反向传播期间重新计算中间激活来权衡内存与计算。这对于在有限的内存约束下训练大型模型或大型批次尤其有用。

这应该在初始化层时使用(例如,layer(input))。再计算在执行时应用,而不是在创建时应用。

参数

  • mode:要应用的再计算模式。选项:
    • "full":将再计算全局应用于所有支持的操作。
    • "activations":对任何包含 keras.activations 的层(例如,Dense(..., activation=relu))上的激活应用再计算。
    • "larger_than":对输出大小大于 output_size_threshold 的层应用再计算。
    • "list_of_layers":对特定层名称列表应用再计算。
    • None:禁用再计算。
  • output_size_threshold"larger_than" 模式的输出大小阈值。生成输出大于此阈值的层将被再计算。默认值为 1024
  • layer_names"list_of_layers" 模式的层名称列表。默认值为空列表。

示例

使用 "list_of_layers" 模式

from keras import RematScope
input_tensor = tf.random.normal((1, 32, 32, 3))
with RematScope(mode="list_of_layers", layer_names=["dense_1",
"conv2d_1"]):
    layer1 = keras.layers.Dense(128, name="dense_1")
    layer2 = keras.layers.Conv2D(64, (3, 3), name="conv2d_1")
    layer3 = keras.layers.Dense(64, name="dense_2")
    # Only layer1 and layer2 will apply rematerialization
    output1 = layer1(input_tensor)
    output2 = layer2(output1)
    output3 = layer3(output2)

使用 "larger_than" 模式并指定输出大小阈值

with RematScope(mode="larger_than", output_size_threshold=2048):
    layer = keras.layers.Conv2D(64, (3, 3))
    output = layer(input_tensor)  # Conv2D outputs larger than 2048

用于精细控制的嵌套范围

with RematScope(mode="full"):
    # Create layers
    layer1 = keras.layers.Dense(128, activation='relu')
    output1 = layer1(input_tensor)  # layer1 is fully rematerialized
    with RematScope(mode="larger_than", output_size_threshold=512):
        layer2 = keras.layers.Conv2D(32, (3, 3))
        output2 = layer2(output1) # layer2 is conditionally rematerialized
        # if output > 512