Keras 2 API 文档 / 混合精度 / 混合精度策略 API

混合精度策略 API

[源代码]

Policy

tf_keras.mixed_precision.Policy(name)

TF-Keras 层的 dtype 策略。

dtype 策略确定层的计算和变量 dtype。每个层都有一个策略。策略可以传递给层构造函数的 dtype 参数,或者可以使用 tf.keras.mixed_precision.set_global_policy 设置全局策略。

参数

  • name:策略名称,它确定计算和变量 dtype。可以是任何 dtype 名称,例如 'float32''float64',这将导致计算和变量 dtype 都将是该 dtype。也可以是字符串 'mixed_float16''mixed_bfloat16',这将导致计算 dtype 为 float16 或 bfloat16,而变量 dtype 为 float32。

通常,仅在使用混合精度时才需要与 dtype 策略进行交互,混合精度是指使用 float16 或 bfloat16 进行计算,而使用 float32 作为变量。这就是术语 mixed_precision 出现在 API 名称中的原因。可以通过将 'mixed_float16''mixed_bfloat16' 传递给 tf.keras.mixed_precision.set_global_policy 来启用混合精度。有关如何使用混合精度的更多信息,请参阅混合精度指南

>>> tf.keras.mixed_precision.set_global_policy('mixed_float16')
>>> layer1 = tf.keras.layers.Dense(10)
>>> layer1.dtype_policy  # `layer1` will automatically use mixed precision
<Policy "mixed_float16">
>>> # Can optionally override layer to use float32
>>> # instead of mixed precision.
>>> layer2 = tf.keras.layers.Dense(10, dtype='float32')
>>> layer2.dtype_policy
<Policy "float32">
>>> # Set policy back to initial float32 for future examples.
>>> tf.keras.mixed_precision.set_global_policy('float32')

在上面的示例中,将 dtype='float32' 传递给层等同于传递 dtype=tf.keras.mixed_precision.Policy('float32')。通常,将 dtype 策略名称传递给层等同于传递相应的策略,因此永远不需要显式构造 Policy 对象。

注意:如果您使用 'mixed_float16' 策略,Model.compile 将自动使用 tf.keras.mixed_precision.LossScaleOptimizer 包装优化器。如果您使用自定义训练循环而不是调用 Model.compile,则应显式使用 tf.keras.mixed_precision.LossScaleOptimizer,以避免 float16 出现数值下溢。

层如何使用其策略的计算 dtype

层将其输入转换为其计算 dtype。这会导致层的计算和输出也采用计算 dtype。例如

>>> x = tf.ones((4, 4, 4, 4), dtype='float64')
>>> # `layer`'s policy defaults to float32.
>>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
>>> layer.compute_dtype  # Equivalent to layer.dtype_policy.compute_dtype
'float32'
>>> # `layer` casts its inputs to its compute dtype and does computations in
>>> # that dtype.
>>> y = layer(x)
>>> y.dtype
tf.float32

请注意,基本 tf.keras.layers.Layer 类插入了转换。如果子类化您自己的层,则无需插入任何转换。

当前,仅将传递给层 call 方法的第一个参数中的张量进行转换(尽管这可能会在以后的次要版本中更改)。例如

>>> class MyLayer(tf.keras.layers.Layer):
...   # Bug! `b` will not be casted.
...   def call(self, a, b):
...     return a + 1., b + 1.
>>> a = tf.constant(1., dtype="float32")
>>> b = tf.constant(1., dtype="float32")
>>> layer = MyLayer(dtype="float64")
>>> x, y = layer(a, b)
>>> x.dtype
tf.float64
>>> y.dtype
tf.float32

如果使用多个输入编写自己的层,则应在 call 中显式地将其他张量转换为 self.compute_dtype,或者接受第一个参数中的所有张量作为列表。

转换仅在 TensorFlow 2 中发生。如果已调用 tf.compat.v1.disable_v2_behavior(),则可以使用 tf.compat.v1.keras.layers.enable_v2_dtype_behavior() 启用转换行为。

层如何使用其策略的变量 dtype

tf.keras.layers.Layer.add_weight 创建的变量的默认 dtype 是层的策略的变量 dtype。

如果层的计算 dtype 和变量 dtype 不同,add_weight 将使用名为 AutoCastVariable 的特殊包装器包装浮点变量。AutoCastVariable 与原始变量相同,只是它在 Layer.call 中使用时将其自身转换为层的计算 dtype。这意味着如果您正在编写层,则无需显式地将变量转换为层的计算 dtype。例如

