环境准备:避开 ROCm 7.x 的“依赖坑”

在 AMD Instinct MI300X 上跑大模型微调,最让人头疼的往往不是算法本身,而是环境配置。很多兄弟在第一步就卡住了:驱动装上了,但 PyTorch 识别不到显卡,或者一跑 DeepSpeed 就报段错误。其实,只要理清 ROCm 7.x 的依赖链条,这事儿并没想象中那么复杂。

首先,操作系统建议锁定在 Ubuntu 22.04 LTS。ROCm 对新内核的支持虽然一直在进步,但 22.04 目前的生态兼容性是最稳的。安装完官方 ROCm 7.x 驱动后,别急着装 Python 包,先用 rocm-smirocminfo 确认显卡状态。看到显卡温度、功耗正常,且架构标识(如 gfx942)正确显示,才算地基打牢了。

接下来是重头戏:PyTorch 的源码编译。虽然官方提供了预编译包,但在 MI300X 这种新架构上,为了彻底激活 FlashAttention 和 DeepSpeed 的性能,强烈建议从源码编译。这里有个关键环境变量必须设对:PYTORCH_ROCM_ARCH=gfx942。如果漏了这一句,编译出来的二进制文件在运行时可能会报 "illegal instruction",到时候排查起来能让人怀疑人生。同时,确保你的 GCC 版本在 11 左右,过高或过低都可能导致 HIP 编译器链接失败。

核心配置:DeepSpeed ZeRO-3 与大显存的正确打开方式

MI300X 最大的卖点就是那恐怖的 192GB HBM3 显存。对于算法工程师来说,这意味着我们终于可以在单机上从容地微调 70B 甚至更大参数的模型,而不用被迫去搞复杂的多机分布式。要实现这一点,DeepSpeed ZeRO-3 策略是必选项。

在 LLaMA-Factory 的配置文件中,我们需要明确开启 ZeRO-3 优化。这不仅仅是改个参数那么简单,它涉及到显存分片、梯度卸载等一系列操作。针对 MI300X,建议在 deepspeed_config.json 中做如下关键设置:

{
  "zero_optimization": {
    "stage": 3,
    "offload_optimizer": {
      "device": "none", 
      "pin_memory": true
    },
    "overlap_comm": true,
    "contiguous_gradients": true,
    "sub_group_size": 1e9,
    "reduce_bucket_size": "auto",
    "stage3_prefetch_bucket_size": "auto",
    "stage3_param_persistence_threshold": "auto",
    "stage3_max_live_parameters": 1e9,
    "stage3_parition_grads": true,
    "stage3_gather_16bit_weights_on_model_save": true
  },
  "bf16": {
    "enabled": true
  }
}

注意这里的 offload_optimizer 设置为 none。因为在 MI300X 上,显存足够大,完全可以把优化器状态留在 GPU 显存里,没必要卸货到 CPU 内存从而拖累通信带宽。只有当你要微调超大规模模型(如 405B)且单卡显存吃紧时,才考虑开启 CPU Offload。

通过 ZeRO-3,模型参数、梯度和优化器状态会被切分存储在所有可用的 GPU 上。在单机八卡环境下,这相当于让你拥有了近 1.5TB 的聚合显存空间。实测中,这套配置能让 70B 模型在 batch size 较大的情况下依然稳定运行,不再频繁遭遇 OOM(显存溢出)报错。

精度与加速:BF16 与 FlashAttention 的实战表现

在精度选择上,BF16 (BFloat16) 是 ROCm 7.x 环境下的首选。相比 FP16,BF16 拥有更大的动态范围,能有效避免微调过程中的梯度下溢问题,特别是在训练后期损失函数收敛阶段,数值稳定性表现更好。在 LLaMA-Factory 中,只需将 compute_type 设为 bf16,框架会自动处理混合精度训练的梯度缩放逻辑。

另一个性能杀手锏是 FlashAttention 的 ROCm 适配。在旧版本中,AMD 平台对 FlashAttention 的支持一直是个痛点,但 ROCm 7.x 配合最新的 Triton 编译器,已经实现了原生级的高效支持。

我在实际测试中对比了开启与关闭 FlashAttention 的效果(基于 Llama-3-70B 模型,序列长度 4096):

  • 未开启 FlashAttention:训练速度约为 145 tokens/s,显存占用较高,长序列下容易触发重计算机制导致算力波动。
  • 开启 FlashAttention:训练速度飙升至 210+ tokens/s,提升幅度接近 45%。更重要的是,显存占用显著下降,使得我们可以进一步增大 batch size 或上下文窗口长度。

这一提升主要得益于 FlashAttention 减少了 HBM 的读写次数,让 MI300X 的高带宽特性真正转化为了算力。在 LLaMA-Factory 启动时,确保传入 --flash_attn 参数,并确认底层算子已正确编译链接。

避坑指南与微调启动

最后分享两个实战中容易踩的坑。第一,多卡通信问题。在单机多卡微调时,如果发现训练速度远低于预期,检查是否启用了 RCCL(ROCm 版的 NCCL)。LLaMA-Factory 默认会尝试调用,但有时需要手动指定 NCCL_DEBUG=INFO 来查看通信链路是否正常建立。确保所有卡在同一个 PCIe 根复合体或通过 Infinity Fabric 互联,避免数据走低速通道。

第二,梯度裁剪与爆炸。在使用 BF16 进行大模型微调时,偶尔会遇到梯度范数过大的情况。建议在配置中适当调整 max_grad_norm,通常设置在 1.0 到 0.5 之间,防止训练发散。

当你把上述配置都理顺后,启动命令就非常简洁了:

llamafactory-cli train \
    --model_name_or_path meta-llama/Meta-Llama-3-70B \
    --dataset alpaca_en_demo \
    --template llama3 \
    --finetuning_type lora \
    --lora_target q_proj,v_proj \
    --deepspeed ds_config_zero3.json \
    --bf16 true \
    --flash_attn true \
    --output_dir ./saves/llama3-70b-lora

看着终端里 Loss 稳步下降,GPU 利用率维持在 90% 以上,那种成就感是无可替代的。AMD 的性价比优势在这一刻体现得淋漓尽致,让我们能以更低的成本探索更大参数的模型边界。

如果你也想亲手试试在 MI300X 上微调 70B 大模型,却苦于没有合适的硬件环境,现在机会来了。200 小时 GPU 算力已就位,快来领取:https://marketing.csdn.net/questions/Q2604140858304426315?utm_source=AIpaper

Logo

免费领 200 小时云算力,进群参与显卡、AI PC 幸运抽奖

更多推荐