Keras 2 API 文档 / 层 API / 注意力层 / 多头注意力层

多头注意力层

[源代码]

MultiHeadAttention

tf_keras.layers.MultiHeadAttention(
    num_heads,
    key_dim,
    value_dim=None,
    dropout=0.0,
    use_bias=True,
    output_shape=None,
    attention_axes=None,
    kernel_initializer="glorot_uniform",
    bias_initializer="zeros",
    kernel_regularizer=None,
    bias_regularizer=None,
    activity_regularizer=None,
    kernel_constraint=None,
    bias_constraint=None,
    **kwargs
)

多头注意力层。

这是多头注意力机制的实现,如论文“Attention is all you Need”(Vaswani 等人,2017)中所述。如果querykeyvalue相同,则为自注意力。query中的每个时间步都关注key中对应的序列,并返回一个固定宽度的向量。

此层首先投影querykeyvalue。这些(实际上)是一个长度为num_attention_heads的张量列表,其中对应的形状为(batch_size, <query dimensions>, key_dim)(batch_size, <key/value dimensions>, key_dim)(batch_size, <key/value dimensions>, value_dim)

然后,将查询张量和键张量进行点积并进行缩放。对这些结果进行softmax运算以获得注意力概率。然后,通过这些概率对值张量进行插值,然后连接回单个张量。

最后,最后一个维度为value_dim的结果张量可以进行线性投影并返回。

在自定义层中使用MultiHeadAttention时,自定义层必须实现自己的build()方法并在其中调用MultiHeadAttention_build_from_signature()。这使得在加载模型时可以正确恢复权重。

示例

对两个序列输入执行 1D 交叉注意力,并使用注意力掩码。返回各个头上的额外注意力权重。

>>> layer = MultiHeadAttention(num_heads=2, key_dim=2)
>>> target = tf.keras.Input(shape=[8, 16])
>>> source = tf.keras.Input(shape=[4, 16])
>>> output_tensor, weights = layer(target, source,
...                                return_attention_scores=True)
>>> print(output_tensor.shape)
(None, 8, 16)
>>> print(weights.shape)
(None, 2, 8, 4)

对 5D 输入张量在轴 2 和 3 上执行 2D 自注意力。

>>> layer = MultiHeadAttention(
...     num_heads=2, key_dim=2, attention_axes=(2, 3))
>>> input_tensor = tf.keras.Input(shape=[5, 3, 4, 16])
>>> output_tensor = layer(input_tensor, input_tensor)
>>> print(output_tensor.shape)
(None, 5, 3, 4, 16)

参数

  • num_heads:注意力头的数量。
  • key_dim:查询和键的每个注意力头的尺寸。
  • value_dim:值的每个注意力头的尺寸。
  • dropout:丢弃概率。
  • use_bias:布尔值,指示密集层是否使用偏差向量/矩阵。
  • output_shape:输出张量的预期形状,除了批次和序列维度。如果未指定,则投影回查询特征维度(查询输入的最后一个维度)。
  • attention_axes:应用注意力的轴。None 表示对所有轴应用注意力,但批次、头和特征除外。
  • kernel_initializer:密集层内核的初始化器。
  • bias_initializer:密集层偏差的初始化器。
  • kernel_regularizer:密集层内核的正则化器。
  • bias_regularizer:密集层偏差的正则化器。
  • activity_regularizer:密集层活动的正则化器。
  • kernel_constraint:密集层内核的约束。
  • bias_constraint:密集层内核的约束。

调用参数

  • query:形状为(B, T, dim)的查询Tensor
  • value:形状为(B, S, dim)的值Tensor
  • key:形状为(B, S, dim)的可选键Tensor。如果未给出,则将同时使用value作为keyvalue,这是最常见的情况。
  • attention_mask:形状为(B, T, S)的布尔掩码,用于阻止对某些位置的注意力。布尔掩码指定哪些查询元素可以关注哪些键元素,1 表示注意力,0 表示没有注意力。对于缺失的批次维度和头维度,可以发生广播。
  • return_attention_scores:一个布尔值,指示输出是否应为(attention_output, attention_scores)(如果为True),或attention_output(如果为False)。默认为False
  • training:Python 布尔值,指示层是否应处于训练模式(添加丢弃)或推理模式(不添加丢弃)。将使用父层/模型的训练模式,或者如果没有父层则使用 False(推理)。
  • use_causal_mask:一个布尔值,指示是否应用因果掩码以防止标记关注未来的标记(例如,在解码器 Transformer 中使用)。

返回值

  • attention_output:计算结果,形状为(B, T, E),其中T表示目标序列形状,E是如果output_shapeNone则查询输入的最后一个维度。否则,多头输出将投影到output_shape指定的形状。
  • attention_scores:[可选]注意力轴上的多头注意力系数。