LLM 算子深度解析:05 LLaMA3 Block — 从零件到发动机


1. 为什么面试官最爱问 Block 组装?

前面四节我们分别拆解了 RMSNorm、SwiGLU、RoPE、GQA——但面试官不会只问零件。他们会问:“LLaMA 的 Decoder Layer 长什么样?数据从输入到输出怎么流的?画一下。”

这一节就是把零件组装成发动机。理解了整个 Block 的数据流,你才能回答"为什么 LLaMA 比 GPT-2 好训练"这种全局问题。


2. LLaMA vs 传统 Transformer:五次精准的"外科手术"

LLaMA 对传统 Transformer(如 GPT-2)做了五点改动,每一点都精准击中一个工程痛点。这不是堆砌 trick,而是用数学上的精巧代替工程上的蛮力

改进点 传统做法(GPT-2) LLaMA 做法 解决的痛点
归一化位置 Post-Norm(子层之后归一化) Pre-Norm(子层之前归一化) 深层网络梯度不稳定
归一化算法 LayerNorm(减均值 + 除方差 + 仿射) RMSNorm(只除 RMS,无均值无偏置) 10-15% 计算冗余
激活函数 ReLU / GELU SwiGLU(门控机制) 负值区域神经元死亡
位置编码 绝对位置编码(正弦波/可学习) RoPE(旋转位置编码) 无法长度外推
注意力机制 MHA(标准多头) GQA(分组查询) KV Cache 显存爆炸

2.1 Pre-Norm vs Post-Norm:归一化放前面还是后面?

这是训练稳定性的分水岭。

Post-Norm(GPT-2):
  x → Attention(x) → LayerNorm → x + residual
       ↑ 先计算,再归一化

Pre-Norm(LLaMA):
  x → RMSNorm(x) → Attention → x + residual
       ↑ 先归一化,再计算

深层网络的梯度困境:Post-Norm 下,第 80 层的梯度要穿过 80 个 Attention + 80 个 LayerNorm 才能回到第 1 层。每个子层都可能放大或缩小梯度——80 层下来,梯度要么爆炸要么消失。

Pre-Norm 把归一化放在子层入口,确保每个子层接收到的输入都是归一化过的。梯度在残差连接上有一条"高速公路"直达浅层。

一句话:Post-Norm 是先做饭再洗手,Pre-Norm 是先洗手再做饭。干净多了,训练也稳定多了。

2.2 SwiGLU 的门控哲学

ReLU 的问题是粗暴:负值区域梯度为 0 → 神经元"死亡" → 再也救不回来。

SwiGLU 用两条通路 + 门控替代了这种一刀切:

SwiGLU(x)=(SiLU(xWgate)⊙xWup)Wdown\text{SwiGLU}(x) = (\text{SiLU}(x W_{\text{gate}}) \odot x W_{\text{up}}) W_{\text{down}}SwiGLU(x)=(SiLU(xWgate)xWup)Wdown

  • gate 通路xWgatex W_{\text{gate}}xWgate → SiLU → 输出 0~1 的"门控信号"
  • up 通路xWupx W_{\text{up}}xWup → 承载实际信息
  • 逐元素乘:门控信号 × 信息 → “这个特征保留 30%,那个特征保留 80%”

传统 FFN 是 ReLU(x·W1)·W2——只有一条通路,ReLU 一刀切。SwiGLU 让模型学会选择性放行


3. 数学定义:Decoder Layer 的完整公式

3.1 两层结构

输入: x  ∈ R^{B×S×d}

Attention Block:
  h = x + GQA(RoPE(RMSNorm(x)))
       ↑     ↑      ↑
       残差   位置编码   Pre-Norm

MLP Block:
  out = h + SwiGLU(RMSNorm(h))
         ↑       ↑       ↑
         残差   门控MLP   Pre-Norm

3.2 为什么 intermediate_size = (8/3) × hidden_size?

这是面试高频题。推导只需要小学算术:

传统 MLPd → 4d → d,参数量 = d×4d+4d×d=8d2d \times 4d + 4d \times d = 8d^2d×4d+4d×d=8d2

SwiGLU MLPgate: d → d_interup: d → d_interdown: d_inter → d,参数量 = 3d×dinter3d \times d_{\text{inter}}3d×dinter

令两者相等:
3d×dinter=8d2⇒dinter=83d≈2.67d3d \times d_{\text{inter}} = 8d^2 \quad \Rightarrow \quad d_{\text{inter}} = \frac{8}{3}d \approx 2.67d3d×dinter=8d2dinter=38d2.67d

具体数值:LLaMA-3 8B 的 hidden_size = 4096 → d_inter = 4096 × 8/3 ≈ 10923 → 实际取 14336(对齐到 128 的倍数,适配 Tensor Core)。

直觉:传统 MLP 用 4× 扩展是因为只有 2 个投影层;SwiGLU 有 3 个投影层,每个自然要窄一些,总参数量才能持平。


4. 代码实现:数据流的骨架

4.1 SwiGLU MLP

class LlamaMLP(nn.Module):
    def __init__(self, hidden_size: int, intermediate_size: int):
        super().__init__()
        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.up_proj   = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # gate: [B, S, d] → [B, S, d_inter] → SiLU → 门控信号
        # up:   [B, S, d] → [B, S, d_inter] → 信息通路
        # 逐元素乘: gate ⊙ up → [B, S, d_inter]
        # down: [B, S, d_inter] → [B, S, d]
        return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))

bias=False 的原因:在大规模模型中,偏置项参数量占比极小(d 维偏置 vs d×d_inter 维权重),但去除偏置能减少显存碎片、简化 kernel launch。LLaMA 全系列都无偏置——极简主义的又一体现。

