大模型教我成为大模型算法工程师之day10:循环神经网络 (RNN)
摘要:循环神经网络(RNN)通过引入时间维度的记忆机制处理序列数据,但存在梯度消失问题。LSTM通过遗忘门、输入门和输出门控制信息流动,利用细胞状态(Cell State)实现长距离依赖。GRU作为简化版,合并状态和门控机制,提升效率。双向RNN同时考虑上下文信息,而Seq2Seq架构为机器翻译奠定基础。虽然Transformer主导NLP领域,但LSTM/GRU在小模型和实时计算中仍具优势。本文
Day 10: 循环神经网络 (RNN)
摘要:人类阅读时不会每看一个词都把前面的忘了,我们的思维是连贯的。循环神经网络 (RNN) 赋予了机器这种“记忆”能力。本文将带你理解 RNN 如何处理序列数据,剖析其致命弱点(梯度消失),并深入图解 LSTM 和 GRU 是如何通过精妙的“门控机制”解决长距离依赖问题的。
1. 为什么需要 RNN?
在 CNN 和 MLP 中,输入和输出是独立的(看这张猫图和看下一张狗图没关系)。
但在处理序列数据 (Sequence Data) 时,前后的顺序至关重要:
- 自然语言:“我 喜欢 吃 苹果” vs “苹果 喜欢 吃 我”。
- 时间序列:今天的股票价格依赖于过去几天的走势。
- 语音:当前的音素与前后的发音紧密相关。
RNN 的核心思想是:不仅利用当前的输入,还利用上一时刻的“状态” (Hidden State)。
2. RNN 基础原理
2.1 结构拆解
RNN 可以看作是一个在时间轴上展开的网络。
- x t x_t xt: t t t 时刻的输入。
- h t h_t ht: t t t 时刻的隐状态(记忆)。
- 核心公式:
h t = tanh ( W i h x t + W h h h t − 1 + b ) h_t = \tanh(W_{ih} x_t + W_{hh} h_{t-1} + b) ht=tanh(Wihxt+Whhht−1+b)
y t = W h y h t + b y y_t = W_{hy} h_t + b_y yt=Whyht+by
名词解释:权重矩阵 (Weight Matrix)
公式里的 W W W 就是神经网络要学习的参数。
- W i h W_{ih} Wih:决定了当前输入 x t x_t xt 如何影响新的记忆。
- W h h W_{hh} Whh:决定了旧记忆 h t − 1 h_{t-1} ht−1 如何转化为新记忆。
梯度消失/爆炸就是因为 W h h W_{hh} Whh 在时间维度上连乘了太多次(几百次循环),导致数值失控。
通俗理解:
想象你在读一本书。
x t x_t xt 是你当前看到的词, h t − 1 h_{t-1} ht−1 是你脑子里对前文的记忆。
你把“当前词”和“前文记忆”融合,生成新的记忆 h t h_t ht,并以此预测下一个词 y t y_t yt。
2.2 致命弱点:梯度消失/爆炸
在反向传播时(BPTT),梯度需要通过时间维度一步步往回传。
- 如果权重矩阵 W < 1 W < 1 W<1,连乘多次后梯度趋近于 0 —— 梯度消失 (Vanishing Gradient)。这导致网络**“记不住”**很久以前的信息(例如开头说了 “Alice”,结尾忘了是 “She”)。
- 如果权重矩阵 W > 1 W > 1 W>1,连乘多次后梯度趋近无穷大 —— 梯度爆炸 (Exploding Gradient)。
3. LSTM (Long Short-Term Memory)
为了解决梯度消失,Schmidhuber 等人在 1997 年提出了 LSTM。
LSTM 的设计哲学是:专门搞一个“高速公路”(Cell State),让信息能无损地流传下去。
3.1 核心组件:三个门 (Gates)
所有的门都由 Sigmoid 函数控制(输出 0~1),0 代表关(遗忘/拦截),1 代表开(通过)。
- 遗忘门 (Forget Gate) f t f_t ft:决定丢弃哪些旧信息。
- “这句话讲完了,句号了,把上一句的主语忘了吧。”
- 输入门 (Input Gate) i t i_t it:决定存入哪些新信息。
- “这个词‘但是’很重要,表示转折,得记下来。”
- 输出门 (Output Gate) o t o_t ot:决定当前的隐状态 h t h_t ht 输出什么。
3.2 细胞状态 (Cell State) C t C_t Ct
这是 LSTM 的核心——信息的传送带。
C t = f t ⋅ C t − 1 + i t ⋅ C ~ t C_t = f_t \cdot C_{t-1} + i_t \cdot \tilde{C}_t Ct=ft⋅Ct−1+it⋅C~t
- 旧记忆 C t − 1 C_{t-1} Ct−1 乘以遗忘门(保留一部分)。
- 加上新记忆 C ~ t \tilde{C}_t C~t 乘以输入门(存入一部分)。
- 关键点:这里是加法更新,而不是乘法。加法让梯度能更稳定地回传,避免了连乘导致的消失问题。
4. GRU (Gated Recurrent Unit)
LSTM 太复杂了(三个门,两个状态),计算慢。
2014 年,Bengio 团队提出了 GRU,它是 LSTM 的简化版,效果差不多,但跑得更快。
4.1 简化策略
- 合并状态:把 Cell State C t C_t Ct 和 Hidden State h t h_t ht 合二为一。
- 合并门:把遗忘门和输入门合并为 更新门 (Update Gate) z t z_t zt。
- z t z_t zt 决定了:我是保留旧记忆,还是用新输入替换它?
- 重置门 (Reset Gate) r t r_t rt:决定如何将新的输入与旧记忆结合。
4.2 LSTM vs GRU
- GRU:参数少(3组权重),训练快,小数据集首选。
- LSTM:参数多(4组权重),表达能力稍强,大数据集或超长序列可能更好。
5. 双向 RNN (Bidirectional RNN)
有些时候,我们不仅需要知道“过去”,还需要知道“未来”。
- 例子:填空题 “He said ____ to me.”
- 只看前面 “He said”,可能是 “hello”, “goodbye”, “nothing”。
- 看了后面 “to me”,范围没变。
- 如果句子是 “He said ____ to the crowd.”,那可能是 “loudly”。
- 原理:两个 RNN,一个从左往右读,一个从右往左读。最后把两个 h t h_t ht 拼起来。
6. Seq2Seq 与 Encoder-Decoder
这是 RNN 最辉煌的应用形式,也是机器翻译的基础。
- Encoder (编码器):把输入序列(中文)压缩成一个上下文向量 (Context Vector)。
- Decoder (解码器):根据这个向量,一步步生成输出序列(英文)。
问题:不管句子多长,都要压缩成一个固定长度的向量,容易**“消化不良”**(信息丢失)。这直接催生了后来的 Attention 机制。
7. 代码实践:LSTM 文本分类
import torch
import torch.nn as nn
class RNNClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim, n_layers, dropout):
super().__init__()
# 1. Embedding层:把词索引变成稠密向量
self.embedding = nn.Embedding(vocab_size, embed_dim)
# 2. LSTM层
# batch_first=True -> (batch, seq_len, features)
# bidirectional=True -> 双向LSTM
self.lstm = nn.LSTM(embed_dim,
hidden_dim,
num_layers=n_layers,
bidirectional=True,
dropout=dropout,
batch_first=True)
# 3. 全连接层
# 双向LSTM输出维度是 hidden_dim * 2
self.fc = nn.Linear(hidden_dim * 2, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, text):
# text: [batch_size, seq_len]
embedded = self.dropout(self.embedding(text))
# embedded: [batch_size, seq_len, embed_dim]
# output: 每个时间步的输出
# hidden: 最后一个时间步的隐状态 (h_n, c_n)
output, (hidden, cell) = self.lstm(embedded)
# hidden: [n_layers * n_directions, batch_size, hidden_dim]
# 我们取最后一层的最后时刻状态
# 提取正向和反向的最后状态
hidden_forward = hidden[-2,:,:]
hidden_backward = hidden[-1,:,:]
# 拼接 (Concatenate):把两个向量并排粘在一起
# 结果维度: [batch_size, hidden_dim * 2]
# 不是相加也不是平均,而是保留正反向的所有信息
hidden_final = torch.cat((hidden_forward, hidden_backward), dim=1)
return self.fc(hidden_final)
8. 总结
- RNN 引入了时间维度的记忆。
- LSTM 用“门”和“传送带”解决了长距离遗忘的问题。
- GRU 是 LSTM 的高效简化版。
- Bi-RNN 同时看上下文。
- Seq2Seq 开启了序列生成的时代。
虽然 Transformer 现在接管了 NLP,但在小模型、流式计算(实时语音)、时间序列预测等领域,LSTM/GRU 依然有一席之地。
参考资料
更多推荐


所有评论(0)