MoBA vs NSA

Kimi公开了他们处理长文的秘密了。团队提出了MoBA (Mixture of Block Attention) ,解决了传统注意力机制在处理长文本时的效率问题。

DeepSeek 发布了一篇新论文,提出了一种改进版的注意力机制 NSA(Native Sparse Attention),加上还有创始人兼 CEO 梁文锋亲自参与。

由于MoBA比NSA更简单,于是我们循序渐进先介绍NSA。

这两个都主要对KV进行优化。

NSA

其中定义了三种注意力:压缩(cmp)、选择(sle)、滑窗(win),并最后使用门控来简单汇聚,$g^{cmp}o^{cmp}+g^{sle}o^{sle}+g^{win}o^{win}$。

压缩注意力

压缩注意力的本质是将一段序列的KV压成一个KV。

选择注意力

论文选择将这部分和压缩注意力结合起来。

1
2
3
4
5
6
7
p_slc = p_cmp.sum(dim = 1) # 在head维度上进行合并
print(p_cmp.shape) # torch.Size([1, 4, 32, 4])
print(p_slc.shape) # torch.Size([1, 32, 4])
select_top_k = 2
_, idx = torch.topk(p_slc, dim = 2, k = select_top_k)
print(idx[0,0,:]) # [3,0] 即 q0注意到第3片段和第0片段
idx.shape # [1, 32, 2] : batch_size, q_len, top_k

滑窗注意力

这部分是用来捕捉临近的kv片段。其实很简单。
$$
\tilde{K} _ {t} ^{w i n}={\bf k} _ {t-w : t}, \tilde{V} _ {t} ^{w i n}={\bf v} _ {t-w : t}
$$
值得注意的是,这部分可能会跨block。

汇聚

前面我们已经介绍过$g^{cmp}o^{cmp}+g^{sle}o^{sle}+g^{win}o^{win}$

它实际上使用了 MLP和 sigmoid,即

1
2
3
4
W_gated = torch.randn(dim, 3) # mlp, dim->3: cmp, slc, win
gate = X @ W_gated
gate = F.sigmoid(gate) # sigmoid activation
print(gate.shape) # 1, 32, 3 , bs, q_len, gated

MoBA

也正如苏神在知乎上所说,“如果读者对比过NSA和MoBA,估计都有种MoBA是NSA的简化版的感觉:NSA用了MLP来压缩block,MoBA直接用Mean Pooling;NSA用另一块压缩的Attention来学block select,MoBA直接去掉了这部分。不得不说,NSA的设计是更符合一般人的想法,如果由我自己独立来设计MoBA,估计最终形式会更像NSA,因为MoBA这种极致简化的做法则更需要一点勇气(以及长时间的尝试)。”

相比NSA,MoBA更简单,block select相当于更精细的内容,但MoBA把这部分去掉了。Mean Pooling代替MLP减少了许多参数量,这让我想到了global average pooling 中指出:”One advantage of global average pooling over the fully connected layers is that it is more native to the convolution structure. Another advantage is that there is no parameter to optimize in the global average pooling, thus overfitting is avoided at this layer.“

同样地,MoBA也把KV分为若干个block。

操作也很简单和明显,我们可以看出它具体做了什么,图中的内容无需再赘述。

此外,保持自回归语言模型的因果关系很重要。

MoBA设置了不能关注未来块,另外,整个块的平均池化可能会无意包含来自未来标记的信息。未来解决这些问题,它强制要求每个token必须路由到当前块,并应用因果掩码。

MoBA也有更多的灵活性,它的参数与full attention参数相比数量不变,不增不减。这一特性启发我们进行全注意力与 MoBA 之间的平滑过渡。具体来说,在初始化阶段,每个注意力层可以选择全注意力或 MoBA,如果需要,这个选择可以在训练过程中动态更改。

当然它并不是可免训练的、即插即用的,作者指出:“MoBA没有参数是不是拿来就可以在现有模型上用? MoBA 不是一个免训练 sparse attention,虽然是无额外参数的,但是依然需要对现有模型进行Continue Training。训练中关注Trailing token loss下降情况,或者直接关注 longctx 相关 bmk 涨点情况即可“

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def moba_attn_varlen_naive(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: int,
moba_chunk_size: int,
moba_topk: int,
) -> torch.Tensor:
"""Implement the moba brute-force setting for reference

Args:
q (torch.Tensor): [seqlen, head, head_dim]
k (torch.Tensor): [seqlen, head, head_dim]
v (torch.Tensor): [seqlen, head, head_dim]
cu_seqlens (torch.Tensor): the cumulative sequence length tensor, same definition in flash attn
max_seqlen (int): the max sequence length of the batch, same definition in flash attn

Returns:
attn_output (torch.Tensor): [seqlen, head, head_dim]
"""

