Decoder 的基本结构
Decoder 的核心任务是自回归预测:给定左侧上下文,预测下一个 token。
\[
P(t_0,\ldots,t_{N-1})
=
\prod_{i=0}^{N-1}P(t_i\mid t_{<i}).
\]
GPT 和教学版 NNQS 中的 AmplitudeTransformer 都属于这个思路。区别在于:语言模型预测自然语言 token,NNQS 预测 orbital pair token。
输入是什么
输入 token id 先经过 embedding 与位置编码:
\[
T\in\mathbb{N}^{B\times N}
\quad\longrightarrow\quad
X\in\mathbb{R}^{B\times N\times d_{\rm model}}.
\]
这里:
- \(B\):batch size。
- \(N\):序列长度。
- \(d_{\rm model}\):每个 token 的隐藏维度。
进入 decoder block 的 \(X\) 已经是一组连续向量,不再是离散 token id。
一个 Decoder Block
一个 decoder block 可以看成:
masked self-attention
-> residual + layer norm
-> feed-forward network
-> residual + layer norm
常见现代实现使用 Pre-LN:
\[
Y=X+\mathrm{MHA}(\mathrm{LayerNorm}(X)),
\]
\[
Z=Y+\mathrm{FFN}(\mathrm{LayerNorm}(Y)).
\]
其中 \(X,Y,Z\) 的形状都相同:
\[
X,Y,Z\in\mathbb{R}^{B\times N\times d_{\rm model}}.
\]
这也是 Transformer block 容易堆叠的原因:每一层输入输出形状保持一致。
Masked Self-Attention
Decoder 的 self-attention 与普通 self-attention 的差别在 mask。第 \(i\) 个位置只能读取 \(j\le i\) 的 token:
\[
\tilde S_{ij}
=
\begin{cases}
S_{ij}, & j\le i,\\
-\infty, & j>i.
\end{cases}
\]
其中:
\[
S={QK^{\mathsf T}\over\sqrt{d_h}}.
\]
softmax 后,未来位置的权重为 \(0\),所以第 \(i\) 个位置的输出只依赖左侧上下文。
多头与输出投影
设有 \(h\) 个 head,每个 head 的维度为:
\[
d_h={d_{\rm model}\over h}.
\]
每个 head 独立计算:
\[
H_a=\mathrm{Attention}(Q_a,K_a,V_a),
\qquad
H_a\in\mathbb{R}^{B\times N\times d_h}.
\]
拼接后:
\[
H=\mathrm{Concat}(H_1,\ldots,H_h)
\in\mathbb{R}^{B\times N\times d_{\rm model}}.
\]
再经过输出投影:
\[
\mathrm{MHA}(X)=HW^O,
\qquad
W^O\in\mathbb{R}^{d_{\rm model}\times d_{\rm model}}.
\]
所以 attention 子层输出仍然是:
\[
\mathrm{MHA}(X)\in\mathbb{R}^{B\times N\times d_{\rm model}}.
\]
FFN 的角色
Attention 负责让不同位置交换信息。FFN 对每个位置独立做非线性变换:
\[
\mathrm{FFN}(x)
=
\sigma(xW_1+b_1)W_2+b_2.
\]
其中:
\[
W_1\in\mathbb{R}^{d_{\rm model}\times d_{\rm ff}},
\qquad
W_2\in\mathbb{R}^{d_{\rm ff}\times d_{\rm model}}.
\]
常见选择是:
\[
d_{\rm ff}\approx 4d_{\rm model}.
\]
FFN 不混合不同 token 位置;它只把每个位置已经聚合到的上下文信息再加工一次。
输出 Logits
堆叠 \(L\) 个 decoder block 后得到:
\[
Z_L\in\mathbb{R}^{B\times N\times d_{\rm model}}.
\]
输出层把每个位置映射到词表:
\[
\mathrm{logits}=Z_LW_{\rm vocab}+b_{\rm vocab},
\]
其中:
\[
W_{\rm vocab}\in\mathbb{R}^{d_{\rm model}\times V}.
\]
于是:
\[
\mathrm{logits}\in\mathbb{R}^{B\times N\times V}.
\]
第 \(i\) 个位置的 logits 用来预测下一个 token。训练时常写成:
输入: [BOS, t0, t1, ..., t_{N-2}]
目标: [t0, t1, t2, ..., t_{N-1}]
训练:并行但不能看未来
训练时可以把整段序列一次送入模型。虽然每个位置通过 causal mask 看不到未来,所有位置的矩阵乘法仍然可以并行计算。
这就是 decoder-only Transformer 的一个妙处:
所以训练时通常不需要 KV cache。
推理:逐 Token 生成
推理或采样时,模型不能一次知道完整目标序列,只能一步一步生成:
[BOS]
-> t0
[BOS, t0]
-> t1
[BOS, t0, t1]
-> t2
如果每一步都重算整个前缀,历史 token 的 key/value 会被重复计算。KV cache 用来缓存每一层的历史 \(K,V\),下一步只计算新 token 的 \(Q,K,V\),再让新 query 读取历史 cache。
单步推理时,新输入形状为:
\[
X_{\rm new}\in\mathbb{R}^{B\times 1\times d_{\rm model}}.
\]
历史 cache 形状为:
\[
K_{\rm cache},V_{\rm cache}
\in
\mathbb{R}^{B\times h\times n\times d_h}.
\]
新 token 产生:
\[
Q_{\rm new},K_{\rm new},V_{\rm new}
\in
\mathbb{R}^{B\times h\times 1\times d_h}.
\]
拼接后用:
\[
K_{\rm all},V_{\rm all}
\in
\mathbb{R}^{B\times h\times (n+1)\times d_h}
\]
计算当前 token 的 attention 输出。完整细节见 KV Cache。
对 NNQS 的对应
在教学版 NNQS 中,decoder-only Transformer 建模 pair token 的条件概率:
\[
P_\theta(t_i\mid t_{<i}).
\]
把所有位置的 log probability 相加:
\[
\log P_\theta(x)
=
\sum_i\log P_\theta(t_i\mid t_{<i}).
\]
由于采样概率对应波函数模方:
\[
P_\theta(x)=|\psi_\theta(x)|^2,
\]
振幅部分为:
\[
\log A_\theta(x)
=
{1\over 2}\sum_i\log P_\theta(t_i\mid t_{<i}).
\]
因此 decoder 的 causal mask 对 NNQS 很关键:它保证每个 orbital pair token 的概率只依赖左侧已经生成的 pair token。