RematScope
类keras.RematScope(mode="full", output_size_threshold=1024, layer_names=None)
Keras 中用于启用再计算的上下文管理器。
再计算(梯度检查点)通过在反向传播期间重新计算中间激活来权衡内存与计算。这对于在有限的内存约束下训练大型模型或大型批次尤其有用。
这应该在初始化层时使用(例如,layer(input)
)。再计算在执行时应用,而不是在创建时应用。
参数
"full"
:将再计算全局应用于所有支持的操作。"activations"
:对任何包含 keras.activations
的层(例如,Dense(..., activation=relu)
)上的激活应用再计算。"larger_than"
:对输出大小大于 output_size_threshold
的层应用再计算。"list_of_layers"
:对特定层名称列表应用再计算。None
:禁用再计算。"larger_than"
模式的输出大小阈值。生成输出大于此阈值的层将被再计算。默认值为 1024
。"list_of_layers"
模式的层名称列表。默认值为空列表。示例
使用 "list_of_layers" 模式
from keras import RematScope
input_tensor = tf.random.normal((1, 32, 32, 3))
with RematScope(mode="list_of_layers", layer_names=["dense_1",
"conv2d_1"]):
layer1 = keras.layers.Dense(128, name="dense_1")
layer2 = keras.layers.Conv2D(64, (3, 3), name="conv2d_1")
layer3 = keras.layers.Dense(64, name="dense_2")
# Only layer1 and layer2 will apply rematerialization
output1 = layer1(input_tensor)
output2 = layer2(output1)
output3 = layer3(output2)
使用 "larger_than" 模式并指定输出大小阈值
with RematScope(mode="larger_than", output_size_threshold=2048):
layer = keras.layers.Conv2D(64, (3, 3))
output = layer(input_tensor) # Conv2D outputs larger than 2048
用于精细控制的嵌套范围
with RematScope(mode="full"):
# Create layers
layer1 = keras.layers.Dense(128, activation='relu')
output1 = layer1(input_tensor) # layer1 is fully rematerialized
with RematScope(mode="larger_than", output_size_threshold=512):
layer2 = keras.layers.Conv2D(32, (3, 3))
output2 = layer2(output1) # layer2 is conditionally rematerialized
# if output > 512