GroupedQueryAttention
类keras.layers.GroupQueryAttention(
head_dim,
num_query_heads,
num_key_value_heads,
dropout=0.0,
use_bias=True,
flash_attention=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
seed=None,
**kwargs
)
分组查询注意力层。
这是 Ainslie 等人,2023 引入的分组查询注意力的实现。这里 num_key_value_heads
表示分组数,将 num_key_value_heads
设置为 1 等同于多查询注意力,当 num_key_value_heads
等于 num_query_heads
时,等同于多头注意力。
此层首先投影 query
、key
和 value
张量。然后,重复 key
和 value
以匹配 query
的头数。
然后,对 query
进行缩放并与 key
张量进行点积。这些结果经过 softmax 运算以获得注意力概率。然后,使用这些概率对 value 张量进行插值,并将其拼接回单个张量。
参数
None
,则在可能的情况下,该层会尝试使用 Flash Attention 进行更快、更省内存的注意力计算。此行为可以通过 keras.config.enable_flash_attention()
或 keras.config.disable_flash_attention()
进行配置。调用参数
(batch_dim, target_seq_len, feature_dim)
,其中 batch_dim
是批大小,target_seq_len
是目标序列的长度,feature_dim
是特征的维度。(batch_dim, source_seq_len, feature_dim)
,其中 batch_dim
是批大小,source_seq_len
是源序列的长度,feature_dim
是特征的维度。(batch_dim, source_seq_len, feature_dim)
。如果未给出,将同时使用 value
作为 key
和 value
,这是最常见的情况。(batch_dim, target_seq_len, source_seq_len)
的布尔掩码,用于阻止对某些位置的注意力。布尔掩码指定了哪些查询元素可以关注哪些键元素,其中 1 表示关注,0 表示不关注。对于缺失的批维度和头维度,可以进行广播。True
,输出应为 (attention_output, attention_scores)
,如果为 False
,则为 attention_output
。默认为 False
。False
(推理)。返回
(batch_dim, target_seq_len, feature_dim)
,其中 target_seq_len
是目标序列长度,feature_dim
是查询输入的最后一个维度。(batch_dim, num_query_heads, target_seq_len, source_seq_len)
。