CANN-torch_npu-昇腾NPU上PyTorch代码怎么一行不改就加速
摘要:CANN的torch_npu适配层让PyTorch代码无需修改即可在昇腾NPU上运行。安装需注意版本匹配,基本用法只需将tensor和模型.to("npu:0")。通过PyTorch的dispatch机制自动映射标准API到CANN算子实现,部分API支持自动融合。使用torch.compile可触发GE图优化,但需代码可trace。性能优化要点包括:减少CPU-NPU数据搬运、保持tens
torch_npu 是 CANN 的 PyTorch 适配层。装上它之后,.to("npu:0") 就能把 tensor 放到昇腾NPU上,PyTorch 的标准 API 自动走 CANN 的算子实现。理论上零代码改动,实际上有几个坑。
安装
pip install torch_npu
torch_npu 的版本必须跟 CANN 版本对齐:
| CANN 版本 | torch_npu 版本 | PyTorch 版本 |
|---|---|---|
| 8.0 | 2.1.x | 2.1.x |
| 8.5 | 2.3.x | 2.3.x |
版本不匹配不会报安装错误——但运行时某些算子会 fallback 到 CPU,性能断崖式下降。
基本用法
import torch
import torch_npu
# 检查 NPU 可用
print(torch.npu.is_available()) # True
# tensor 放到 NPU
x = torch.randn(2, 3).to("npu:0")
# 模型放到 NPU
model = MyModel().to("npu:0")
# 就这么简单
PyTorch 的所有标准 API 在昇腾NPU上都能用——torch.nn.functional.linear、torch.nn.functional.softmax、torch.matmul 等等。torch_npu 在底层把这些 API 映射到 CANN 的算子实现。
自动映射机制
torch_npu 用 PyTorch 的 dispatch 机制拦截 API 调用:
# PyTorch 内部流程
F.linear(x, w, b)
→ torch.ops.aten.linear(x, w, b) # ATen 算子
→ dispatch 到 npu backend
→ torch_npu.npu.linear(x, w, b) # CANN 实现
→ ops-nn 的 GEMM + BiasAdd
用户不需要显式调用 torch_npu.npu.*。用标准 PyTorch API 就行。
哪些 API 走了融合
不是所有 API 都能自动走融合。torch_npu 只对以下场景做了融合映射:
| PyTorch API | 映射到的 CANN 算子 | 条件 |
|---|---|---|
| F.scaled_dot_product_attention | FlashAttention | fp16/bf16, head_dim 对齐 |
| F.linear + F.silu | linear_activation | 连续 tensor |
| F.layer_norm | fused LayerNorm | 连续 tensor |
其他的都是标准映射(一个 PyTorch API → 一个 CANN 算子,不做跨算子融合)。
想走更多融合,需要用 torch.compile 或 ATB。
torch.compile + torch_npu
model = MyModel().to("npu:0")
model = torch.compile(model, backend="npu")
# GE 接管计算图,自动做算子融合
torch.compile 触发 GE 编译,GE 会做跨算子融合(graph-autofusion)。这是不换框架的前提下获得融合加速的最简单方式。
但 torch.compile 有个前提:模型代码必须是 FX traceable 的。动态控制流(if/else 跟 tensor 值相关)、Python side effect(全局变量修改)会让 trace 失败。
性能陷阱
陷阱 1:CPU-NPU 频繁搬运
# ❌ 每步都 CPU ↔ NPU
for step in range(100):
x = torch.randn(32, 4096).to("npu:0") # CPU → NPU
out = model(x)
loss = out.sum().item() # NPU → CPU
.item() 和 .to("cpu") 触发 NPU→CPU 同步,打断异步执行流水线。应该在 NPU 上尽量多算,最后一次性搬回 CPU。
陷阱 2:非连续 tensor
# ❌ transpose 后 tensor 非连续
x = torch.randn(32, 128, 4096, device="npu:0")
x = x.transpose(1, 2) # 非连续!
out = F.linear(x, w) # 走非优化路径
# ✅ 加 contiguous
x = x.transpose(1, 2).contiguous()
out = F.linear(x, w) # 走优化路径
陷阱 3:不支持的 dtype
# ❌ float64 在昇腾NPU上走 CPU
x = torch.randn(32, 4096, dtype=torch.float64, device="npu:0")
# 某些算子不支持 float64,静默 fallback 到 CPU
# ✅ 用 float16 或 bfloat16
x = torch.randn(32, 4096, dtype=torch.float16, device="npu:0")
torch_npu 是昇腾NPU上最轻量的使用方式——装上就能跑,大部分场景不需要改代码。但要跑得快,注意三个点:减少 CPU-NPU 搬运、保持 tensor 连续、用 float16/bfloat16。仓库在这里:
https://atomgit.com/cann/torch_npu
更多推荐


所有评论(0)