>>> class SimpleDense(tf.keras.layers.Layer):
...
...   def build(self, input_shape):
...     # With mixed precision, self.kernel is a float32 AutoCastVariable
...     self.kernel = self.add_weight('kernel', (input_shape[-1], 10))
...
...   def call(self, inputs):
...     # With mixed precision, self.kernel will be casted to float16
...     return tf.linalg.matmul(inputs, self.kernel)
...
>>> layer = SimpleDense(dtype='mixed_float16')
>>> y = layer(tf.ones((10, 10)))
>>> y.dtype
tf.float16
>>> layer.kernel.dtype
tf.float32

层作者可以通过将 experimental_autocast=False 传递给 add_weight 来防止变量被 AutoCastVariable 包装,如果必须在层内访问变量的 float32 值,则此功能很有用。

如何编写支持混合精度和 float64 的层。

在大多数情况下,由于基本层自动转换输入、创建正确类型的变量,并且在混合精度的情况下,使用 AutoCastVariables 包装变量,因此层将自动支持混合精度和 float64,而无需任何额外的工作。

您需要额外工作以支持混合精度或 float64 的主要情况是,当您创建新张量时,例如使用 tf.onestf.random.normal。在这种情况下,您必须创建具有正确 dtype 的张量。例如,如果您调用 tf.random.normal,则必须传递计算 dtype,这是输入已转换成的 dtype

>>> class AddRandom(tf.keras.layers.Layer):
...
...   def call(self, inputs):
...     # We must pass `dtype=inputs.dtype`, otherwise a TypeError may
...     # occur when adding `inputs` to `rand`.
...     rand = tf.random.normal(shape=inputs.shape, dtype=inputs.dtype)
...     return inputs + rand
>>> layer = AddRandom(dtype='mixed_float16')
>>> y = layer(x)
>>> y.dtype
tf.float16

如果您没有将 dtype=inputs.dtype 传递给 tf.random.normal,则会发生 TypeError。这是因为 tf.random.normal 的 dtype 默认为 "float32",但输入 dtype 是 float16。您不能将 float32 张量与 float16 张量相加。


[源代码]

global_policy 函数

tf_keras.mixed_precision.global_policy()

返回全局 dtype 策略。

如果未将策略传递给层构造函数,则全局策略是用于层的默认 tf.keras.mixed_precision.Policy。如果尚未使用 keras.mixed_precision.set_global_policy 设置策略,则此函数将返回从 tf.keras.backend.floatx() 构造的策略(floatx 默认为 float32)。

>>> tf.keras.mixed_precision.global_policy()
<Policy "float32">
>>> tf.keras.layers.Dense(10).dtype_policy  # Defaults to the global policy
<Policy "float32">

如果已使用 tf.compat.v1.disable_v2_behavior() 禁用了 TensorFlow 2 行为,则此函数将返回一个特殊的“_infer”策略,该策略会从第一次调用层时第一个输入的 dtype 推断 dtype。此行为与 TensorFlow 1 中存在的行为相匹配。

有关策略的更多信息,请参见 tf.keras.mixed_precision.Policy

返回

全局策略。


[源代码]

set_global_policy 函数

tf_keras.mixed_precision.set_global_policy(policy)

设置全局 dtype 策略。

如果未将策略传递给层构造函数,则全局策略是用于层的默认 tf.keras.mixed_precision.Policy

>>> tf.keras.mixed_precision.set_global_policy('mixed_float16')
>>> tf.keras.mixed_precision.global_policy()
<Policy "mixed_float16">
>>> tf.keras.layers.Dense(10).dtype_policy
<Policy "mixed_float16">
>>> # Global policy is not used if a policy
>>> # is directly passed to constructor
>>> tf.keras.layers.Dense(10, dtype='float64').dtype_policy
<Policy "float64">
>>> tf.keras.mixed_precision.set_global_policy('float32')

如果未设置全局策略,则层将默认为从 tf.keras.backend.floatx() 构造的策略。

要使用混合精度,应将全局策略设置为 'mixed_float16''mixed_bfloat16',以便每个层默认使用 16 位计算 dtype 和 float32 变量 dtype。

只有浮点策略可以设置为全局策略,例如 'float32''mixed_float16'。非浮点策略(例如 'int32''complex64')不能设置为全局策略,因为大多数层不支持此类策略。

有关更多信息,请参见 tf.keras.mixed_precision.Policy

参数

  • policy:一个 Policy,或一个将转换为 Policy 的字符串。也可以是 None,在这种情况下,全局策略将从 tf.keras.backend.floatx() 构造