自定义ROCm算子

在这里插入图片描述

实际上我们可以自定义一个ROCm算子,并在torch里调用它,这样可以用py的语法,ROCm的性能。这也是为什么说py是胶水语言,实际上Torch的大部分算子都是这样实现的,python只起到一个顶层接口的作用

本节代码位于AMD云服务器的src/infra/custom-pytorch-operator/code/custom_swish/路径下,或者也可以看这个仓库https://github.com/datawhalechina/hello-rocm

从torch到rocm大致的执行流程是这样:
在这里插入图片描述
torch调用封装函数,函数由pybind封装,去调用cpp的ATen库,这是一个cpp的通用张量库,支持我们把自定义算子注册到这个库里。再往下就是HIP实现的自定义算子,给算子外面套一层ATen的封装,就能注册到ATen。最后HIP算子编译成AMD机器码执行。

所以我们想实现一个python可调用的函数,需要实现HIP内核,实现ATen注册,实现pybind封装,完成python注册。

HIP实现

这里实现一个fused_swish算子

swish的定义是f(x)=x∗sigmoid(x)f(x)=x*sigmoid(x)f(x)=xsigmoid(x),这是个简单的逐元素操作算子。

fused的意思是融合算子,如果不融合,一般的做法是,把乘法和sigmoid分别调用两个算子进行计算,开始先sigmoid,存到一个中间张量,再读这个中间张量计算乘法。这里的保存中间结果和读取中间结果其实可以省去。直接定义一个算子,内部完成swish的全流程,这减少了内存读写,能大幅提高性能。

分析一下具体实现:

  • fused_swish_forward_kernel(const scalar_t* input, scalar_t* output, int size)对标torch算子,需要有前向传播和反向传播,这里先定义前向传播
  • template <typename scalar_t>支持模版泛型,只用写一遍就能编译时动态生成任意类型的接口
  • for (int i = idx; i < size; i += stride)跨步循环。一个block启动的线程是有限的,但是输入可能很大,不可能令线程块大小=输入大小,考虑让一个线程负责多个元素的计算,也就是这里定义一个步长stride,线程i负责所有模stride同余的位置的计算
  • float sigmoid_x = 1.0f / (1.0f + expf(-x));sigmoid中间结果,只保存在一个临时变量里,不存到内存中,这是融合的关键部分
  • fused_swish_backward_kernel(const scalar_t* grad_output, const scalar_t* x, scalar_t* grad_x, int size)反向传播,需要给出输入张量x,输出梯度的缓冲区grad_x,还有由于梯度是链式法则,还要传入上一级的梯度grad_output
  • void launch_fused_swish_forward(const scalar_t* input, scalar_t* output, int size)host侧函数,封装核函数,提供给ATen调用
  • hipLaunchKernelGGL(fused_swish_forward_kernel<scalar_t>, dim3(blocks), dim3(threads), 0, 0, input, output, size);动态计算启动的block个数,每个block固定256线程。block数也有上限256
  • template void launch_fused_swish_forward<float>(const float*, float*, int);模板类只有被调用才会编译,这里虽然是库函数,没有调用的地方,但为了编译出我们想要的数据类型的接口,手动实例化一下。注意这里float和double都要实例化,ATen库默认你有这两个类型。
#include <hip/hip_runtime.h>
#include <math.h>

// 1. 前向传播 Kernel:支持 Grid-Stride Loop 和 模板泛型
template <typename scalar_t>
__global__ void fused_swish_forward_kernel(const scalar_t* input, scalar_t* output, int size) {
    // 计算当前线程的全局索引
    int idx = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x;
    // 计算跨步大小:整个网格的线程总数
    int stride = hipBlockDim_x * hipGridDim_x;

    // 网格跨步循环,处理大于线程总数的数据
    for (int i = idx; i < size; i += stride) {
        // 强制转换为 float 进行中间计算,保证精度
        float x = static_cast<float>(input[i]);
        float sigmoid_x = 1.0f / (1.0f + expf(-x));
        // 计算 Swish: x * sigmoid(x),并转回原类型写回
        output[i] = static_cast<scalar_t>(x * sigmoid_x);
    }
}

