MHA
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention (MHA) 的标准实现
"""
def __init__(self, d_model, num_heads):
"""
Args:
d_model (int): 模型的总维度。
num_heads (int): 注意力头的数量。
"""
super().__init__()
# 确保 d_model 可以被 num_heads 整除
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_head = d_model // num_heads # 每个头的维度
# 定义 Q, K, V 的线性投射层
# 这里我们将它们合并到一个大的线性层,也可以分开定义
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
# 最终的输出线性层
self.W_o = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask=None):
"""
前向传播
Args:
query (Tensor): shape (batch_size, seq_len_q, d_model)
key (Tensor): shape (batch_size, seq_len_k, d_model)
value (Tensor): shape (batch_size, seq_len_v, d_model)
mask (Tensor, optional): Defaults to None.
"""
batch_size = query.size(0)
# 1. 线性投射
Q = self.W_q(query) # (batch_size, seq_len_q, d_model)
K = self.W_k(key) # (batch_size, seq_len_k, d_model)
V = self.W_v(value) # (batch_size, seq_len_v, d_model)
# 2. 拆分成多个头
# view: 重塑 tensor
# transpose: 交换维度,将 num_heads 维度提前,方便并行计算
# (batch, seq_len, d_model) -> (batch, seq_len, num_heads, d_head) -> (batch, num_heads, seq_len, d_head)
Q = Q.view(batch_size, -1, self.num_heads, self.d_head).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.d_head).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_head).transpose(1, 2)
# 3. 计算缩放点积注意力
# context 的 shape: (batch, num_heads, seq_len_q, d_head)
context, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask)
# 4. 拼接所有头
# transpose 和 contiguous().view() 是上面拆分操作的逆过程
# (batch, num_heads, seq_len_q, d_head) -> (batch, seq_len_q, num_heads, d_head)
context = context.transpose(1, 2).contiguous()
# -> (batch, seq_len_q, d_model)
context = context.view(batch_size, -1, self.d_model)
# 5. 最终线性投射
output = self.W_o(context) # (batch, seq_len_q, d_model)
return output, attn_weights
恢复 output 的维度时,需要注意一点,transpose() 会返回一个 非连续张量(因为它只是换了 strides,不是真正复制内存)。.view() 要求张量必须是 内存连续的,否则就会报错。所以必须使用 ==.contiguous()== 来把张量转化为内存连续的
这个 scaled_dot_product_attention 的实现如下:
def scaled_dot_product_attention(self, Q, K, V, mask=None):
"""
计算缩放点积注意力的核心函数
"""
# 1. 计算 Q 和 K^T 的点积
attn_scores = torch.matmul(Q, K.transpose(-2, -1))
# 2. 缩放
scaled_attn_scores = attn_scores / math.sqrt(self.d_head)
# 3. (可选) 应用 mask
if mask is not None:
# mask 的值为 True 的地方会被设置为一个非常小的负数(-1e9 -> -inf -> softmax -> 0)
scaled_attn_scores = scaled_attn_scores.masked_fill(mask == 0, -1e9)
# 4. 计算 softmax 得到注意力权重
attn_weights = torch.softmax(scaled_attn_scores, dim=-1) # dim=-1 一定要写,没有默认值,会报错
# 5. 将权重应用于 V
output = torch.matmul(attn_weights, V)
return output, attn_weights
这里有一些 pytorch 的 API 需要注意。转置一个矩阵的方法有两种:使用 ==A.transpose(-1, -2)==, or ==A.mT== 这两个操作是等价的。A.mT 是 matrix transpose 的意思,会自动交换最后两个维度。注意,在高维矩阵上不能使用 .T 这个操作,这个操作会把所有维度都转置了,.T 只能用在二维的情况下。
还有一个 API 需要注意的:我们对 attention_sore 进行掩码时使用的 API 是 ==masked_fill(condition, val)==。用法就是给 mask==0 的元素,赋予一个趋近于负无穷的数字,这样经过 softmax 函数以后会输出 0 。写出调用方法: ==masked_fill(m == 0, -1e9)==。
GQA
GQA 的实现和 MHA 基本上没啥区别。核心的区别有两个:1. KV 的 projection 矩阵的输出维度更小。W_k = ==nn.Linear(dim, num_kv_heads * h_dim)==;2. 计算 scaled_dot_product_attention 之前,我们必须要对 KV 矩阵进行重复(使用 ==K.repeat_interleave()== API),将 KV 矩阵恢复为和 Q 一样的 shape 进行计算。
import torch
import torch.nn as nn
import math
class GroupedQueryAttention(nn.Module):
"""
Grouped-Query Attention (GQA) 的实现
"""
def __init__(self, d_model, num_q_heads, num_kv_heads):
"""
Args:
d_model (int): 模型的总维度。
num_q_heads (int): Query 头的数量。
num_kv_heads (int): Key/Value 头的数量 (GQA中的 'G')。
"""
super(GroupedQueryAttention, self).__init__()
# 确保维度和头的数量是可整除的
assert d_model % num_q_heads == 0, "d_model must be divisible by num_q_heads"
assert num_q_heads % num_kv_heads == 0, "num_q_heads must be divisible by num_kv_heads"
self.d_model = d_model
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.num_groups = num_q_heads // num_kv_heads # 每个 K/V 组对应的 Q 头的数量
self.d_head = d_model // num_q_heads
# 定义 Q, K, V 的线性投射层
# Q 的投射维度是完整的 d_model
self.W_q = nn.Linear(d_model, d_model)
# K 和 V 的投射维度更小,因为它们的头更少
self.W_k = nn.Linear(d_model, self.num_kv_heads * self.d_head)
self.W_v = nn.Linear(d_model, self.num_kv_heads * self.d_head)
# 最终的输出线性层
self.W_o = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 1. 线性投射
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)
# 2. 拆分 Q, K, V 头
# Q: (batch, seq_len_q, d_model) -> (batch, num_q_heads, seq_len_q, d_head)
Q = Q.view(batch_size, -1, self.num_q_heads, self.d_head).transpose(1, 2)
# K: (batch, seq_len_k, num_kv_heads * d_head) -> (batch, num_kv_heads, seq_len_k, d_head)
K = K.view(batch_size, -1, self.num_kv_heads, self.d_head).transpose(1, 2)
# V: (batch, seq_len_v, num_kv_heads * d_head) -> (batch, num_kv_heads, seq_len_v, d_head)
V = V.view(batch_size, -1, self.num_kv_heads, self.d_head).transpose(1, 2)
# 3. GQA 的核心:重复 K 和 V 来匹配 Q 的头数
# K: (batch, num_kv_heads, seq_len_k, d_head) -> (batch, num_q_heads, seq_len_k, d_head)
# V: (batch, num_kv_heads, seq_len_v, d_head) -> (batch, num_q_heads, seq_len_v, d_head)
if self.num_groups > 1:
K = K.repeat_interleave(self.num_groups, dim=1)
V = V.repeat_interleave(self.num_groups, dim=1)
# 4. 计算缩放点积注意力 (现在 Q, K, V 的头数维度都匹配了)
context, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask)
# 5. 拼接所有头
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# 6. 最终线性投射
output = self.W_o(context)
return output, attn_weights
GQA 的 scaled_dot_product_attention 和 MHA 是完全一样的。
def scaled_dot_product_attention(self, Q, K, V, mask=None):
"""
这个核心函数和 MHA 完全一样
"""
attn_scores = Q @ K.mT
scaled_attn_scores = attn_scores / math.sqrt(self.d_head)
if mask is not None:
scaled_attn_scores = scaled_attn_scores.masked_fill(mask == 0, -1e9)
attn_weights = torch.softmax(scaled_attn_scores, dim=-1)
output = attn_weights @ V
return output, attn_weights
MLA
https://kexue.fm/archives/10091
注意力架构演化的核心动机始终是:找到一种注意力架构,在**==推理开销(更小的 KV 缓存)==与==模型性能(注意力的效果)==** 之间的权衡中寻求最优解。
起点与瓶颈 (MHA):MHA 效果卓越,是性能的黄金标准,但其与头数成正比的 KV 缓存成为了处理长文本时不可逾越的效率瓶颈。
激进的探索 (MQA):为了突破这一瓶颈,MQA 采取了激进的策略——所有查询头(Query Head) 共享==同一套 KV 矩阵==。这实现了 KV 缓存的最大化压缩,但也造成模型表达能力明显下降。
务实的平衡 (GQA):GQA 则是在 MHA 和 MQA 之间取得平衡 “sweet point”。它通过将查询头分组,并通过==组内共享 KV 矩阵==,实现了一种务实且高效的折中,在 MLA 出现之前被业界广泛采纳。
全新的野心 (MLA):然而,GQA 本质上仍是一种“权衡”。DeepSeek 更具野心:我们能否超过这种权衡?MLA 在推理时拥有与==MQA==相媲美的极小 KV 缓存,同时在模型性能上恢复到接近 ==MHA== 的理想水平,从而实现“鱼与熊掌兼得”。

欢迎来到这里!
我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。
注册 关于