Native Sparse Attention

参考资料

2502.11089v2.pdf

核心观点

Sparse Attention

softmax-based 的 attention 机制决定了它固有的稀疏性,即:

很多 QK 对算出来的 attention score 接近于 0, 计算它们并没有意义。

image

如图,紫色部分表示了哪些 QK 对是极为相关的,并反应了在 QK 和 AV 的计算过程中潜在的稀疏计算特性。

如何利用稀疏性来优化算法效率?

  1. KV-cache 淘汰机制

    SnapKV​认为提示词末尾的 token 可以揭示整个 prompt 的那些 token 是重要的。

    1. 基于 Observation Window 计算 Prompt 各个位置的注意力权重之和行
    2. 使用 1 维 Pooling 进行权重聚类,选出 Top-k 的位置索引行
    3. 利用 gather 操作提取选中位置的 Key 和 Value 行
    4. 将压缩的 KV 与 Observation Window 对应的 KV 拼接,得到最终结果行

    举个简单例子,假设 Prompt 序列长度为 1000,Observation Window 大小为 16,我们希望将 KV 缓存压缩至 256。SnapKV 首先基于最后 16 个 Token 的注意力分布,通过 Voting 算法选出最重要的 240 个位置。然后将这 240 个位置对应的 Key 和 Value 与 Observation Window 拼接,形成大小为 256 的新 KV 缓存。这样,后续的生成过程就只需在显著缩减的 KV 缓存上进行注意力计算,从而大幅提升了效率。

    def snap_kv ( query_states , key_states , value_states , window_size , max_capacity_prompt ,
    kernel_size ):
    	bsz , num_heads , q_len , head_dim = query_states . shape
    	# Ensure it is the prompt phase .
    	assert key_states . shape [ -2] == query_states . shape [ -2]
    	if q_len < max_capacity_prompt :
    		return key_states , value_states
    	else :
    		# Compute attention weights of observing window ’s queries and prefix context ’s Keys .
    		attn_weights = compute_attn ( query_states [... , - window_size :, :] , key_states ,
    	attention_mask )
    		# Sum the weight along the query dimension .
    		vote = attn_weights [... , - window_size :, :- window_size ]. sum ( dim = -2)
    		# Apply 1D pooling for clustering .
    		# Select top -k indices based on the pooled weights to identify important positions .
    		indices = pool_vote . topk ( max_capacity_prompt - window_size , dim = -1) . indices
    		# Expand the indices to match the head dimension for gathering .
    		indices = indices . unsqueeze ( -1) . expand ( -1 , -1, -1, head_dim )
    		# Gather the compressed past key and value states based on the selected indices .
    		k_past_compress = key_states [... , : - window_size , :]. gather ( dim =2 , index = indices )
    		v_past_compress = value_states [... , :- window_size , :]. gather ( dim =2 , index = indices )
    		k_obs = key_states [... , - window_size :, :]
    		v_obs = value_states [... , - window_size :, :]
    		key_states = torch . cat ([ k_past_compress , k_obs ], dim =2)
    		value_states = torch . cat ([ v_past_compress , v_obs ], dim =2)
    		return key_states , value_states
    
  2. blockwise KV-cache 选择机制

    SeerAttention

    image

    • 训练阶段:使用原始的 attention 模型计算 attention map,通过 2D MaxPooling 得到一个 Groud Truth。这个 Groud Truth 反映了 blockwise token 之间的相关性。类似于 Distillation,额外训练了一个以 blockwise token 作为输入的模型,训练这个模型的输出逼近 Ground Truth。
    • 推理阶段:使用蒸馏得到的模型快速计算出 Ground Truth ,使用 TopK 或者 Threshold 策略进行筛选,得到一个 Mask 矩阵。那么,从理论上来说,原始 attention 模型只需要计算 Mask 中被选择的部分。

    训练了一个辅助模型,用来低成本的计算 blockwise token 之间的相关性。从而快速选择出主模型(不需要更改结构)哪些 block 是需要计算的。

  3. 基于采样、聚类或哈希的选择方法

Overview

image

核心算法流程如下:

NSA 采用了三个分支来计算 Attenttion ,每个分支采用独立且不同的 \tilde{K}_{t}\tilde{V}_{t} 作为输入。最后通过一个可学习的门控来选择最终的输出。

\begin{array}{c}\tilde{K}_{t}=f_{K}\left(\mathbf{q}_{t}, \mathbf{k}_{: t}, \mathbf{v}_{: t}\right), \quad \tilde{V}_{t}=f_{V}\left(\mathbf{q}_{t}, \mathbf{k}_{: t}, \mathbf{v}_{: t}\right) \\\mathbf{o}_{t}^{*}=\operatorname{Attn}\left(\mathbf{q}_{t}, \tilde{K}_{t}, \tilde{V}_{t}\right)\end{array}