// 2. 反向传播 Kernel
// Swish 的导数推导: f'(x) = f(x) + sigmoid(x) * (1 - f(x))
template <typename scalar_t>
__global__ void fused_swish_backward_kernel(const scalar_t* grad_output, const scalar_t* x, scalar_t* grad_x, int size) {
    int idx = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x;
    int stride = hipBlockDim_x * hipGridDim_x;

    for (int i = idx; i < size; i += stride) {
        float val_x = static_cast<float>(x[i]);
        float go = static_cast<float>(grad_output[i]);

        float sigmoid_x = 1.0f / (1.0f + expf(-val_x));
        float swish_x = val_x * sigmoid_x;

        // 根据链式法则计算当前元素的梯度:grad_output * f'(x)
        float grad_val = go * (swish_x + sigmoid_x * (1.0f - swish_x));
        grad_x[i] = static_cast<scalar_t>(grad_val);
    }
}

// 3. 供 C++ Wrapper 调用的 Host 端启动函数
template <typename scalar_t>
void launch_fused_swish_forward(const scalar_t* input, scalar_t* output, int size) {
    int threads = 256;
    // 限制最多启动 256 个 Block,利用跨步循环处理超大数据,避免调度过载
    int blocks = min((size + threads - 1) / threads, 256);
    hipLaunchKernelGGL(fused_swish_forward_kernel<scalar_t>, dim3(blocks), dim3(threads), 0, 0, input, output, size);
}

template <typename scalar_t>
void launch_fused_swish_backward(const scalar_t* grad_output, const scalar_t* x, scalar_t* grad_x, int size) {
    int threads = 256;
    int blocks = min((size + threads - 1) / threads, 256);
    hipLaunchKernelGGL(fused_swish_backward_kernel<scalar_t>, dim3(blocks), dim3(threads), 0, 0, grad_output, x, grad_x, size);
}

// 4. 显式实例化模板(告诉编译器我们需要编译哪些数据类型的版本)
template void launch_fused_swish_forward<float>(const float*, float*, int);
template void launch_fused_swish_backward<float>(const float*, const float*, float*, int);
template void launch_fused_swish_forward<double>(const double*, double*, int);
template void launch_fused_swish_backward<double>(const double*, const double*, double*, int);

ATen注册

  • #define CHECK_HIP(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a HIP/CUDA tensor")防御性编程,检查张量是否在GPU上
  • #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")类似地检查张量的内存是否连续。这是因为我们后面的HIP算子都是默认张量内存连续的,也就是不管什么shape的张量,在内存里实际上都是一个一维数组,然后我们对这个一维数组手动寻址,确定每不同下标的值。但是torch层张量支持切片,reshape,转置之类的操作,这些操作可能会导致张量内存不连续,在torch层这是允许的,因为torch.tensor保存了reshape信息,会把代码中的下标映射到实际下标,但这些reshape信息不会传递给cpp,cpp默认是连续的。所以在调用自定义算子前必须保证张量是连续的,这在py里只要tensor.continguou()即可,但以防py代码没写这一行,我们在cpp这里提前检查一下
  • AT_DISPATCH_FLOATING_TYPES,ATen提供的宏,相当于帮我们注册了float和double两个类型的ATen函数,名称是fused_swish_forward,实际内核是launch_fused_swish_forward,这是我们在HIP文件里封装的Host侧接口,input.data_ptr<scalar_t>()对参数做了转换,ATen层的参数都是torch::tensor类型,手动转成指针类型
  • PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)最后pybind绑定到py接口,传入的函数指针是上面定义的ATen接口&fused_swish_forward
#include <torch/extension.h>

// 声明外部 HIP 文件中定义的模板 Launch 函数
template <typename scalar_t>
void launch_fused_swish_forward(const scalar_t* input, scalar_t* output, int size);
template <typename scalar_t>
void launch_fused_swish_backward(const scalar_t* grad_output, const scalar_t* x, scalar_t* grad_x, int size);

// --- 防御性检查宏 ---
// 检查 Tensor 是否在 GPU 上
#define CHECK_HIP(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a HIP/CUDA tensor")
// 检查 Tensor 内存是否连续
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
// 组合检查
#define CHECK_INPUT(x) CHECK_HIP(x); CHECK_CONTIGUOUS(x)