# qkv shape = [ S, H, D ]
batch = cu_seqlens.numel() - 1
softmax_scale = q.shape[-1] ** (-0.5)

o = torch.zeros_like(q)
for batch_idx in range(batch):
batch_start = cu_seqlens[batch_idx].item()
batch_end = cu_seqlens[batch_idx + 1].item()
# get qkv of this batch
q_ = q[batch_start:batch_end]
k_ = k[batch_start:batch_end]
v_ = v[batch_start:batch_end]
o_ = o[batch_start:batch_end]
# calc key gate weight
key_gate_weight = []
batch_size = batch_end - batch_start
num_block = math.ceil(batch_size / moba_chunk_size)
for block_idx in range(0, num_block):
block_start = block_idx * moba_chunk_size
block_end = min(batch_size, block_start + moba_chunk_size)
key_gate_weight.append(k_[block_start:block_end].mean(dim=0, keepdim=True))
key_gate_weight = torch.cat(key_gate_weight, dim=0) # [ N, H, D ]
# calc & mask gate
# use fp32 to avoid precision issue in bf16
q_ = q_.type(torch.float32)
key_gate_weight = key_gate_weight.type(torch.float32)
gate = torch.einsum("shd,nhd->hsn", q_, key_gate_weight) # [ H, S, N ]
key_gate_weight = key_gate_weight.type_as(k)
q_ = q_.type_as(k)
for i in range(num_block):
# select the future Qs that can attend to KV chunk i
gate[:, : (i + 1) * moba_chunk_size, i] = float("-inf")
gate[:, i * moba_chunk_size : (i + 1) * moba_chunk_size, i] = float("inf")
# gate_top_k_idx = gate_top_k_val = [ H S K ]
gate_top_k_val, gate_top_k_idx = torch.topk(
gate, k=min(moba_topk, num_block), dim=-1, largest=True, sorted=False
)
gate_top_k_val, _ = gate_top_k_val.min(dim=-1) # [ H, S ]
need_attend = gate >= gate_top_k_val.unsqueeze(-1)
# add gate_idx_mask in case of there is cornercases of same topk val been selected
gate_idx_mask = torch.zeros(
need_attend.shape, dtype=torch.bool, device=q.device
)
gate_idx_mask = gate_idx_mask.scatter_(dim=-1, index=gate_top_k_idx, value=True)
need_attend = torch.logical_and(need_attend, gate_idx_mask)
gate[need_attend] = 0
gate[~need_attend] = -float("inf")
gate = gate.repeat_interleave(moba_chunk_size, dim=-1)[
:, :, :batch_size
] # [ H, S, S ]
gate.masked_fill_(
torch.ones_like(gate, dtype=torch.bool).tril().logical_not(), -float("inf")
)

# calc qk = qk^t
q_ = q_.type(torch.float32)
k_ = k_.type(torch.float32)
v_ = v_.type(torch.float32)
qk = torch.einsum("xhd,yhd->hxy", q_, k_)
# mask
qk += gate
qk *= softmax_scale
# calc o
p = qk.softmax(dim=-1)
o_ += torch.einsum("hxy,yhd->xhd", p, v_)
o = o.type_as(q)

return o

作者也介绍了一些心路历程,也值得一看。

MoBA VS NSA

它们都有和flash attention比较,也都达到了100%的大海捞针测试。

其中有一个有趣的是

MoBA的损失曲线一开始不如full attention,但后续逐渐重合。

而NSA则全面优于full attention。

另外,如果全用MoBA可能会有问题,所以后续需要加上几层(苏神说一层就足够)full attention。而NSA则没有这一问题。

参考资料:

  1. 【手撕NSA】DeepSeek新作-原生稀疏注意力-超长文(附代码)

MoBA vs NSA
https://lijianxiong.work/2025/20250222/
作者
LJX
发布于
2025年2月22日
许可协议