linear_attention 是把 Mamba2 的 gating 机制 和 DeltaNet 的 delta rule 结合起来形成的新结构
先简单分析一下 线性注意力是什么
Overview
从 Transformer 视角
标准 causal self-attention 的一层,大致是:
$$ Q = XW_Q,\quad K = XW_K,\quad V = XW_V $$$$ A = \text{softmax}(QK^\top + \text{mask}),\quad O = AV $$对第 $t$ 个 token 来说,本质上是在做:
$$ o_t = \sum_{i \le t} \alpha_{t,i} v_i $$O(n, t) 中的每一行是把 attn_map 中的对应行作为权重,对 value 矩阵的行进行线性组合(加权)
也就是:当前 token 用 query 去和历史所有 key 打分,再把历史 value 加权求和。这给了 Transformer 很强的“按内容检索历史”的能力,但代价是训练时通常要处理完整的 token-token 交互矩阵。Mamba2、DeltaNet、Gated DeltaNet 这一类工作,核心目标就是:尽量保留这种“从历史取信息”的能力,但不要每次都真的显式看全部历史 token。
如果用一句工程化的话说:
- Transformer:历史以 KV cache / attention matrix 的形式存在
- Gated DeltaNet:历史以 compressed recurrent state 的形式存在
线性 Attention 视角
在线性 attention 里,你会把输出写成某种“特征映射后”的形式,使得历史可以压缩进一个状态矩阵 $S_t$。一个最经典的形态是:
$$ S_t = S_{t-1} + v_t k_t^\top $$$$ o_t = S_t q_t $$这里不要太纠结公式细节,重点是这个结构说明了一件事:
不需要保存所有历史 token; 只需要维护一个不断更新的状态 $S_t$,当前输出就能由 $S_t$ 和当前 query 算出来。
所以它看起来更像一个 矩阵状态的 RNN。作者在 DeltaNet 相关解释里也明确强调:linear attention 本质上就是带矩阵状态的 linear RNN。
你可以把它和 Transformer 对比成:
- Transformer
history = [(k1,v1), (k2,v2), ..., (kt,vt)]
output_t = attend(q_t, history)
- Linear-attention style
state_t = update(state_{t-1}, k_t, v_t)
output_t = read(state_t, q_t)
DeltaNet 和 Gated DeltaNet 都是在这个“state update / state read”框架里工作
DeltaNet 在改什么
如果最简单的 linear attention 是:
$$ S_t = S_{t-1} + v_tk_t^\top $$那它有一个问题:更新太“笨”。
因为这个更新更像是“把新记忆一直往状态里累加”。这样做效率高,但在复杂检索、覆盖旧记忆、冲突信息替换这些场景里,能力往往不够强。
Gated DeltaNet 论文对 DeltaNet 的定位就是:它用 delta rule 来做更精细的记忆修改,从而提升 retrieval 和长上下文能力
delta rule 的核心想法不是“直接把 $v_t k_t^\top$ 加进去”,而是:
- 先看看当前 state 对这个 key 已经会“读出”什么
- 再计算“我真正想写进去的 value”和“当前 state 已经能给出的 value”之间的差
- 只把这个差值写进去
也就是类似:
$$ \text{pred}_t = S_{t-1} k_t $$$$ \Delta_t = v_t - \text{pred}_t $$$$ S_t = S_{t-1} + \Delta_t k_t^\top $$如果 Transformer 的直觉是:
“query 去所有历史里检索最相关内容”
那 DeltaNet 更像:
“历史不是按 token 存,而是被编码进一个大 memory matrix; 当前 token 通过 query 从这个 matrix 里读; 新 token 通过 delta-rule 对这个 matrix 做纠偏式更新。”
所以 DeltaNet 的“像 attention”的地方在于: 它仍然有 key / value / query,仍然有 content-based read/write。 但它不是显式对每个历史 token 打 attention score,而是对一个 压缩状态 做读写。
Mamba2 在改什么
最好把 Mamba2 理解成:
一种把序列处理写成“状态递推”的模型层,重点不是显式 attention,而是让状态在时间上流动,并通过可学习机制控制信息保留与遗忘。
Mamba2 这条线很强调 gating / selective memory control。 也就是每来一个 token,不是无脑更新状态,而是会学到:
- 旧状态保留多少
- 当前 token 写入多少
- 哪些通道更该忘
- 哪些通道更该保留
所以如果你硬要用最粗略的话描述:
- 普通 RNN
state_t = f(state_{t-1}, x_t)
- Mamba2 风格
gate_t = g(x_t)
state_t = gate_t * old_part + (1 - gate_t) * new_part
output_t = read(state_t)
Gated DeltaNet
现在你已经有两块积木了:
- DeltaNet:擅长“精准改写记忆”
- Mamba2:擅长“门控地保留/擦除记忆”
那 Gated DeltaNet 的想法就很自然:
把 delta-rule 的“纠偏式写入” 和 Mamba2 风格的“门控遗忘/控制” 合在同一个 recurrent memory update 里。
论文原话概括就是:
- gating enables rapid memory erasure
- delta rule facilitates targeted updates
用非严格但有帮助的写法:
DeltaNet 风格
$$ S_t = S_{t-1} + (v_t - S_{t-1} k_t) k_t^\top $$Gated DeltaNet 风格
$$ \tilde{S}_{t-1} = g_t \odot S_{t-1} $$$$ \text{pred}_t = \tilde{S}_{t-1} k_t $$$$ \Delta_t = v_t - \text{pred}_t $$$$ S_t = \tilde{S}_{t-1} + \eta_t \cdot \Delta_t k_t^\top $$这里你只需要读懂语义:
$g_t$:先对旧记忆做门控,决定保留多少$\text{pred}_t$:看看当前记忆对这个 key 的“已有回答”$\Delta_t$:和目标 value 的差$\eta_t$:控制这次写入有多强
也就是说,Gated DeltaNet 不是“先忘后重写”那么粗糙,而是:
- 先门控旧状态
- 再计算当前记忆对这条 key 的预测
- 最后按差值做低秩修正写入
与 Transformer Q/K/V 重新对齐
最好把 Gated DeltaNet 看成 “QKV still exists, but attention map disappears”
- Transformer Layer
x_t
├─> q_t
├─> k_t
└─> v_t
scores_t = q_t @ K_past^T
weights_t = softmax(scores_t)
o_t = weights_t @ V_past
- Gated DeltaNet Layer
x_t
├─> q_t
├─> k_t
├─> v_t
└─> gate/update params
state_{t-1} --(gate)--> gated_state
pred_t = gated_state @ k_t
delta_t = v_t - pred_t
state_t = gated_state + write(delta_t, k_t)
o_t = read(state_t, q_t)
注意这里非常关键的一点:
- 在 Transformer 里,read 是通过
softmax(q_t K^T) V - 在 Gated DeltaNet 里,read 是通过
state_t @ q_t或同类线性读出
所以两者都在“query-based retrieval”,但检索对象不同:
- Transformer 检索的是 token list
- Gated DeltaNet 检索的是 compressed matrix state
这就是它与 Transformer 的本质差异
Workflow
先给出一个单头版本:
import torch
def batch_matvec(S, x):
# S: [B, Dv, Dk]
# x: [B, Dk]
return torch.einsum('bvk,bk->bv', S, x)
def batch_outer(a, b):
# a: [B, Dv]
# b: [B, Dk]
return torch.einsum('bv,bk->bvk', a, b)
def gate_memory(S, gate):
# simplest scalar gate version
# S: [B, Dv, Dk]
# gate: [B, 1]
return S * gate[:, None, None]
def gated_deltanet_forward(x, S, params):
# x: [B, T, D]
# S: [B, Dv, Dk]
h = rmsnorm(x) # [B, T, D]
q = h @ params.Wq # [B, T, Dk]
k = h @ params.Wk # [B, T, Dk]
v = h @ params.Wv # [B, T, Dv]
gate = torch.sigmoid(h @ params.Wg_state) # [B, T, 1]
eta = torch.sigmoid(h @ params.Weta) # [B, T, 1]
outputs = []
S_cur = S
for t in range(x.shape[1]):
q_t = q[:, t, :] # [B, Dk]
k_t = k[:, t, :] # [B, Dk]
v_t = v[:, t, :] # [B, Dv]
g_t = gate[:, t, :] # [B, 1], 衰减系数, 0~1 之间
e_t = eta[:, t, :] # [B, 1]
# 遗忘:旧记忆乘以衰减系数
S_gated = gate_memory(S_cur, g_t) # [B, Dv, Dk]
# 检索: 用 k_t 从记忆本里查一下现在存的值是什么
pred = batch_matvec(S_gated, k_t) # [B, Dv]
# Delta(纠错):看看记忆里的值和真实 v_t 差多少
delta = v_t - pred # [B, Dv]
write = batch_outer(delta, k_t) # [B, Dv, Dk]
# e_t 是写入门控
S_new = S_gated + write * e_t[:, None, None] # [B, Dv, Dk]
o_t = batch_matvec(S_new, q_t) # [B, Dv]
outputs.append(o_t)
S_cur = S_new
O = torch.stack(outputs, dim=1) # [B, T, Dv]
y = x + (O @ params.Wo) # [B, T, D]
return y, S_cur