\mathbf{o}_{t}^{*}=\sum_{c \in C} g_{t}^{c} \cdot \operatorname{Attn}\left(\mathbf{q}_{t}, \tilde{K}_{t}^{c}, \tilde{V}_{t}^{c}\right)

C=\{\mathrm{cmp}, \mathrm{slc}, \mathrm{win}\}

Token Compression

通过将一系列连续的 Key 或 Value 聚合为块级(block-level)表征,我们可以得到压缩后的键(key)与值(value),它们能够概括整个块的信息。

\tilde{K}_{t}^{\mathrm{cmp}}=f_{K}^{\mathrm{cmp}}\left(k_{: t}\right)=\left\{\varphi\left(k_{i d+1: i d+l}\right) \left\lvert\, 0 \leq i \leq\left\lfloor\frac{t-l}{d}\right\rfloor\right.\right\}

其中,

  • l:块长度(block length),表示每次聚合多少个连续的 key;
  • d:相邻块之间的滑动步长(sliding stride);
  • \varphi(\cdot):一个可学习的 MLP 网络,带有块内位置编码(intra-block position encoding),用于将该块内的多个 key 映射到一个压缩后的 key。

\tilde{K}_{t}^{\mathrm{cmp}} \in \mathbb{R}^{d_{k} \times\left\lfloor\frac{t-l}{d}\right\rfloor}

image

需要注意的是,在多头注意力中,Compression Attention 每个头需要单独进行压缩。因此,输入与输出的维度应是:

q [batch, seqlen_q, n_head_q * head_dim] --> [batch, nhead_q, seqlen_q, head_dim]
k,v [batch, seqlen_kv, n_head_kv * head_dim] --> [btach, nhead_kv, seqlen_kv, head_dim] --> [btach, nhead_kv, blocknum, head_dim]
P [batch, nhead_q, seqlen_q, blocknum]
O [batch, nhead_q, seqlen_q, head_dim] --> [batch, seqlen_q, n_head_q * head_dim]

# 多头压缩注意力
Q_mha = Q.view(1, t, heads, head_dim).transpose(1,2)
K_cmp_mha = K_cmp.view(1, block_nums, heads, head_dim).transpose(1,2)
V_cmp_mha = V_cmp.view(1, block_nums, heads, head_dim).transpose(1,2)
score_cmp = Q_mha @ K_cmp_mha.transpose(2,3) # bs, head, q_len, k_cmp_len
print(score_cmp.shape) # torch.Size([1, 4, 32, 4])

p_cmp = F.softmax(score_cmp, dim = -1) # torch.Size([1, 4, 32, 4)
o_cmp = p_cmp @ V_cmp_mha
print(o_cmp.shape) # torch.Size([1, 4, 32, 4]) 

o_cmp = o_cmp.transpose(2, 1).reshape(batch_size, t, dim)
print(o_cmp.shape) # torch.Size([1, 32, 16])

Token Selection

image

仅使用压缩后的 Key 和 Value 可能会丢失重要的细粒度信息,因此我们需要有选择地保留部分关键 token。下面介绍一种高效的 token 选择机制,它能够以极低的计算开销识别并保留最重要的 token。

  1. Blockwise Selection

选择策略以空间上连续的块为单位处理 Key 和 Value 序列,主要受到以下两个因素驱动:

  • 硬件效率考虑,Kernel 实现是以 tile 为单位来做,也有利于 tensor core 的计算和连续访存。
  • 注意力得分的固有分布规律,注意力得分通常具有空间连续性
  1. 基于 importance score 来选择

直接计算块级(block-level)重要性得分可能会带来显著的计算开销。幸好这里可以复用 压缩分支 计算注意力时得到的注意力分数

\mathbf{p}_t^{\mathrm{cmp}} = \mathrm{Softmax}\left( \mathbf{q}_t^{\mathrm{T}} \tilde{\mathbf{K}}_t^{\mathrm{cmp}} \right)

如果,Token Selection 使用的 block 划分策略与 Token Compression 不一致时,需要做一些额外处理:

