注意力机制

image

MHA

import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention (MHA) 的标准实现
    """
    def __init__(self, d_model, num_heads):
        """
        Args:
            d_model (int): 模型的总维度。
            num_heads (int): 注意力头的数量。
        """
        super().__init__()
        # 确保 d_model 可以被 num_heads 整除
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads # 每个头的维度

        # 定义 Q, K, V 的线性投射层
        # 这里我们将它们合并到一个大的线性层,也可以分开定义
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        # 最终的输出线性层
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        """
        前向传播
        Args:
            query (Tensor): shape (batch_size, seq_len_q, d_model)
            key (Tensor):   shape (batch_size, seq_len_k, d_model)
            value (Tensor): shape (batch_size, seq_len_v, d_model)
            mask (Tensor, optional): Defaults to None.
        """
        batch_size = query.size(0)

        # 1. 线性投射
        Q = self.W_q(query) # (batch_size, seq_len_q, d_model)
        K = self.W_k(key)   # (batch_size, seq_len_k, d_model)
        V = self.W_v(value) # (batch_size, seq_len_v, d_model)

        # 2. 拆分成多个头
        # view: 重塑 tensor
        # transpose: 交换维度,将 num_heads 维度提前,方便并行计算
        # (batch, seq_len, d_model) -> (batch, seq_len, num_heads, d_head) -> (batch, num_heads, seq_len, d_head)
        Q = Q.view(batch_size, -1, self.num_heads, self.d_head).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_head).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_head).transpose(1, 2)

        # 3. 计算缩放点积注意力
        # context 的 shape: (batch, num_heads, seq_len_q, d_head)
        context, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask)

        # 4. 拼接所有头
        # transpose 和 contiguous().view() 是上面拆分操作的逆过程
        # (batch, num_heads, seq_len_q, d_head) -> (batch, seq_len_q, num_heads, d_head)
        context = context.transpose(1, 2).contiguous()
        # -> (batch, seq_len_q, d_model)
        context = context.view(batch_size, -1, self.d_model)

        # 5. 最终线性投射
        output = self.W_o(context) # (batch, seq_len_q, d_model)

        return output, attn_weights

恢复 output 的维度时,需要注意一点,transpose() 会返回一个 非连续张量(因为它只是换了 strides,不是真正复制内存)。.view() 要求张量必须是 内存连续的,否则就会报错。所以必须使用 ==.contiguous()== 来把张量转化为内存连续的

这个 scaled_dot_product_attention 的实现如下:

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        """
        计算缩放点积注意力的核心函数
        """
        # 1. 计算 Q 和 K^T 的点积
        attn_scores = torch.matmul(Q, K.transpose(-2, -1))

        # 2. 缩放
        scaled_attn_scores = attn_scores / math.sqrt(self.d_head)

        # 3. (可选) 应用 mask
        if mask is not None:
            # mask 的值为 True 的地方会被设置为一个非常小的负数(-1e9 -> -inf -> softmax -> 0)
            scaled_attn_scores = scaled_attn_scores.masked_fill(mask == 0, -1e9)

        # 4. 计算 softmax 得到注意力权重
        attn_weights = torch.softmax(scaled_attn_scores, dim=-1) # dim=-1 一定要写,没有默认值,会报错

        # 5. 将权重应用于 V
        output = torch.matmul(attn_weights, V)
        return output, attn_weights

这里有一些 pytorch 的 API 需要注意。转置一个矩阵的方法有两种:使用 ==A.transpose(-1, -2)==, or ==A.mT== 这两个操作是等价的。A.mT 是 matrix transpose 的意思,会自动交换最后两个维度。注意,在高维矩阵上不能使用 .T 这个操作,这个操作会把所有维度都转置了,.T 只能用在二维的情况下。

还有一个 API 需要注意的:我们对 attention_sore 进行掩码时使用的 API 是 ==masked_fill(condition, val)==。用法就是给 mask==0 的元素,赋予一个趋近于负无穷的数字,这样经过 softmax 函数以后会输出 0 。写出调用方法: ==masked_fill(m == 0, -1e9)==。

GQA

GQA 的实现和 MHA 基本上没啥区别。核心的区别有两个:1. KV 的 projection 矩阵的输出维度更小。W_k = ==nn.Linear(dim, num_kv_heads * h_dim)==;2. 计算 scaled_dot_product_attention 之前,我们必须要对 KV 矩阵进行重复(使用 ==K.repeat_interleave()== API),将 KV 矩阵恢复为和 Q 一样的 shape 进行计算。

import torch
import torch.nn as nn
import math

class GroupedQueryAttention(nn.Module):
    """
    Grouped-Query Attention (GQA) 的实现
    """
    def __init__(self, d_model, num_q_heads, num_kv_heads):
        """
        Args:
            d_model (int): 模型的总维度。
            num_q_heads (int): Query 头的数量。
            num_kv_heads (int): Key/Value 头的数量 (GQA中的 'G')。
        """
        super(GroupedQueryAttention, self).__init__()
        # 确保维度和头的数量是可整除的
        assert d_model % num_q_heads == 0, "d_model must be divisible by num_q_heads"
        assert num_q_heads % num_kv_heads == 0, "num_q_heads must be divisible by num_kv_heads"

        self.d_model = d_model
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.num_groups = num_q_heads // num_kv_heads # 每个 K/V 组对应的 Q 头的数量
        self.d_head = d_model // num_q_heads

        # 定义 Q, K, V 的线性投射层
        # Q 的投射维度是完整的 d_model
        self.W_q = nn.Linear(d_model, d_model)
        # K 和 V 的投射维度更小,因为它们的头更少
        self.W_k = nn.Linear(d_model, self.num_kv_heads * self.d_head)
        self.W_v = nn.Linear(d_model, self.num_kv_heads * self.d_head)

        # 最终的输出线性层
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # 1. 线性投射
        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)

        # 2. 拆分 Q, K, V 头
        # Q: (batch, seq_len_q, d_model) -> (batch, num_q_heads, seq_len_q, d_head)
        Q = Q.view(batch_size, -1, self.num_q_heads, self.d_head).transpose(1, 2)
        # K: (batch, seq_len_k, num_kv_heads * d_head) -> (batch, num_kv_heads, seq_len_k, d_head)
        K = K.view(batch_size, -1, self.num_kv_heads, self.d_head).transpose(1, 2)
        # V: (batch, seq_len_v, num_kv_heads * d_head) -> (batch, num_kv_heads, seq_len_v, d_head)
        V = V.view(batch_size, -1, self.num_kv_heads, self.d_head).transpose(1, 2)

        # 3. GQA 的核心:重复 K 和 V 来匹配 Q 的头数
        # K: (batch, num_kv_heads, seq_len_k, d_head) -> (batch, num_q_heads, seq_len_k, d_head)
        # V: (batch, num_kv_heads, seq_len_v, d_head) -> (batch, num_q_heads, seq_len_v, d_head)
        if self.num_groups > 1:
            K = K.repeat_interleave(self.num_groups, dim=1)
            V = V.repeat_interleave(self.num_groups, dim=1)

        # 4. 计算缩放点积注意力 (现在 Q, K, V 的头数维度都匹配了)
        context, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask)

        # 5. 拼接所有头
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)

        # 6. 最终线性投射
        output = self.W_o(context)

        return output, attn_weights

GQA 的 scaled_dot_product_attention 和 MHA 是完全一样的。

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        """
        这个核心函数和 MHA 完全一样
        """
        attn_scores = Q @ K.mT
        scaled_attn_scores = attn_scores / math.sqrt(self.d_head)
        if mask is not None:
            scaled_attn_scores = scaled_attn_scores.masked_fill(mask == 0, -1e9)
        attn_weights = torch.softmax(scaled_attn_scores, dim=-1)
        output = attn_weights @ V
        return output, attn_weights

MLA

https://kexue.fm/archives/10091

注意力架构演化的核心动机始终是:找到一种注意力架构,在**==推理开销(更小的 KV 缓存)====模型性能(注意力的效果)==** 之间的权衡中寻求最优解。
起点与瓶颈 (MHA):MHA 效果卓越,是性能的黄金标准,但其与头数成正比的 KV 缓存成为了处理长文本时不可逾越的效率瓶颈。
激进的探索 (MQA):为了突破这一瓶颈,MQA 采取了激进的策略——所有查询头(Query Head) 共享==同一套 KV 矩阵==。这实现了 KV 缓存的最大化压缩,但也造成模型表达能力明显下降。
务实的平衡 (GQA):GQA 则是在 MHA 和 MQA 之间取得平衡 “sweet point”。它通过将查询头分组,并通过==组内共享 KV 矩阵==,实现了一种务实且高效的折中,在 MLA 出现之前被业界广泛采纳。
全新的野心 (MLA):然而,GQA 本质上仍是一种“权衡”。DeepSeek 更具野心:我们能否超过这种权衡?MLA 在推理时拥有与==MQA==相媲美的极小 KV 缓存,同时在模型性能上恢复到接近 ==MHA== 的理想水平,从而实现“鱼与熊掌兼得”。

image

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...