维度追踪(LLaMA-3 8B):

输入 x:                 [B, S, 4096]
gate_proj(x):           [B, S, 14336]    (4096 → 14336)
up_proj(x):             [B, S, 14336]
F.silu(gate) * up:      [B, S, 14336]    逐元素乘,维度不变
down_proj(结果):         [B, S, 4096]     (14336 → 4096)

4.2 LLaMA Decoder Layer:残差连接的四步法则

class LlamaDecoderLayer(nn.Module):
    def __init__(self, hidden_size: int, intermediate_size: int):
        super().__init__()
        self.input_layernorm = RMSNorm(hidden_size)           # Pre-Norm for Attention
        self.self_attn = GroupedQueryAttention(...)            # 内含 RoPE + GQA
        self.post_attention_layernorm = RMSNorm(hidden_size)   # Pre-Norm for MLP
        self.mlp = LlamaMLP(hidden_size, intermediate_size)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # ---- Attention Block ----
        residual = hidden_states                              # ① 保存残差
        hidden_states = self.input_layernorm(hidden_states)   # ② Pre-Norm
        hidden_states = self.self_attn(hidden_states)         # ③ GQA + RoPE
        hidden_states = residual + hidden_states              # ④ 残差连接

        # ---- MLP Block ----
        residual = hidden_states                              # ⑤ 保存残差
        hidden_states = self.post_attention_layernorm(hidden_states)  # ⑥ Pre-Norm
        hidden_states = self.mlp(hidden_states)               # ⑦ SwiGLU MLP
        hidden_states = residual + hidden_states              # ⑧ 残差连接

        return hidden_states

必须严格遵守的四步顺序保存残差 → 归一化 → 子层计算 → 加回残差

为什么残差在归一化之前保存? 残差分支提供"未经归一化的原始信号",归一化分支提供"归一化后经过子层处理的信号"。两者相加 = 原始信号 + 变换信号。如果残差也被归一化,"高速公路"就变成普通道路——梯度无法直达浅层。

4.3 命名陷阱:post_attention_layernorm 到底是什么?

属性名 实际含义 位置
input_layernorm Attention 的 Pre-Norm Attention 之前
post_attention_layernorm MLP 的 Pre-Norm Attention 之后、MLP 之前

post_attention_layernorm 名字里带 “post”,但它在 MLP 之前做归一化——仍然是 Pre-Norm 架构。它叫 “post_attention” 是因为它在 Attention 之后。理解这个命名对阅读 HuggingFace 源码很关键。


5. 工业对照:HuggingFace LLaMA 源码

HF 的实现逻辑完全一致,额外加了 KV Cache 管理和 attention_mask 传递:

# HuggingFace transformers/models/llama/modeling_llama.py (简化)
class LlamaDecoderLayer(nn.Module):
    def forward(self, hidden_states, attention_mask=None, position_ids=None,
                past_key_value=None, ...):
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states, _, past_key_value = self.self_attn(
            hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,   # KV Cache 支持
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states  # + past_key_value

LLaMA-3 各规格参数一览

模型 层数 hidden_size intermediate_size Q头数 KV头数 GQA 压缩比
LLaMA-3 8B 32 4096 14336 32 8 4:1
LLaMA-3 70B 80 8192 28672 64 8 8:1
LLaMA-3 405B 126 16384 53248 128 8 16:1

规律:KV 头数固定为 8 → 模型越大,GQA 压缩比越高;intermediate_size ≈ hidden_size × 8/3,对齐到 128 的倍数。


6. 踩坑实录

现象 根因 解决
忘记保存残差 loss 不下降或 → NaN 直接在归一化后的值上操作,梯度路径被掐断 严格遵循"保存→归一化→计算→加回"四步
两个 RMSNorm 共享参数 训练效果差 Attention 和 MLP 前的归一化需要各自学习不同的 γ 分别创建两个 RMSNorm 实例
gate 和 up 顺序搞反 能跑但效果不如正确版本 F.silu(up) * gate 语义错误——gate 走激活、up 不走 F.silu(gate_proj(x)) * up_proj(x)
intermediate_size 用了 4× 参数量比预期大 50% SwiGLU 有 3 个投影层,用 4× 导致参数暴涨 用 8/3 × hidden_size
忘记 .contiguous() RuntimeError: view size is not compatible transpose 后内存不连续 .contiguous().view(...)

7. 延伸思考:LLaMA 架构的"极简主义"设计哲学

回顾这五点改进,每条都是做减法

  • LayerNorm → RMSNorm:去掉均值和偏置——不减精度,只减计算
  • ReLU → SwiGLU:更精细的控制——不是更复杂,是更聪明
  • 绝对编码 → RoPE:不再"查表",而是"算角度"——数学代替存储
  • MHA → GQA:减少 KV 头——用结构换显存
  • Post-Norm → Pre-Norm:把归一化挪到前面——结构更简单,训练更稳定

核心洞察:大模型架构的进化方向不是"堆更多东西",而是用数学上的精巧代替工程上的堆砌。每减掉一个组件,都留下了一个更优雅的方案。

值得深挖的方向

  • Parallel Transformer Block(GPT-J):把 Attention 和 MLP 并行计算,再合并残差——速度更快但精度略降
  • Sandwich Norm:Pre-Norm + Post-Norm 混合体,在某些训练阶段更稳定
  • MoE(07节):把 MLP 层替换为多个专家 + 路由器——Mixtral/DeepSeek 的架构基础
  • FlashAttention + Block(14节):完整的端到端实现

下一篇:[[06 MoE Router]] — 把 MLP 切成 8 个专家,让每个 token 只找最懂行的 2 个干活。

更多推荐