跳转至

Decoder 的基本结构

Decoder 的核心任务是自回归预测:给定左侧上下文,预测下一个 token。

\[ P(t_0,\ldots,t_{N-1}) = \prod_{i=0}^{N-1}P(t_i\mid t_{<i}). \]

GPT 和教学版 NNQS 中的 AmplitudeTransformer 都属于这个思路。区别在于:语言模型预测自然语言 token,NNQS 预测 orbital pair token。

输入是什么

输入 token id 先经过 embedding 与位置编码:

\[ T\in\mathbb{N}^{B\times N} \quad\longrightarrow\quad X\in\mathbb{R}^{B\times N\times d_{\rm model}}. \]

这里:

  • \(B\):batch size。
  • \(N\):序列长度。
  • \(d_{\rm model}\):每个 token 的隐藏维度。

进入 decoder block 的 \(X\) 已经是一组连续向量,不再是离散 token id。

一个 Decoder Block

一个 decoder block 可以看成:

masked self-attention
  -> residual + layer norm
  -> feed-forward network
  -> residual + layer norm

常见现代实现使用 Pre-LN:

\[ Y=X+\mathrm{MHA}(\mathrm{LayerNorm}(X)), \]
\[ Z=Y+\mathrm{FFN}(\mathrm{LayerNorm}(Y)). \]

其中 \(X,Y,Z\) 的形状都相同:

\[ X,Y,Z\in\mathbb{R}^{B\times N\times d_{\rm model}}. \]

这也是 Transformer block 容易堆叠的原因:每一层输入输出形状保持一致。

Masked Self-Attention

Decoder 的 self-attention 与普通 self-attention 的差别在 mask。第 \(i\) 个位置只能读取 \(j\le i\) 的 token:

\[ \tilde S_{ij} = \begin{cases} S_{ij}, & j\le i,\\ -\infty, & j>i. \end{cases} \]

其中:

\[ S={QK^{\mathsf T}\over\sqrt{d_h}}. \]

softmax 后,未来位置的权重为 \(0\),所以第 \(i\) 个位置的输出只依赖左侧上下文。

多头与输出投影

设有 \(h\) 个 head,每个 head 的维度为:

\[ d_h={d_{\rm model}\over h}. \]

每个 head 独立计算:

\[ H_a=\mathrm{Attention}(Q_a,K_a,V_a), \qquad H_a\in\mathbb{R}^{B\times N\times d_h}. \]

拼接后:

\[ H=\mathrm{Concat}(H_1,\ldots,H_h) \in\mathbb{R}^{B\times N\times d_{\rm model}}. \]

再经过输出投影:

\[ \mathrm{MHA}(X)=HW^O, \qquad W^O\in\mathbb{R}^{d_{\rm model}\times d_{\rm model}}. \]

所以 attention 子层输出仍然是:

\[ \mathrm{MHA}(X)\in\mathbb{R}^{B\times N\times d_{\rm model}}. \]

FFN 的角色

Attention 负责让不同位置交换信息。FFN 对每个位置独立做非线性变换:

\[ \mathrm{FFN}(x) = \sigma(xW_1+b_1)W_2+b_2. \]

其中:

\[ W_1\in\mathbb{R}^{d_{\rm model}\times d_{\rm ff}}, \qquad W_2\in\mathbb{R}^{d_{\rm ff}\times d_{\rm model}}. \]

常见选择是:

\[ d_{\rm ff}\approx 4d_{\rm model}. \]

FFN 不混合不同 token 位置;它只把每个位置已经聚合到的上下文信息再加工一次。

输出 Logits

堆叠 \(L\) 个 decoder block 后得到:

\[ Z_L\in\mathbb{R}^{B\times N\times d_{\rm model}}. \]

输出层把每个位置映射到词表:

\[ \mathrm{logits}=Z_LW_{\rm vocab}+b_{\rm vocab}, \]

其中:

\[ W_{\rm vocab}\in\mathbb{R}^{d_{\rm model}\times V}. \]

于是:

\[ \mathrm{logits}\in\mathbb{R}^{B\times N\times V}. \]

\(i\) 个位置的 logits 用来预测下一个 token。训练时常写成:

输入: [BOS, t0, t1, ..., t_{N-2}]
目标: [t0,  t1, t2, ..., t_{N-1}]

训练:并行但不能看未来

训练时可以把整段序列一次送入模型。虽然每个位置通过 causal mask 看不到未来,所有位置的矩阵乘法仍然可以并行计算。

这就是 decoder-only Transformer 的一个妙处:

信息流是自回归的
计算图是并行的

所以训练时通常不需要 KV cache。

推理:逐 Token 生成

推理或采样时,模型不能一次知道完整目标序列,只能一步一步生成:

[BOS]
  -> t0
[BOS, t0]
  -> t1
[BOS, t0, t1]
  -> t2

如果每一步都重算整个前缀,历史 token 的 key/value 会被重复计算。KV cache 用来缓存每一层的历史 \(K,V\),下一步只计算新 token 的 \(Q,K,V\),再让新 query 读取历史 cache。

单步推理时,新输入形状为:

\[ X_{\rm new}\in\mathbb{R}^{B\times 1\times d_{\rm model}}. \]

历史 cache 形状为:

\[ K_{\rm cache},V_{\rm cache} \in \mathbb{R}^{B\times h\times n\times d_h}. \]

新 token 产生:

\[ Q_{\rm new},K_{\rm new},V_{\rm new} \in \mathbb{R}^{B\times h\times 1\times d_h}. \]

拼接后用:

\[ K_{\rm all},V_{\rm all} \in \mathbb{R}^{B\times h\times (n+1)\times d_h} \]

计算当前 token 的 attention 输出。完整细节见 KV Cache

对 NNQS 的对应

在教学版 NNQS 中,decoder-only Transformer 建模 pair token 的条件概率:

\[ P_\theta(t_i\mid t_{<i}). \]

把所有位置的 log probability 相加:

\[ \log P_\theta(x) = \sum_i\log P_\theta(t_i\mid t_{<i}). \]

由于采样概率对应波函数模方:

\[ P_\theta(x)=|\psi_\theta(x)|^2, \]

振幅部分为:

\[ \log A_\theta(x) = {1\over 2}\sum_i\log P_\theta(t_i\mid t_{<i}). \]

因此 decoder 的 causal mask 对 NNQS 很关键:它保证每个 orbital pair token 的概率只依赖左侧已经生成的 pair token。