torch_npu 是 CANN 的 PyTorch 适配层。装上它之后,.to("npu:0") 就能把 tensor 放到昇腾NPU上,PyTorch 的标准 API 自动走 CANN 的算子实现。理论上零代码改动,实际上有几个坑。

安装

pip install torch_npu

torch_npu 的版本必须跟 CANN 版本对齐:

CANN 版本 torch_npu 版本 PyTorch 版本
8.0 2.1.x 2.1.x
8.5 2.3.x 2.3.x

版本不匹配不会报安装错误——但运行时某些算子会 fallback 到 CPU,性能断崖式下降。

基本用法

import torch
import torch_npu

# 检查 NPU 可用
print(torch.npu.is_available())  # True

# tensor 放到 NPU
x = torch.randn(2, 3).to("npu:0")

# 模型放到 NPU
model = MyModel().to("npu:0")

# 就这么简单

PyTorch 的所有标准 API 在昇腾NPU上都能用——torch.nn.functional.lineartorch.nn.functional.softmaxtorch.matmul 等等。torch_npu 在底层把这些 API 映射到 CANN 的算子实现。

自动映射机制

torch_npu 用 PyTorch 的 dispatch 机制拦截 API 调用:

# PyTorch 内部流程
F.linear(x, w, b)
  → torch.ops.aten.linear(x, w, b)  # ATen 算子
  → dispatch 到 npu backend
  → torch_npu.npu.linear(x, w, b)   # CANN 实现
  → ops-nn 的 GEMM + BiasAdd

用户不需要显式调用 torch_npu.npu.*。用标准 PyTorch API 就行。

哪些 API 走了融合

不是所有 API 都能自动走融合。torch_npu 只对以下场景做了融合映射:

PyTorch API 映射到的 CANN 算子 条件
F.scaled_dot_product_attention FlashAttention fp16/bf16, head_dim 对齐
F.linear + F.silu linear_activation 连续 tensor
F.layer_norm fused LayerNorm 连续 tensor

其他的都是标准映射(一个 PyTorch API → 一个 CANN 算子,不做跨算子融合)。

想走更多融合,需要用 torch.compile 或 ATB。

torch.compile + torch_npu

model = MyModel().to("npu:0")
model = torch.compile(model, backend="npu")

# GE 接管计算图,自动做算子融合

torch.compile 触发 GE 编译,GE 会做跨算子融合(graph-autofusion)。这是不换框架的前提下获得融合加速的最简单方式。

torch.compile 有个前提:模型代码必须是 FX traceable 的。动态控制流(if/else 跟 tensor 值相关)、Python side effect(全局变量修改)会让 trace 失败。

性能陷阱

陷阱 1:CPU-NPU 频繁搬运

# ❌ 每步都 CPU ↔ NPU
for step in range(100):
    x = torch.randn(32, 4096).to("npu:0")  # CPU → NPU
    out = model(x)
    loss = out.sum().item()  # NPU → CPU

.item().to("cpu") 触发 NPU→CPU 同步,打断异步执行流水线。应该在 NPU 上尽量多算,最后一次性搬回 CPU。

陷阱 2:非连续 tensor

# ❌ transpose 后 tensor 非连续
x = torch.randn(32, 128, 4096, device="npu:0")
x = x.transpose(1, 2)  # 非连续!
out = F.linear(x, w)    # 走非优化路径

# ✅ 加 contiguous
x = x.transpose(1, 2).contiguous()
out = F.linear(x, w)    # 走优化路径

陷阱 3:不支持的 dtype

# ❌ float64 在昇腾NPU上走 CPU
x = torch.randn(32, 4096, dtype=torch.float64, device="npu:0")
# 某些算子不支持 float64,静默 fallback 到 CPU

# ✅ 用 float16 或 bfloat16
x = torch.randn(32, 4096, dtype=torch.float16, device="npu:0")

torch_npu 是昇腾NPU上最轻量的使用方式——装上就能跑,大部分场景不需要改代码。但要跑得快,注意三个点:减少 CPU-NPU 搬运、保持 tensor 连续、用 float16/bfloat16。仓库在这里:

https://atomgit.com/cann/torch_npu

Logo

免费领 100 小时云算力,进群参与显卡、AI PC 幸运抽奖

更多推荐