Keras 3 API 文档 / 层 API / 注意力层 / GroupQueryAttention

分组查询注意力 (GroupQueryAttention)

[源代码]

GroupedQueryAttention

keras.layers.GroupQueryAttention(
    head_dim,
    num_query_heads,
    num_key_value_heads,
    dropout=0.0,
    use_bias=True,
    flash_attention=None,
    kernel_initializer="glorot_uniform",
    bias_initializer="zeros",
    kernel_regularizer=None,
    bias_regularizer=None,
    activity_regularizer=None,
    kernel_constraint=None,
    bias_constraint=None,
    seed=None,
    **kwargs
)

分组查询注意力层。

这是 Ainslie 等人,2023 引入的分组查询注意力的实现。这里 num_key_value_heads 表示分组数,将 num_key_value_heads 设置为 1 等同于多查询注意力,当 num_key_value_heads 等于 num_query_heads 时,等同于多头注意力。

此层首先投影 querykeyvalue 张量。然后,重复 keyvalue 以匹配 query 的头数。

然后,对 query 进行缩放并与 key 张量进行点积。这些结果经过 softmax 运算以获得注意力概率。然后,使用这些概率对 value 张量进行插值,并将其拼接回单个张量。

参数

  • head_dim:每个注意力头的维度。
  • num_query_heads:查询注意力头的数量。
  • num_key_value_heads:键和值注意力头的数量。
  • dropout:Dropout 概率。
  • use_bias:布尔值,表示密集层是否使用偏置向量/矩阵。
  • flash_attention:如果为 None,则在可能的情况下,该层会尝试使用 Flash Attention 进行更快、更省内存的注意力计算。此行为可以通过 keras.config.enable_flash_attention()keras.config.disable_flash_attention() 进行配置。
  • kernel_initializer:密集层核的初始化器。
  • bias_initializer:密集层偏置的初始化器。
  • kernel_regularizer:密集层核的正则化器。
  • bias_regularizer:密集层偏置的正则化器。
  • activity_regularizer:密集层活动的正则化器。
  • kernel_constraint:密集层核的约束。
  • bias_constraint:密集层核的约束。
  • seed:用于 Dropout 层的可选整数种子。

调用参数

  • query:查询张量,形状为 (batch_dim, target_seq_len, feature_dim),其中 batch_dim 是批大小,target_seq_len 是目标序列的长度,feature_dim 是特征的维度。
  • value:值张量,形状为 (batch_dim, source_seq_len, feature_dim),其中 batch_dim 是批大小,source_seq_len 是源序列的长度,feature_dim 是特征的维度。
  • key:可选键张量,形状为 (batch_dim, source_seq_len, feature_dim)。如果未给出,将同时使用 value 作为 keyvalue,这是最常见的情况。
  • attention_mask:形状为 (batch_dim, target_seq_len, source_seq_len) 的布尔掩码,用于阻止对某些位置的注意力。布尔掩码指定了哪些查询元素可以关注哪些键元素,其中 1 表示关注,0 表示不关注。对于缺失的批维度和头维度,可以进行广播。
  • return_attention_scores:一个布尔值,指示如果为 True,输出应为 (attention_output, attention_scores),如果为 False,则为 attention_output。默认为 False
  • training:Python 布尔值,指示该层是应在训练模式下运行(添加 Dropout)还是在推理模式下运行(无 Dropout)。如果存在父层/模型,则将使用其训练模式;如果不存在父层,则为 False(推理)。
  • use_causal_mask:一个布尔值,指示是否应用因果掩码以防止 token 关注未来的 token(例如,在解码器 Transformer 中使用)。

返回

  • attention_output:计算结果,形状为 (batch_dim, target_seq_len, feature_dim),其中 target_seq_len 是目标序列长度,feature_dim 是查询输入的最后一个维度。
  • attention_scores:(可选) 注意力系数,形状为 (batch_dim, num_query_heads, target_seq_len, source_seq_len)