// --- 前向传播 C++ 接口 ---
torch::Tensor fused_swish_forward(torch::Tensor input) {
    CHECK_INPUT(input);
    // 预先分配好显存存放结果,形状、类型和设备与 input 保持一致
    auto output = torch::empty_like(input);

    // 动态分发宏:根据 input 的实际 scalar_type(),自动实例化并调用对应的 C++ 模板函数
    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "fused_swish_forward", ([&] {
        launch_fused_swish_forward<scalar_t>(
            input.data_ptr<scalar_t>(), // 获取底层显存指针
            output.data_ptr<scalar_t>(),
            input.numel() // 获取元素总数
        );
    }));
    return output;
}

// --- 反向传播 C++ 接口 ---
torch::Tensor fused_swish_backward(torch::Tensor grad_output, torch::Tensor x) {
    CHECK_INPUT(grad_output);
    CHECK_INPUT(x);
    // 分配用于存储 x 梯度的显存
    auto grad_x = torch::empty_like(x);

    AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "fused_swish_backward", ([&] {
        launch_fused_swish_backward<scalar_t>(
            grad_output.data_ptr<scalar_t>(),
            x.data_ptr<scalar_t>(),
            grad_x.data_ptr<scalar_t>(),
            x.numel()
        );
    }));
    return grad_x;
}

// 使用 Pybind11 将 C++ 函数暴露给 Python
// TORCH_EXTENSION_NAME 是编译时自动生成的模块名
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &fused_swish_forward, "Fused Swish Forward (HIP)");
    m.def("backward", &fused_swish_backward, "Fused Swish Backward (HIP)");
}

py注册脚本

最后调用torch的注册工具,把前面的ATen封装文件和HIP内核作为输入,进行编译,编译出二进制绑定到py函数,执行python setup.py install即可绑定到当前环境下的py解释器,然后就可以直接调用了

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name='my_custom_swish', # 安装后的包名
    ext_modules=[
        CUDAExtension(
            name='my_custom_swish_backend', # 编译生成的底层库名
            sources=['fused_swish_wrapper.cpp', 'fused_swish_kernel.hip'],
            # 开启 C++ 和 HIP 编译器的最高级别优化 -O3
            # 在 ROCm 环境下,'nvcc' 参数会被传递给 hipcc
            extra_compile_args={'cxx': ['-O3'], 'nvcc':['-O3']}
        )
    ],
    cmdclass={'build_ext': BuildExtension}
)

torch封装和测试

上面我们注册了一个py函数,但是这个函数还不能被torch的Autograd 识别并自动计算梯度,虽然我们实现了反向传播逻辑。需要套一层torch库的逻辑

  • class FusedSwishFunction(torch.autograd.Function)继承自Autograd 类,内部实现torch标准的forward,backward封装
  • ctx.save_for_backward(x)ctx是一个临时缓冲区,把输入x保存下来,后面反向传播要用
  • x, = ctx.saved_tensors反向传播时取出刚刚保存的输入张量
  • if ctx.needs_input_grad[0]:如果需要梯度
  • grad_x = my_custom_swish_backend.backward(grad_output.contiguous(), x)这里才调用我们刚才封装的自定义函数,计算梯度
  • out = fused_swish(x)验证前向传播
  • loss要是一个标量,才能有loss对各个输入的标量梯度loss = out.sum(),如果loss是一个张量,求梯度得到的是一个矩阵
  • loss.backward()自动反向传播,会调用我们上面封装的Autograd 类的backward接口
import torch
# 导入我们刚才编译安装好的底层 C++ 库
import my_custom_swish_backend

class FusedSwishFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        """
        前向传播逻辑
        ctx: 上下文对象,用于存储反向传播需要的信息
        x: 输入 Tensor
        """
        # 1. 调用底层 C++ 前向函数
        result = my_custom_swish_backend.forward(x)
        # 2. 将输入 x 存入上下文(Context),留给反向求导时使用
        # 因为 Swish 的导数计算需要用到原始输入 x
        ctx.save_for_backward(x)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        """
        反向传播逻辑
        ctx: 上下文对象,取回前向存储的信息
        grad_output: 上游传来的梯度
        """
        # 1. 取出前向时保存的 x
        x, = ctx.saved_tensors

        # 2. 检查链式法则上游是否需要梯度 (工业级优化)
        grad_x = None
        if ctx.needs_input_grad[0]:
            # 3. 调用底层 C++ 反向函数
            # 注意: 反向传播传进来的 grad_output 可能因为经过各种切片操作导致在内存中不再连续
            # 所以调用 .contiguous() 是必不可少的防御手段!
            grad_x = my_custom_swish_backend.backward(grad_output.contiguous(), x)

        # 返回输入 x 的梯度
        return grad_x

# 封装成一个优雅的 Python 函数供深度学习模型使用
def fused_swish(x):
    return FusedSwishFunction.apply(x)

# ======== 验证求导链路是否畅通 ========
print("--- 功能与精度验证 ---")
# 创建一个需要梯度的 Tensor
x = torch.randn(5, device='cuda', dtype=torch.float32, requires_grad=True)

print("输入 x:", x)

# 前向传播
out = fused_swish(x)
print("Swish 输出:", out)

# 模拟算出一个标量 Loss
loss = out.sum()

# 一键反向传播!PyTorch 会自动调用我们定义的 backward 方法
loss.backward()

print("自动求导后的梯度 (x.grad):", x.grad)

# 验证:Swish 在 x=0 处的导数应该是 0.5
x_zero = torch.tensor([0.0], device='cuda', requires_grad=True)
fused_swish(x_zero).backward()
print("x=0 处的导数 (预期 0.5):", x_zero.grad.item())

print("Autograd 反向传播打通!自定义算子现在具有学习能力了!")

测试性能

刚才是测试能否调用,这里测试性能

import time

# 准备 5000 万个元素的大张量 (200MB 显存)
size = 50000000
# native 作为参照组
x_native = torch.randn(size, device='cuda', requires_grad=True)
# custom 作为自定义算子测试组,克隆一份独立的数据
x_custom = x_native.clone().detach().requires_grad_(True)

print(f"\n开始 Benchmark,数据大小: {size} 元素...")

# 预热 GPU (防止第一次初始化和 JIT 编译开销影响计时)
for _ in range(10):
    (x_native * torch.sigmoid(x_native)).sum().backward()
    fused_swish(x_custom).sum().backward()

# --- 测试 1: 原生 PyTorch 性能 ---
torch.cuda.synchronize() # 确保 GPU 空闲
start = time.time()
for _ in range(50):
    out = x_native * torch.sigmoid(x_native) # Forward 启动至少 2 个 Kernel
    out.sum().backward()                     # Backward 启动数个 Kernel
torch.cuda.synchronize() # 等待所有任务完成
torch_time = (time.time() - start) / 50 * 1000 # 计算平均耗时 (ms)

# --- 测试 2: 自定义 Fused C++ 算子 ---
torch.cuda.synchronize()
start = time.time()
for _ in range(50):
    out = fused_swish(x_custom)  # Forward 仅启动 1 个 Kernel
    out.sum().backward()         # Backward 仅启动 1 个 Kernel
torch.cuda.synchronize()
custom_time = (time.time() - start) / 50 * 1000

print(f"\n--- 极限性能 Benchmark (5000万元素, 50轮平均) ---")
print(f"原生 PyTorch (Forward + Backward) 耗时: {torch_time:.2f} ms")
print(f"自定义 Fused 算子 (纯 C++) 耗时: {custom_time:.2f} ms")
print(f"综合性能提升: {torch_time / custom_time:.2f} 倍!")

加速比基本1.7x左右,计算操作上我们其实并没有优化,element-wise基本都没法优化计算顺序,逐个计算就完了。这类算子的优化点主要就在融合,减少显存读写次数

开始 Benchmark,数据大小: 50000000 元素...

--- 极限性能 Benchmark (5000万元素, 50轮平均) ---
原生 PyTorch (Forward + Backward) 耗时: 5.25 ms
自定义 Fused 算子 (纯 C++) 耗时: 3.07 ms
综合性能提升: 1.71 倍!
Logo

免费领 200 小时云算力,进群参与显卡、AI PC 幸运抽奖

更多推荐