\mathbf{p}_t^{\mathrm{slc}}[j] = \sum_{m=0}^{\frac{l'}{d}-1} \sum_{n=0}^{\frac{l'}{d}-1} \mathbf{p}_t^{\mathrm{cmp}}\!\left[\frac{l'}{d}j - m - n\right]

对于多头自注意力,在实现的 kernel 时,我们希望==每个 query head 能够选择相同的 KV block==,因此需要对每个头的注意力进行聚合进行统一选择:

\mathbf{p}_t^{\mathrm{slc}'} = \sum_{h=1}^{H} \mathbf{p}_t^{\mathrm{slc},(h)}

image

\mathcal{I}_t = \{ i \mid \mathrm{rank}(\mathbf{p}_t^{\mathrm{slc}'}[i]) \leq n \} \\ \tilde{\mathbf{K}}_t^{\mathrm{slc}} = \mathrm{Cat}\!\left( \{ \mathbf{k}_{i l' + 1 : (i + 1) l'} \mid i \in \mathcal{I}_t \} \right)

P_cmp [batch, nhead_q, seqlen_q, blocknum]  --> P_slc [batch, seqlen_q, clocknum]
q [batch, seqlen_q, n_head_q * head_dim] --> [batch, nhead_q, seqlen_q, head_dim]
K_slc [batch, seqlen_q, block_size * select_top_k, n_head_kv * dim]
V_slc [batch, seqlen_q, block_size * select_top_k, n_head_kv * dim]
O [batch, nhead_q, seqlen_q, head_dim] --> [batch, seqlen_q, n_head_q * head_dim]

p_slc = p_cmp.sum(dim = 1) # 在head维度上进行合并
select_top_k = 2
_, idx = torch.topk(p_slc, dim = 2, k = select_top_k)
print(idx[0,0,:]) # [2, 8] 即 q0注意到第2片段和第8片段
idx.shape # [1, 32, 2] : batch_size, q_len, top_k

idx_slc_start = idx * d
idx_slc_end = idx * d + l
K_slc = torch.randn(batch_size, t, d * select_top_k, dim)
V_slc = torch.randn(batch_size, t, d * select_top_k, dim)
for i in range(batch_size):
    for j in range(t):
        for k in range(select_top_k):
            K_slc[i, j, k * d : k * d + l, :] = K[i, idx_slc_start[i, j, k ] :  idx_slc_end[i, j, k ] , :]
            V_slc[i, j, k * d : k * d + l, :] = V[i, idx_slc_start[i, j, k ] :  idx_slc_end[i, j, k ] , :]
print(K_slc.shape) # bs, seq_len, select_kv, dim, 1,32,16,16, 不同qt选到不同的select_kv
print(V_slc.shape) # bs, seq_len, select_kv, dim  1,32,16,16, 不同qt选到不同的select_kv

# shared head KV
# IN GQA Group: [1-head KV & N-head Q] ----repeat kv-head---> [N-head KV & N-head Q]

V_slc_mha = V_slc.view(batch_size, t, select_top_k * d, heads, head_dim).transpose(2,3)
V_slc = V_slc_mha.sum(dim = 2, keepdim = True)
print(V_slc.shape) # bs, seq_len, head, select_seq_len, head_dim

K_slc_mha = K_slc.view(batch_size, t, select_top_k * d, heads, head_dim).transpose(2,3)
K_slc = K_slc_mha.sum(dim = 2, keepdim = True)
print(V_slc.shape) # bs, seq_len, head, select_seq_len, head_dim

o_slc = torch.zeros(batch_size, t, dim)
for j in range(t):
    Q_slc_j = Q_mha[:, :, j, :].unsqueeze(dim = 2)
    K_slc_j = K_slc[:, j, :, :, :].repeat(1, heads, 1, 1)
    V_slc_j = V_slc[:, j, :, :, :].repeat(1, heads, 1, 1)
    
    attn_score_j = Q_slc_j @ K_slc_j.transpose(2,3)
    p_slc_j = F.softmax(attn_score_j, dim = -1) 
    # print(p_slc.shape)

    o_slc_j = p_slc_j @ V_slc_j # bs, seq, dim   
    # print(o_slc_j.shape)

    o_slc_j = o_slc_j.transpose(1,2).view(batch_size, 1, dim)
    o_slc[:, j, :] = o_slc_j
print(o_slc.shape)

Sliding Window

窗口注意力是捕捉与当前 q 最近的 kv 片段,即越相近的 KV 就越重要。

在注意力机制中,局部模式通常适应得更快,并且在学习过程中占主导地位,这可能导致模型难以有效地从压缩和选择 token 中学习。为了解决这个问题,作者引入了一个专门的滑动窗口分支,它显式地处理局部上下文,从而使其他分支能够专注于学习各自的特征,而不会被局部模式(shortcut)所干扰。

具体而言,作者维护了最近的 token 序列:
\tilde{K}_{t}^{\text {win }}=k_{t-w: t}, \quad \tilde{V}_{t}^{\text {win }}=v_{t-w: t}
即在一个窗口大小为 w 的范围内,保留最近的 w 个键值对。注意力计算被划分为三个独立的分支:压缩分支、选择分支与滑动窗口分支。每个分支独立计算后,通过一个学习得到的门控机制进行聚合。

# built sliding window attention
def get_window_mask(seq_len, window):
    mask = torch.ones(seq_len, seq_len, dtype = torch.long)
    mask = torch.tril(mask)
    win_mask = -torch.ones(seq_len - window, seq_len - window, dtype = torch.long)
    win_mask =  torch.tril(win_mask)
    mask[window:, :seq_len - window] += win_mask
    return mask
print(get_window_mask(7, 3)) # test
window_mask = get_window_mask(t, 8)

压缩率 : (seqlen / block_size + top_k * block_size + w ) / seqlen_kv

Gated

为防止不同分支之间出现“梯度捷径”并增加稳定性,作者为三个分支提供了独立的键和值。这种架构设计在不显著增加计算开销的情况下,有助于减少局部与长程模式识别之间的梯度干扰,从而实现更稳定的训练。

在获得三类键和值(压缩分支、选择分支、滑动窗口分支):
\left(\tilde{K}_{t}^{\text {cmp }}, \tilde{V}_{t}^{\text {cmp }} ; \tilde{K}_{t}^{\text {slc }}, \tilde{V}_{t}^{\text {slc }} ; \tilde{K}_{t}^{\text {win }}, \tilde{V}_{t}^{\text {win }}\right)
最终注意力输出按照公式 \mathbf{o}_{t}^{*}=\sum_{c \in C} g_{t}^{c} \cdot \operatorname{Attn}\left(\mathbf{q}_{t}, \tilde{K}_{t}^{c}, \tilde{V}_{t}^{c}\right)

W_gated = torch.randn(dim, 3) # mlp, dim->3: cmp, slc, win
gate = X @ W_gated
gate = F.sigmoid(gate) # sigmoid activation
print(gate.shape) # 1, 32, 3 , bs, q_len, gated

o_list = [o_cmp, o_slc, o_win]
o_star = torch.zeros(batch_size, t, dim)
for i in range(3):
    o_star += gate[:, :, i].unsqueeze(2) * o_list[i]
print(o_star.shape)

计算。结合前文介绍的压缩、选择与滑动窗口机制,这三者共同构成了 NSA(Natively Sparse Attention)算法框架的完整结构。

Kernel Design

Token Compression 和 Sliding Window 分支可以直接复用 FlashAttention-2。

但是 Token Selection 的 KV 不连续的内存访问模式,直接使用 FlashAttention-2 性能较低。如果直接采用 FA 策略,将连续的 query tile 加载入 SRAM,会导致内存访问效率低下,因为同一个 query tile 内的不同 query 可能需要访问不同的 KV 块。

为解决这一问题,该文提出了不同的 query 分组策略:
对于查询序列中的每个位置,我们将同一个 GQA 组内的所有查询头(这些头共享相同的稀疏 KV 块)一起加载到 SRAM 中。

  • 以分组为中心的数据加载(Group-Centric Data Loading)
    在每个内部循环(inner loop)中,加载位于位置 t 的一个组(group)中所有头(heads)的查询向量 Q \in \mathbb{R}^{[h, d_k]},以及它们共享的稀疏键/值(key/value)块索引集合 \mathcal{I}_t

  • 共享 KV 读取(Shared KV Fetching)
    在内部循环中,顺序地将由索引集 \mathcal{I}_t 标识的连续键/值块加载到 SRAM 中,即

    K \in \mathbb{R}^{[B_k, d_k]}, \quad V \in \mathbb{R}^{[B_k, d_v]},

    以最小化内存加载开销。这里,B_k 表示满足块大小约束 B_k \lVert l' \rVert 的 kernel block 尺寸。

  • 网格级外循环(Outer Loop on Grid)
    由于内部循环的长度(与选定的块数量 n 成比例)对于不同查询块几乎相同,我们将查询/输出循环分配给 Triton 的网格调度器(grid scheduler) 来统一管理,从而简化并优化 kernel 执行。

image

  • 算法
    429 引用 • 254 回帖 • 24 关注
5 操作
Lucre 在 2025-11-14 15:33:12 更新了该帖
Lucre 在 2025-11-14 13:32:55 更新了该帖
Lucre 在 2025-11-14 11:51:16 更新了该帖
Lucre 在 2025-11-14 11:37:34 更新了该帖 Lucre 在 2025-11-14 10:42:46 更新了该帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...