可训练动态掩码稀疏注意力

(arxiv 2025)

Trainable Dynamic Mask Sparse Attention

Smalldoge出品,该组织专注于小型语言模型,专注于效率易用性

现有方法

现有方法要么是利用softmax注意力的稀疏性。

如《From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification》提出了Sparsemax,即$Sparsemax(x)=relu(x-\lambda(x))$。

其中$\lambda(x)$是使得p的各分量之和为1的常数。

其实softmax能写为$softmax(x)=exp(x-\lambda(x))$,而sparsemax正是利用了$exp(x)\approx relu(1+x)$。

或者要么是利用长内容的稀疏性来选择性计算,比如滑动注意力和之前介绍过的NSA。

另一方面,也有从KV入手,比如之前attention sinks介绍过多那一些。

为了解决之前方法的限制,必须利用自注意力稀疏性进行必要计算和利用长上下文稀疏性进行选择性计算。

算法

内容感知动态稀疏掩码
$$
\delta=exp(\tau(v(\Delta)*A))
$$
其中$\Delta$是可学习的采样权重矩阵,A是门控参数,$\tau$是非负函数,这里是有点是softplus,类似VIB中的处理。

再应用 Top-k 选择和标准的因果掩码$m_t^c$:
$$
m_t=f(top_w(\delta+m_t^c))
$$

位置感知的稀疏注意力权重

然后就是类似标准的因果掩码:
$$
o_t=softmax(\frac{q_tk^T}{\sqrt{d_h}}+m_t)v
$$
由于被掩码位置的注意力权重注定为 0,因此在计算 QKT 点积时,可以完全跳过这些位置的计算 。这使得计算复杂度从标准注意力的$O(n^2d_h)$ 降低到 $O(nwd_h)$,其中 n 是序列长度,w 是保留的窗口大小,$d_h$是头维度 。

这部分是通过硬件和类似flash attention的分块操作处理的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def dynamic_mask_attention(h_t, position_embeddings, causal_m, past_key_value,
W_Q, W_K, W_V, W_dt, A, W_O,
num_heads, scaling, keep_window_size):
input_shape = h_t.shape[:-1]
# [b, q_len]
hidden_shape = (*input_shape,-1, h_t.shape[-1] // num_heads)
# linear projections
q_t = W_Q(h_t).view(hidden_shape).transpose(1, 2)
k_t = W_K(h_t).view(hidden_shape).transpose(1, 2)
v_t = W_V(h_t).view(hidden_shape).transpose(1, 2)
o_t = torch.zeros_like(q_t)
# [b, n_h, q_len, d_h]
# [b, n_h, q_len, d_h]
# [b, n_h, q_len, d_h]
# [b, n_h, q_len, d_h]
# apply rotary position embeddings
q_t, k_t = apply_rotary_pos_emb(q_t, k_t, *position_embeddings)
# concatenate past key and value states
k, v = past_key_value.update(k_t, v_t)
# [b, n_h, k_len, d_h]
# calculate dynamic mask
dt = W_dt(v.transpose(1, 2).reshape(v.shape[0], v.shape[-2],-1)) # [b, k_len, n_h]
dt = torch.exp(A * F.softplus(dt)).transpose(-1,-2)
# [b, n_h, k_len]
m_t = dt[:, :, None, :].expand(-1,-1, h_t.shape[1],-1)
# [b, n_h, q_len, k_len]
active_m = torch.zeros_like(m_t)
m_t = m_t.masked_fill(causal_m != 0,-float('inf'))
topk_indices = torch.topk(m_t, keep_window_size, dim=-1, sorted=False).indices
active_m = active_m.scatter(-1, topk_indices, 1.0)
m_t = m_t.masked_fill(active_m == 0.0,-float('inf'))
# calculate sparse attention weight
for b_idx in range(hidden_shape[0]):
for h_idx in range(num_heads):
for q_idx in range(hidden_shape[1]):
q_elem = q_t[b_idx, h_idx, q_idx, :]
indices = topk_indices[b_idx, h_idx, q_idx]
k_vecs = k[b_idx, h_idx, indices, :]
v_vecs = v[b_idx, h_idx, indices, :]
a_elem = torch.sum(q_elem.unsqueeze(0) * k_vecs, dim=-1)
a_elem = a_elem * scaling + m_t[b_idx, h_idx, q_idx, indices]
a_elem = F.softmax(a_elem, dim=-1)
o_elem = torch.sum(a_elem.unsqueeze(1) * v_vecs, dim=0)
o_t[b_idx, h_idx, q_idx, :] = o_elem
o_t = o_t.transpose(1, 2).contiguous()
h_t = W_O(o_t.view(*input_shape,-1))
return h_t

可训练动态掩码稀疏注意力
https://lijianxiong.space/2025/20250805/
作者
LJX
发布于
2025年8月5日
许可协议