Vanilla RNN 的缺陷:为什么我们需要 LSTM?
在上一篇中我们提到,RNN 的核心公式是 $h_t = \tanh(W_{xh} x_t + W_{hh} h_{t-1} + b_h)$。这种设计在处理长序列时会遇到两个工程上的致命问题:
- 梯度消失与梯度爆炸(Vanishing/Exploding Gradients)
- 在反向传播计算梯度时,误差需要沿着时间步反向传递。这会导致权重矩阵 $W_{hh}$ 被连乘 $t$ 次。根据线性代数原理,如果 $W_{hh}$ 的最大特征值小于 1,连乘后梯度会呈指数级衰减趋近于 0(梯度消失);如果大于 1,则会指数级放大(梯度爆炸)。梯度消失意味着网络根本无法学习到长距离的依赖关系。
- 信息覆盖(Information Overwrite)
- RNN 只有一个隐藏状态 $h_t$。在每一个时间步,新的输入 $x_t$ 都会强制与历史信息 $h_{t-1}$ 混合。没有任何机制能够保护早期非常重要但最近没有出现的信息。这就好比一个容量有限的栈,新数据不断涌入,旧数据很快就被冲刷掉了。
LSTM 的核心思想:分离状态与引入门控机制
为了解决上述问题,LSTM 对架构进行了大改,其核心创新在于:将内部状态拆分为两个,并引入了“门(Gates)”来进行精确的信息路由。
- 细胞状态 $c_t$ (Cell State): 这是 LSTM 的“主干道”或“长期记忆”。它在整个链条上贯穿运行,只有一些少量的线性交互。这种设计使得梯度可以通过 $c_t$ 顺畅地无损反向传播,直接解决了梯度消失问题。
- 隐藏状态 $h_t$ (Hidden State): 类似于 Vanilla RNN 的 $h_t$,作为“短期记忆”或当前时间步的输出。
- 门控机制 (Gating Mechanism): 门本质上是经过 Sigmoid 激活的全连接层。Sigmoid 的输出在 $[0, 1]$ 之间,用于控制信息的保留比例(0 代表完全丢弃,1 代表完全保留)。
数学工作流:LSTM 的四个核心步骤
对于第 $t$ 个 token,LSTM 内部执行以下运算。为了方便,我们通常将前一个隐藏状态 $h_{t-1}$ 和当前输入 $x_t$ 拼接在一起计算。
第一步:遗忘门 (Forget Gate) —— 决定丢弃什么历史信息
$$f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)$$$f_t$ 是一个与 $c_{t-1}$ 维度相同的向量。它的值在 0 到 1 之间,决定了长期记忆 $c_{t-1}$ 中哪些维度的数据应该被清空。
第二步:输入门 (Input Gate) —— 决定写入什么新信息 我们需要计算两部分:哪些维度需要更新($i_t$),以及新的候选内容是什么($\tilde{c}_t$)。
$$i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)$$$$\tilde{c}_t = \tanh(W_c \cdot [h_{t-1}, x_t] + b_c)$$$i_t$ 充当开关,$\tilde{c}_t$ 是当前时刻提取出的新特征。
第三步:更新细胞状态 (Update Cell State)
$$c_t = f_t * c_{t-1} + i_t * \tilde{c}_t$$这里的 $*$ 表示逐元素乘法,即 Hadamard Product
这是 LSTM 最精妙的一步:旧记忆 $c_{t-1}$ 乘以遗忘门(按比例遗忘),然后加上新特征 $\tilde{c}_t$ 乘以输入门(按比例写入)。这是一个纯线性的加法过程,梯度可以完美地沿着加号回传。
第四步:输出门 (Output Gate) —— 决定输出什么
$$o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)$$$$h_t = o_t * \tanh(c_t)$$根据当前输入和历史决定输出开关 $o_t$,然后将长期记忆 $c_t$ 通过 $\tanh$ 压缩到 $[-1, 1]$ 之间,乘以输出开关,得到当前时刻的隐藏状态 $h_t$。
代码与工作流解析
以下是标准 LSTM 单元的伪代码实现,展示了参数是如何定义以及前向传播是如何完成的。
import torch
import torch.nn as nn
class LSTMCell(nn.Module):
def __init__(self, input_size: int, hidden_size: int):
super().__init__()
self.hidden_size = hidden_size
# In practice, PyTorch fuses all 4 linear transformations into one large matrix
# for performance via parallel matrix multiplication.
# W shape: (input_size + hidden_size, 4 * hidden_size)
self.W_ih = nn.Linear(input_size, 4 * hidden_size)
self.W_hh = nn.Linear(hidden_size, 4 * hidden_size)
def forward(self, x_t: torch.Tensor, states: tuple[torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
# x_t shape: (batch_size, input_size)
# states = (h_{t-1}, c_{t-1}), each shape: (batch_size, hidden_size)
h_prev, c_prev = states
# Step 1: Compute all gate pre-activations simultaneously
# gates shape: (batch_size, 4 * hidden_size)
gates = self.W_ih(x_t) + self.W_hh(h_prev)
# Split the fused tensor into 4 chunks for f, i, o, and c_candidate
# Each chunk shape: (batch_size, hidden_size)
forget_gate, input_gate, output_gate, cell_candidate = gates.chunk(4, dim=1)
# Step 2: Apply non-linear activations
f_t = torch.sigmoid(forget_gate)
i_t = torch.sigmoid(input_gate)
o_t = torch.sigmoid(output_gate)
c_tilde = torch.tanh(cell_candidate)
# Step 3: Update Cell State (The core gradient superhighway)
# Element-wise operations
c_t = (f_t * c_prev) + (i_t * c_tilde)
# Step 4: Compute new Hidden State
h_t = o_t * torch.tanh(c_t)
# Return the updated states for the next time step or final output
return h_t, c_t
class LSTM(nn.Module):
def __init__(self, input_size: int, hidden_size: int):
super().__init__()
self.cell = LSTMCell(input_size, hidden_size)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
# x shape: (batch_size, seq_len, input_size)
batch_size, seq_len, _ = x.size()
# Initialize h_0 and c_0 with zeros
h_t = torch.zeros(batch_size, self.cell.hidden_size, device=x.device)
c_t = torch.zeros(batch_size, self.cell.hidden_size, device=x.device)
outputs = []
# Unroll over time (Sequential nature remains)
for t in range(seq_len):
x_t = x[:, t, :]
h_t, c_t = self.cell(x_t, (h_t, c_t))
outputs.append(h_t)
final_outputs = torch.stack(outputs, dim=1)
return final_outputs, (h_t, c_t)
总结
LSTM 通过引入复杂的门控机制和跨越时间步的残差连接(即 $c_t = c_{t-1} + \dots$ 的加法结构),极大地缓解了 RNN 的梯度消失问题,使得模型能够捕获更长距离的依赖。但正如代码中 for t in range(seq_len) 所示,它的本质仍然是基于马尔可夫链的时序迭代,这意味着它依旧无法像 Transformer 那样在训练阶段实现高度的并行计算。