你写 F.linear(x, w, b),一行 Python。但在昇腾NPU上它要变成:DMA 搬入 → Cube GEMM → DMA 搬出。中间的转换由 GE(Graph Engine)图引擎完成。理解 GE 的工作流程,才能搞清楚为什么有时候算子没走融合路径。

GE 的工作流程

1. 前端:PyTorch FX Graph → GE 计算图
2. 优化:算子融合、常量折叠、死代码消除
3. 编译:计算图 → NPU 可执行文件(om)
4. 执行:om 文件加载到 NPU 运行

步骤 1-3 发生在首次推理时(“编译期”),步骤 4 在每次推理时执行。编译一次,运行多次。

前端:从 PyTorch 到 GE

torch.compile(model, backend="npu") 触发 GE 编译流程:

import torch
import torch_npu

model = MyModel().to("npu:0")
model = torch.compile(model, backend="npu")  # 触发 GE 编译

# 首次推理:编译 + 执行(较慢,30-60s)
out = model(x)

# 后续推理:直接执行 om 文件(快)
out = model(x)

GE 接收的是 PyTorch 的 FX Graph(一种中间表示)。每个 FX Node 对应一个算子调用,GE 把它映射到 CANN 的算子实现。

算子映射

GE 维护一个映射表:PyTorch 算子 → CANN 算子

torch.nn.functional.linear → ops-nn MatMul + BiasAdd
torch.nn.functional.layer_norm → ops-nn LayerNorm
torch.nn.functional.silu → ops-nn SiLU
torch.nn.functional.scaled_dot_product_attention → ops-transformer FlashAttention

如果映射表里没有对应的 CANN 算子,GE 会尝试用 CANN 的基础算子组合实现。如果组合也做不到,就 fallback 到 CPU 执行——性能灾难。

优化 Pass

GE 的优化器运行多个 Pass:

Pass 1:算子融合

Linear → SiLU → [融合] → FusedLinearSiLU

Pass 2:常量折叠

Shape([1, 4096]) → Reshape → [编译时算好] → [1, 32, 128]

Pass 3:死代码消除

x = Linear(x)  ← 输出没被任何下游使用
[删除]

Pass 4:内存规划

为每个中间 tensor 分配 HBM 地址
重叠分配:tensor A 用完后释放的地址给 tensor B 用

编译产物

编译结果是 .om 文件(离线模型),包含:

  • NPU 可执行指令流
  • 常量数据(权重、编译时算好的参数)
  • 内存布局信息
# 查看 om 文件信息
atc --mode=1 --om=model.om --framework=0

om 文件一旦生成,只要输入 shape 不变,可以反复使用。这就是 ATB 的 cache_dir 里存的东西。

动态 Shape 问题

GE 编译时需要知道输入 tensor 的 shape。如果 shape 在运行时才确定,GE 有两种策略:

策略 1:重新编译。 每次遇到新 shape 都编译一份 om 文件。首 token 延迟高(30-60s),但编译后很快。

策略 2:动态 shape 编译。 编译一份支持任意 shape 的 om 文件。首 token 延迟低(0.5s),但运行时每次要重新计算 Tiling 参数,单步推理比静态 shape 慢 5-10%。

ATB 默认用策略 2。如果你的输入 shape 固定(比如所有请求都 pad 到 4096),可以强制用策略 1 获得更好的性能。

from atb import LLM

model = LLM("model_id", device="npu:0",
            dynamic_shape=False)  # 强制静态 shape 编译

调试 GE 编译

# 开启 GE 调试日志
export GE_LOG_LEVEL=0  # DEBUG

# 查看 GE 编译过程中的算子映射和融合
grep "FusionOp" ge_log.txt
grep "AutoFusion" ge_log.txt
grep "UnsupportedOp" ge_log.txt  # 没有映射的算子

如果日志里有 UnsupportedOp,说明某个 PyTorch 算子没有 CANN 实现,会 fallback 到 CPU。这是推理性能断崖式下降的常见原因。


GE 是 PyTorch 代码和昇腾NPU硬件之间的翻译器。理解了它的编译流程,遇到"为什么推理比预期慢"的问题就能快速定位:是算子没映射上?是融合没生效?还是 shape 不固定导致重复编译?仓库在这里:

https://atomgit.com/cann/ATB

Logo

免费领 200 小时云算力,进群参与显卡、AI PC 幸运抽奖

更多推荐