05 LLaMA3 Block — 从零件到发动机
LLM 算子深度解析:05 LLaMA3 Block — 从零件到发动机
1. 为什么面试官最爱问 Block 组装?
前面四节我们分别拆解了 RMSNorm、SwiGLU、RoPE、GQA——但面试官不会只问零件。他们会问:“LLaMA 的 Decoder Layer 长什么样?数据从输入到输出怎么流的?画一下。”
这一节就是把零件组装成发动机。理解了整个 Block 的数据流,你才能回答"为什么 LLaMA 比 GPT-2 好训练"这种全局问题。
2. LLaMA vs 传统 Transformer:五次精准的"外科手术"
LLaMA 对传统 Transformer(如 GPT-2)做了五点改动,每一点都精准击中一个工程痛点。这不是堆砌 trick,而是用数学上的精巧代替工程上的蛮力。
| 改进点 | 传统做法(GPT-2) | LLaMA 做法 | 解决的痛点 |
|---|---|---|---|
| 归一化位置 | Post-Norm(子层之后归一化) | Pre-Norm(子层之前归一化) | 深层网络梯度不稳定 |
| 归一化算法 | LayerNorm(减均值 + 除方差 + 仿射) | RMSNorm(只除 RMS,无均值无偏置) | 10-15% 计算冗余 |
| 激活函数 | ReLU / GELU | SwiGLU(门控机制) | 负值区域神经元死亡 |
| 位置编码 | 绝对位置编码(正弦波/可学习) | RoPE(旋转位置编码) | 无法长度外推 |
| 注意力机制 | MHA(标准多头) | GQA(分组查询) | KV Cache 显存爆炸 |
2.1 Pre-Norm vs Post-Norm:归一化放前面还是后面?
这是训练稳定性的分水岭。
Post-Norm(GPT-2):
x → Attention(x) → LayerNorm → x + residual
↑ 先计算,再归一化
Pre-Norm(LLaMA):
x → RMSNorm(x) → Attention → x + residual
↑ 先归一化,再计算
深层网络的梯度困境:Post-Norm 下,第 80 层的梯度要穿过 80 个 Attention + 80 个 LayerNorm 才能回到第 1 层。每个子层都可能放大或缩小梯度——80 层下来,梯度要么爆炸要么消失。
Pre-Norm 把归一化放在子层入口,确保每个子层接收到的输入都是归一化过的。梯度在残差连接上有一条"高速公路"直达浅层。
一句话:Post-Norm 是先做饭再洗手,Pre-Norm 是先洗手再做饭。干净多了,训练也稳定多了。
2.2 SwiGLU 的门控哲学
ReLU 的问题是粗暴:负值区域梯度为 0 → 神经元"死亡" → 再也救不回来。
SwiGLU 用两条通路 + 门控替代了这种一刀切:
SwiGLU(x)=(SiLU(xWgate)⊙xWup)Wdown\text{SwiGLU}(x) = (\text{SiLU}(x W_{\text{gate}}) \odot x W_{\text{up}}) W_{\text{down}}SwiGLU(x)=(SiLU(xWgate)⊙xWup)Wdown
- gate 通路:xWgatex W_{\text{gate}}xWgate → SiLU → 输出 0~1 的"门控信号"
- up 通路:xWupx W_{\text{up}}xWup → 承载实际信息
- 逐元素乘:门控信号 × 信息 → “这个特征保留 30%,那个特征保留 80%”
传统 FFN 是 ReLU(x·W1)·W2——只有一条通路,ReLU 一刀切。SwiGLU 让模型学会选择性放行。
3. 数学定义:Decoder Layer 的完整公式
3.1 两层结构
输入: x ∈ R^{B×S×d}
Attention Block:
h = x + GQA(RoPE(RMSNorm(x)))
↑ ↑ ↑
残差 位置编码 Pre-Norm
MLP Block:
out = h + SwiGLU(RMSNorm(h))
↑ ↑ ↑
残差 门控MLP Pre-Norm
3.2 为什么 intermediate_size = (8/3) × hidden_size?
这是面试高频题。推导只需要小学算术:
传统 MLP:d → 4d → d,参数量 = d×4d+4d×d=8d2d \times 4d + 4d \times d = 8d^2d×4d+4d×d=8d2
SwiGLU MLP:gate: d → d_inter,up: d → d_inter,down: d_inter → d,参数量 = 3d×dinter3d \times d_{\text{inter}}3d×dinter
令两者相等:
3d×dinter=8d2⇒dinter=83d≈2.67d3d \times d_{\text{inter}} = 8d^2 \quad \Rightarrow \quad d_{\text{inter}} = \frac{8}{3}d \approx 2.67d3d×dinter=8d2⇒dinter=38d≈2.67d
具体数值:LLaMA-3 8B 的 hidden_size = 4096 → d_inter = 4096 × 8/3 ≈ 10923 → 实际取 14336(对齐到 128 的倍数,适配 Tensor Core)。
直觉:传统 MLP 用 4× 扩展是因为只有 2 个投影层;SwiGLU 有 3 个投影层,每个自然要窄一些,总参数量才能持平。
4. 代码实现:数据流的骨架
4.1 SwiGLU MLP
class LlamaMLP(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# gate: [B, S, d] → [B, S, d_inter] → SiLU → 门控信号
# up: [B, S, d] → [B, S, d_inter] → 信息通路
# 逐元素乘: gate ⊙ up → [B, S, d_inter]
# down: [B, S, d_inter] → [B, S, d]
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
bias=False 的原因:在大规模模型中,偏置项参数量占比极小(d 维偏置 vs d×d_inter 维权重),但去除偏置能减少显存碎片、简化 kernel launch。LLaMA 全系列都无偏置——极简主义的又一体现。
维度追踪(LLaMA-3 8B):
输入 x: [B, S, 4096]
gate_proj(x): [B, S, 14336] (4096 → 14336)
up_proj(x): [B, S, 14336]
F.silu(gate) * up: [B, S, 14336] 逐元素乘,维度不变
down_proj(结果): [B, S, 4096] (14336 → 4096)
4.2 LLaMA Decoder Layer:残差连接的四步法则
class LlamaDecoderLayer(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.input_layernorm = RMSNorm(hidden_size) # Pre-Norm for Attention
self.self_attn = GroupedQueryAttention(...) # 内含 RoPE + GQA
self.post_attention_layernorm = RMSNorm(hidden_size) # Pre-Norm for MLP
self.mlp = LlamaMLP(hidden_size, intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# ---- Attention Block ----
residual = hidden_states # ① 保存残差
hidden_states = self.input_layernorm(hidden_states) # ② Pre-Norm
hidden_states = self.self_attn(hidden_states) # ③ GQA + RoPE
hidden_states = residual + hidden_states # ④ 残差连接
# ---- MLP Block ----
residual = hidden_states # ⑤ 保存残差
hidden_states = self.post_attention_layernorm(hidden_states) # ⑥ Pre-Norm
hidden_states = self.mlp(hidden_states) # ⑦ SwiGLU MLP
hidden_states = residual + hidden_states # ⑧ 残差连接
return hidden_states
必须严格遵守的四步顺序:保存残差 → 归一化 → 子层计算 → 加回残差
为什么残差在归一化之前保存? 残差分支提供"未经归一化的原始信号",归一化分支提供"归一化后经过子层处理的信号"。两者相加 = 原始信号 + 变换信号。如果残差也被归一化,"高速公路"就变成普通道路——梯度无法直达浅层。
4.3 命名陷阱:post_attention_layernorm 到底是什么?
| 属性名 | 实际含义 | 位置 |
|---|---|---|
input_layernorm |
Attention 的 Pre-Norm | Attention 之前 |
post_attention_layernorm |
MLP 的 Pre-Norm | Attention 之后、MLP 之前 |
post_attention_layernorm 名字里带 “post”,但它在 MLP 之前做归一化——仍然是 Pre-Norm 架构。它叫 “post_attention” 是因为它在 Attention 之后。理解这个命名对阅读 HuggingFace 源码很关键。
5. 工业对照:HuggingFace LLaMA 源码
HF 的实现逻辑完全一致,额外加了 KV Cache 管理和 attention_mask 传递:
# HuggingFace transformers/models/llama/modeling_llama.py (简化)
class LlamaDecoderLayer(nn.Module):
def forward(self, hidden_states, attention_mask=None, position_ids=None,
past_key_value=None, ...):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, _, past_key_value = self.self_attn(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value, # KV Cache 支持
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states # + past_key_value
LLaMA-3 各规格参数一览
| 模型 | 层数 | hidden_size | intermediate_size | Q头数 | KV头数 | GQA 压缩比 |
|---|---|---|---|---|---|---|
| LLaMA-3 8B | 32 | 4096 | 14336 | 32 | 8 | 4:1 |
| LLaMA-3 70B | 80 | 8192 | 28672 | 64 | 8 | 8:1 |
| LLaMA-3 405B | 126 | 16384 | 53248 | 128 | 8 | 16:1 |
规律:KV 头数固定为 8 → 模型越大,GQA 压缩比越高;intermediate_size ≈ hidden_size × 8/3,对齐到 128 的倍数。
6. 踩坑实录
| 坑 | 现象 | 根因 | 解决 |
|---|---|---|---|
| 忘记保存残差 | loss 不下降或 → NaN | 直接在归一化后的值上操作,梯度路径被掐断 | 严格遵循"保存→归一化→计算→加回"四步 |
| 两个 RMSNorm 共享参数 | 训练效果差 | Attention 和 MLP 前的归一化需要各自学习不同的 γ | 分别创建两个 RMSNorm 实例 |
| gate 和 up 顺序搞反 | 能跑但效果不如正确版本 | F.silu(up) * gate 语义错误——gate 走激活、up 不走 |
F.silu(gate_proj(x)) * up_proj(x) |
| intermediate_size 用了 4× | 参数量比预期大 50% | SwiGLU 有 3 个投影层,用 4× 导致参数暴涨 | 用 8/3 × hidden_size |
忘记 .contiguous() |
RuntimeError: view size is not compatible |
transpose 后内存不连续 | .contiguous().view(...) |
7. 延伸思考:LLaMA 架构的"极简主义"设计哲学
回顾这五点改进,每条都是做减法:
- LayerNorm → RMSNorm:去掉均值和偏置——不减精度,只减计算
- ReLU → SwiGLU:更精细的控制——不是更复杂,是更聪明
- 绝对编码 → RoPE:不再"查表",而是"算角度"——数学代替存储
- MHA → GQA:减少 KV 头——用结构换显存
- Post-Norm → Pre-Norm:把归一化挪到前面——结构更简单,训练更稳定
核心洞察:大模型架构的进化方向不是"堆更多东西",而是用数学上的精巧代替工程上的堆砌。每减掉一个组件,都留下了一个更优雅的方案。
值得深挖的方向
- Parallel Transformer Block(GPT-J):把 Attention 和 MLP 并行计算,再合并残差——速度更快但精度略降
- Sandwich Norm:Pre-Norm + Post-Norm 混合体,在某些训练阶段更稳定
- MoE(07节):把 MLP 层替换为多个专家 + 路由器——Mixtral/DeepSeek 的架构基础
- FlashAttention + Block(14节):完整的端到端实现
下一篇:[[06 MoE Router]] — 把 MLP 切成 8 个专家,让每个 token 只找最懂行的 2 个干活。
更多推荐



所有评论(0)