Transformer 是基于全局视角处理序列的,它通过 Self-Attention 一次性让所有 token 互相交互。而 RNN(Recurrent Neural Network,循环神经网络)的本质是基于时间步的顺序迭代。

如果把 Transformer 比作同时看到整段句子的“上帝视角”,RNN 则像是按照从左到右的顺序逐字阅读,并在脑海中维护一个不断更新的“记忆向量”。

RNN 的核心结构是一个循环单元。它在处理序列时,并不是一次性输入整个序列(如 [batch, seq_len, dim]),而是将序列沿着 seq_len 维度切开,在每一个时间步 $t$ 输入一个 token。

为了保留历史上下文,RNN 引入了隐藏状态 $h_t$。

  • $h_t$ 是一个固定维度的向量,它包含了从时间步 $0$ 到 $t$ 的所有历史信息压缩。
  • 在每一个时间步,RNN 接收两个输入:当前时刻的输入 $x_t$ 和 上一时刻的隐藏状态 $h_{t-1}$。
  • RNN 会使用同一套权重矩阵(参数共享)来融合这两个输入,生成新的 $h_t$。

对于序列中的第 $t$ 个 token,RNN 内部只做两步基本的线性变换和一次非线性激活:

  • 更新隐藏状态:

    $$h_t = \tanh(W_{xh} x_t + W_{hh} h_{t-1} + b_h)$$

    这里 $W_{xh}$ 负责投影当前输入,$W_{hh}$ 负责投影历史记忆。两者的结果相加后,通过 $\tanh$ 激活函数将数值压缩到 $[-1, 1]$ 之间,防止在长序列迭代中数值爆炸。

  • 计算当前输出(可选):

    $$y_t = W_{hy} h_t + b_y$$

    如果需要每个时间步都输出(比如序列标注),就用当前的 $h_t$ 映射出 $y_t$。如果只需要句向量(比如文本分类),则直接取最后一个时间步的 $h_n$ 即可。

伪代码实现

注意,这里的 for 循环是 RNN 架构的灵魂,也是它与 Transformer 最大的区别

import torch
import torch.nn as nn

class VanillaRNN(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, output_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        
        # Linear layer for current input token
        self.W_xh = nn.Linear(input_size, hidden_size)
        
        # Linear layer for previous hidden state (the recurrence mechanism)
        self.W_hh = nn.Linear(hidden_size, hidden_size)
        
        # Linear layer to generate final output from hidden state
        self.W_hy = nn.Linear(hidden_size, output_size)
        
        # Non-linear activation to bound the hidden state values
        self.tanh = nn.Tanh()

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        # x shape: (batch_size, seq_len, input_size)
        batch_size, seq_len, _ = x.size()
        
        # Step 1: Initialize the first hidden state h_0 with zeros
        # Shape: (batch_size, hidden_size)
        h_t = torch.zeros(batch_size, self.hidden_size, device=x.device)
        
        outputs = []
        
        # Step 2: Unroll the sequence over the time dimension
        # THIS is the core of RNN. It forces sequential computation.
        for t in range(seq_len):
            # Extract the token embedding at time step t
            x_t = x[:, t, :]  # Shape: (batch_size, input_size)
            
            # Step 3: Compute new hidden state h_t
            # Combine current input and previous history
            h_t = self.tanh(self.W_xh(x_t) + self.W_hh(h_t))
            
            # Step 4: Compute output for the current time step
            y_t = self.W_hy(h_t)
            outputs.append(y_t)
            
        # Stack all outputs along the sequence dimension
        # final_outputs shape: (batch_size, seq_len, output_size)
        final_outputs = torch.stack(outputs, dim=1)
        
        # Return both the sequence of outputs and the final hidden state
        return final_outputs, h_t

RNN vs Transformer

基于上述工作流,这两种架构在工程特性上有显著差异:

  • 计算图与并行性
    • Transformer:输入序列可以一次性通过矩阵乘法完成 Attention 计算,极其适合 GPU 的高度并行计算。
    • RNN:必须等待 $h_{t-1}$ 计算完毕,才能计算 $h_t$。这种时序依赖导致它在训练时无法沿着序列维度并行,这也是它被 Transformer 取代的主要工程原因。
  • 上下文容量
    • Transformer:Attention 的复杂度是 $O(N^2)$,$N$ 个 token 两两交互,记忆是无损的(不考虑 KV Cache 容量)。
    • RNN:无论序列有多长,所有的历史信息都被强制压缩在一个固定维度大小的向量 $h_t$ 中。这会导致信息瓶颈(Information Bottleneck)和长距离依赖下的梯度消失(Vanishing Gradient)。
  • 推理成本
    • Transformer:自回归生成时,随着序列变长,KV Cache 的显存占用线性增长,Attention 的计算量也在增加。
    • RNN:由于只维护一个固定大小的 $h_t$ 向量,它的推理复杂度是 $O(1)$,不需要缓存过去的 token,这在长文本生成的推理端是一个巨大优势。