重新物化(Remat)

[源代码]

remat 函数

keras.remat(f)

对函数或层应用重新物化以进行内存优化。

重新物化是一种内存优化技术,它以计算换取内存。它不是存储中间结果(例如激活值)用于反向传播,而是在反向传播过程中重新计算它们。这以增加计算时间为代价,减少了峰值内存使用,从而允许在相同的内存限制下训练更大的模型或使用更大的批处理大小。

参数

  • f:可调用函数,对其应用重新物化。这通常是一个计算开销大的操作,其中可以重新计算而不是存储中间状态。

返回

一个包装函数,用于应用重新物化。返回的函数定义了一个自定义梯度,确保在反向传播过程中,根据需要重新计算前向计算。

示例

from keras import Model
class CustomRematLayer(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.remat_function = remat(self.intermediate_function)

    def intermediate_function(self, x):
        for _ in range(2):
            x = x + x * 0.1  # Simple scaled transformation
        return x

    def call(self, inputs):
        return self.remat_function(inputs)

# Define a simple model using the custom layer
inputs = layers.Input(shape=(4,))
x = layers.Dense(4, activation="relu")(inputs)
x = CustomRematLayer()(x)  # Custom layer with rematerialization
outputs = layers.Dense(1)(x)

# Create and compile the model
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer="sgd", loss="mse")