RNN、GRU 与 LSTM¶
RNN, recurrent neural network,是处理序列的经典模型。它的核心思想是:
把前面读过的 prefix 压缩进一个递推隐藏态。
给定输入序列:
普通 RNN 每一步更新:
这里 \(h_t\) 是 hidden state。它应该包含从 \(x_0\) 到 \(x_t\) 的历史信息。
RNN 和 decoder-only Transformer 的对应¶
decoder-only Transformer 做的是:
RNN 也可以做同样的自回归任务:
然后:
所以 RNN、LSTM、GRU、decoder-only Transformer 都可以建模:
区别在于 prefix 怎么表示:
| 模型 | prefix 表示方式 |
|---|---|
| RNN | 压缩成一个 hidden state |
| LSTM / GRU | 用门控 hidden state 记住更长信息 |
| Transformer | 每个位置通过 attention 直接读取历史 token |
普通 RNN 的问题¶
普通 RNN 每一步都把旧信息和新输入混合成新 \(h_t\)。这会带来两个问题:
- 长期信息容易被覆盖。
- 反向传播时梯度容易消失或爆炸。
从反向传播角度看,梯度要反复乘上类似 \(W_h\) 的矩阵。时间步很长时,乘积可能趋近 0,也可能爆炸。
这就是 LSTM 和 GRU 出现的背景:用门控机制控制信息保留、写入和输出。
LSTM 的核心¶
LSTM 每一步有两个状态:
每个时间步接收:
输出:
LSTM 的核心是三个门和一个候选记忆:
| 符号 | 名字 | 作用 |
|---|---|---|
| \(f_t\) | forget gate | 决定旧记忆保留多少 |
| \(i_t\) | input gate | 决定新信息写入多少 |
| \(\tilde c_t\) | candidate cell | 候选新记忆内容 |
| \(o_t\) | output gate | 决定把多少内部记忆暴露出去 |
完整公式:
最关键的是:
它不是每一步完全重写记忆,而是在旧记忆上做加法式更新。这条 cell state 路径可以看成一种记忆通道。
LSTM 的直观比喻¶
可以把 LSTM 想成一个会记笔记的人。每读一个 token,它问三件事:
对应关系:
| 问题 | LSTM 组件 |
|---|---|
| 忘多少旧信息 | forget gate \(f_t\) |
| 写多少新信息 | input gate \(i_t\) |
| 新内容是什么 | candidate memory \(\tilde c_t\) |
| 输出多少记忆 | output gate \(o_t\) |
GRU¶
GRU, gated recurrent unit,是比 LSTM 更简洁的门控 RNN。它没有单独的 cell state,而是直接维护 hidden state。
常见公式:
其中:
| 门 | 作用 |
|---|---|
| update gate \(z_t\) | 控制保留旧 hidden 还是写入新 hidden |
| reset gate \(r_t\) | 控制计算候选状态时看多少旧 hidden |
GRU 参数更少,计算更轻;LSTM 结构更显式,长期记忆通道更清楚。
自回归 RNN 的 PyTorch 写法¶
下面用 nn.GRU 写一个最小自回归模型。输入 token 形状:
输出:
其中 logits[:, i, :] 预测 tokens[:, i]。
下面代码里的 gather 用来从词表维取出真实 token 对应的 log probability。这个 API 的 shape 规则见 常用 Tensor 操作:索引、gather、拼接、拆分与矩阵乘法。
import torch
import torch.nn as nn
import torch.nn.functional as F
class RNNAutoregressive(nn.Module):
def __init__(self, n_tokens: int, d_model: int, n_layers: int = 1):
super().__init__()
self.n_tokens = n_tokens
self.start_token = n_tokens
self.embedding = nn.Embedding(n_tokens + 1, d_model)
self.rnn = nn.GRU(
input_size=d_model,
hidden_size=d_model,
num_layers=n_layers,
batch_first=True,
)
self.output = nn.Linear(d_model, n_tokens)
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
batch, _ = tokens.shape
start = torch.full(
(batch, 1),
self.start_token,
dtype=tokens.dtype,
device=tokens.device,
)
model_input = torch.cat([start, tokens[:, :-1]], dim=1)
x = self.embedding(model_input)
hidden, _ = self.rnn(x)
return self.output(hidden)
def log_prob(self, tokens: torch.Tensor) -> torch.Tensor:
logits = self.forward(tokens)
logp_all = F.log_softmax(logits, dim=-1)
chosen = logp_all.gather(
dim=-1,
index=tokens.unsqueeze(-1),
).squeeze(-1)
return chosen.sum(dim=-1)
生成时缓存什么¶
Transformer 推理时缓存历史 K/V:
RNN 更简单,只缓存一个 hidden state:
所以 RNN 生成时很省缓存,但表达能力受限于 hidden state 容量。
| 模型 | 生成时缓存 |
|---|---|
| RNN / GRU / LSTM | hidden state,LSTM 还缓存 cell state |
| Transformer | 每层每个历史 token 的 key/value |
| PixelCNN | 通常不方便高效缓存 |
| RBM | 不按顺序生成,常用 Gibbs sampling |
LSTM 和 Transformer 的区别¶
LSTM 是递推结构:
所以它很难在时间维完全并行:
Transformer 的 self-attention 可以让所有 token 同时计算,并通过 causal mask 保证不看未来:
对比:
| 模型 | 优点 | 缺点 |
|---|---|---|
| RNN | 简单,天然序列递推 | 长期依赖弱,难并行 |
| LSTM / GRU | 门控缓解长期依赖 | 仍然顺序计算,超长依赖有限 |
| Transformer | 并行性强,长距离路径短 | attention 复杂度高,需要位置编码和 KV cache |
历史上可以粗略看成: