自定义算子

在这里插入图片描述
之前这篇文章

https://blog.csdn.net/Maxwell_Newton/article/details/162127777?spm=1011.2415.3001.5331

学习了一下自定义ROCm算子接入python的流程,想着自己实现一个复杂一点的算子,跑通一遍这个流程

  • HIP实现内核
  • ATen注册
  • pybind封装
  • python调用,和torch对比测试

完整项目发布于仓库https://github.com/hereisaway/ROCm-online-softmax

算子定义

Softmax 算子(Softmax Operator)是深度学习和机器学习中最常用的激活函数之一,主要用于将一个包含任意实数的向量(通常称为 Logits)映射为一个概率分布。以下是 Softmax 算子的标准定义、数学公式以及其核心特性:

1. 数学定义

对于一个包含 NNN 个元素的输入向量 x=[x1,x2,…,xN]T\mathbf{x} = [x_1, x_2, \dots, x_N]^Tx=[x1,x2,,xN]T,Softmax 算子会将其映射为相同维度的输出向量 s=[s1,s2,…,sN]T\mathbf{s} = [s_1, s_2, \dots, s_N]^Ts=[s1,s2,,sN]T。对于输出向量中的第 iii 个元素 sis_isi,其计算公式如下:si=Softmax(xi)=exi∑j=1Nexjs_i = \text{Softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{N} e^{x_j}}si=Softmax(xi)=j=1Nexjexi关键特性归一化(Probability Distribution):所有输出元素的值都在 (0,1)(0, 1)(0,1) 区间内,且它们的总和严格等于 1:∑i=1Nsi=1\sum_{i=1}^{N} s_i = 1i=1Nsi=1放大差异(“Soft” Max):它被称为“软最大值”,是因为它不仅会挑选出最大的那个数,还会通过指数函数(exie^{x_i}exi)拉大输入之间的差距。相比于 Hardmax(直接将最大值设为 1,其余设为 0),Softmax 保持了函数的可导性,适合梯度下降优化。

2. 硬件与工程实现中的实际工程定义

在 GPU 计算(如 CUDA 内核开发)或高性能算子优化中,直接使用上述标准公式进行计算会遭遇数值溢出(Numerical Overflow)的问题。因为当 xix_ixi 较大时(例如 xi=1000x_i = 1000xi=1000),e1000e^{1000}e1000 会直接导致浮点数溢出(inf)。因此,工业界(如 cuDNN, PyTorch, TensorRT)在实现 Softmax 算子时,通常使用的是平移不变性(Translation Invariance)修正后的安全版本。安全的 Softmax 计算公式令 M=max⁡(x1,x2,…,xN)M = \max(x_1, x_2, \dots, x_N)M=max(x1,x2,,xN),则有:si=exi−M∑j=1Nexj−Ms_i = \frac{e^{x_i - M}}{\sum_{j=1}^{N} e^{x_j - M}}si=j=1NexjMexiM通过减去最大值 MMM,输入向量的最大值被限制为 0(e0=1e^0 = 1e0=1),其余值全部 ≤0\le 00,这保证了指数项 exi−Me^{x_i - M}exiM 的取值范围严格在 (0,1](0, 1](0,1] 之间,彻底避免了上溢出(Overflow)问题(即使遭遇下溢出变为 0,也是安全的)。

实现思路

1. 朴素 Softmax (Naive Softmax)

为了保证数值安全性,工业界常用的“朴素”安全 Softmax 包含 3 个独立的循环步骤(即 3-pass 算法)。对于一个长度为 NNN 的向量 XXX

  • Pass 1 (求最大值):找出最大值 M=max⁡j(xj)M = \max_j(x_j)M=maxj(xj)
  • Pass 2 (求指数和):计算每个元素的指数并累加,D=∑jexj−MD = \sum_j e^{x_j - M}D=jexjM
  • Pass 3 (归一化):计算最终结果 si=exi−MDs_i = \frac{e^{x_i - M}}{D}si=DexiM

硬件视角的致命缺陷:多轮访存在 GPU 编程中,如果向量 NNN 的规模很大(或者处于 Transformer 的 Attention 矩阵计算中),这个 3-pass3\text{-pass}3-pass 的逻辑会带来巨大的显存(HBM/DRAM)带宽浪费:

  • 第一次访存:从 Global Memory 读取 XXX,计算出 MMM,写回或暂存。
  • 第二次访存:再次从 Global Memory 读取 XXX,计算出 DDD
  • 第三次访存:第三次从 Global Memory 读取 XXX,计算出 sis_isi,并将结果写回 Global Memory。

这意味着数据被反复读取了 3 次(3 次 Read,1 次 Write)。由于 Softmax 属于 Element-wise(访存密集型) 算子,计算核心(ALU)大部分时间都在等待数据从显存传输,算子性能完全受限于显存带宽(Memory-bound)。

2. Online Softmax

为了减少访存次数,FlashAttention 的核心前置技术——Online Softmax(由 Nvidia 的 Max Milakov 等人提出)应运而生。它的核心思想是:将“求最大值”和“求分母累加和”合二为一,在单次遍历(1-pass)中动态更新。

数学递推原理:假设我们已经处理了前 kkk 个元素,当前已知的最大值为 MkM_kMk,当前分母的指数累加和为 DkD_kDk。当处理第 k+1k+1k+1 个元素 xk+1x_{k+1}xk+1 时,我们需要动态更新这两个局部变量:

  • 更新局部最大值 Mk+1M_{k+1}Mk+1Mk+1=max⁡(Mk,xk+1)M_{k+1} = \max(M_k, x_{k+1})Mk+1=max(Mk,xk+1)
  • 动态修正并更新累加和 Dk+1D_{k+1}Dk+1:如果新的元素比之前的最大值还要大(xk+1>Mkx_{k+1} > M_kxk+1>Mk),那么之前基于 MkM_kMk 计算的指数和 DkD_kDk 就缩放错了比例。我们需要利用指数的性质对其进行重缩放(Rescaling)修正:Dk+1=Dk⋅eMk−Mk+1+exk+1−Mk+1D_{k+1} = D_k \cdot e^{M_k - M_{k+1}} + e^{x_{k+1} - M_{k+1}}Dk+1=DkeMkMk+1+exk+1Mk+1
  • 物理含义:eMk−Mk+1e^{M_k - M_{k+1}}eMkMk+1 恰好是将旧的累加和从以 MkM_kMk 为基准,对齐到以新的 Mk+1M_{k+1}Mk+1 为基准。

算法执行流程 (2-pass):通过这个递推公式,Online Softmax 将算法精简为了 2-pass:

  • Pass 1 (Online 迭代):只需遍历一遍数据,就能同时得到全局最大值 MNM_NMN 和全局正确的指数分母和 DND_NDN
  • Pass 2 (归一化):再遍历一遍数据,直接输出最终的 si=exi−MNDNs_i = \frac{e^{x_i - M_N}}{D_N}si=DNexiMN

实现流程

HIP 朴素softmax内核

简单的3pass,也就是三次循环。分块思路是,输入的二维矩阵,每一行分给一个线程块。三次循环前两次都是reduce操作,一个max一个sum,用前一篇reduce算子的思路,warp reduce规约。第三次循环就是一个element-wise 除法,跨步循环即可。

  • const float* row_x = x + static_cast<int64_t>(row) * cols,当前block处理输入张量第第几行,找到这一行的起始地址
  • float* row_y = y + static_cast<int64_t>(row) * cols;同样找到在输出张量的起始地址
  • for (int col = tid; col < cols; col += blockDim.x)第一个循环,跨步循环,实现任意长度的规约,每个block只有blockdim.x个线程,所以blockdim.x为步长
  • local_max = warp_reduce_max(local_max);跨步循环之后warp reduce求出每个warp的结果
  • warp_buf[warp_id] = local_sum;每个warp第一个线程结果写入线程束对应的缓冲区
  • row_m = lane < num_warps ? warp_buf[lane] : -INFINITY;第一个warp对每个warp的结果再次进行规约
  • warp_buf[0] = row_m;至此得到了每一行的最大指数
  • 第二个循环,仍然是规约,只不过这次规约的是sumlocal_sum += expf(row_x[col] - row_m);,local_sum = warp_reduce_sum(local_sum);别的都和第一轮差不多
  • row_y[col] = expf(row_x[col] - row_m) / row_d;最后一次循环用sum和max做缩放,计算每个位置的结果
__global__ void naive_3pass_softmax_kernel(const float* __restrict__ x,
                                           float* __restrict__ y,
                                           int rows,
                                           int cols) {
    int row = blockIdx.x;
    if (row >= rows) {
        return;
    }

    int tid = threadIdx.x;
    int lane = tid % warpSize;
    int warp_id = tid / warpSize;
    int num_warps = (blockDim.x + warpSize - 1) / warpSize;

    const float* row_x = x + static_cast<int64_t>(row) * cols;
    float* row_y = y + static_cast<int64_t>(row) * cols;

    __shared__ float warp_buf[kMaxWarpsPerBlock];

    float local_max = -INFINITY;
    for (int col = tid; col < cols; col += blockDim.x) {
        local_max = fmaxf(local_max, row_x[col]);
    }
    local_max = warp_reduce_max(local_max);
    if (lane == 0) {
        warp_buf[warp_id] = local_max;
    }
    __syncthreads();

    float row_m = -INFINITY;
    if (warp_id == 0) {
        row_m = lane < num_warps ? warp_buf[lane] : -INFINITY;
        row_m = warp_reduce_max(row_m);
        if (lane == 0) {
            warp_buf[0] = row_m;
        }
    }
    __syncthreads();
    row_m = warp_buf[0];

    float local_sum = 0.0f;
    for (int col = tid; col < cols; col += blockDim.x) {
        local_sum += expf(row_x[col] - row_m);
    }
    local_sum = warp_reduce_sum(local_sum);
    if (lane == 0) {
        warp_buf[warp_id] = local_sum;
    }
    __syncthreads();

    float row_d = 0.0f;
    if (warp_id == 0) {
        row_d = lane < num_warps ? warp_buf[lane] : 0.0f;
        row_d = warp_reduce_sum(row_d);
        if (lane == 0) {
            warp_buf[0] = row_d;
        }
    }
    __syncthreads();
    row_d = warp_buf[0];

    for (int col = tid; col < cols; col += blockDim.x) {
        row_y[col] = expf(row_x[col] - row_m) / row_d;
    }
}

HIP online softmax内核

因为我们要在第一次循环就做完max和sum的规约,仍然考虑使用warp reduce规约框架,但是需要定义一个类,包含max和sum信息,然后定义这个类的合并规则,然后把原本的warp reduce里的max,+操作改成这个合并操作即可

  • local = online_update(local, row_x[col]);初始跨步循环,维护一个OnlinePair ,每次加入一个新元素,更新最大值和sum
  • local = warp_reduce_online_pair(local);仍然线程束规约,只是规约的是OnlinePair ,内部调用v = online_combine(v, other);处理两个线程的合并,注意和online_update不同,这里合并的是两个OnlinePair 对象,而online_update是一个OnlinePair 和一个普通元素x
  • online_combine内,先确定最大值,这是好做的,然后要把最大值更小的一侧缩放,用ifelse不太好,会造成线程束发散,我们直接用这个通用公式out.d = a.d * expf(a.m - m) + b.d * expf(b.m - m);代价是会两次expf调用
struct OnlinePair {
    float m;
    float d;
};

__device__ __forceinline__ OnlinePair online_update(OnlinePair acc, float x) {
    float new_m = fmaxf(acc.m, x);
    acc.d = acc.d * expf(acc.m - new_m) + expf(x - new_m);
    acc.m = new_m;
    return acc;
}

__device__ __forceinline__ OnlinePair online_combine(OnlinePair a, OnlinePair b) {
    if (a.d == 0.0f) {
        return b;
    }
    if (b.d == 0.0f) {
        return a;
    }
    float m = fmaxf(a.m, b.m);
    OnlinePair out;
    out.m = m;
    out.d = a.d * expf(a.m - m) + b.d * expf(b.m - m);
    return out;
}

__device__ __forceinline__ OnlinePair warp_reduce_online_pair(OnlinePair v) {
    for (int offset = warpSize / 2; offset > 0; offset >>= 1) {
        OnlinePair other;
        other.m = __shfl_down(v.m, offset);
        other.d = __shfl_down(v.d, offset);
        v = online_combine(v, other);
    }
    return v;
}

__device__ __forceinline__ float warp_reduce_max(float v) {
    for (int offset = warpSize / 2; offset > 0; offset >>= 1) {
        v = fmaxf(v, __shfl_down(v, offset));
    }
    return v;
}

__device__ __forceinline__ float warp_reduce_sum(float v) {
    for (int offset = warpSize / 2; offset > 0; offset >>= 1) {
        v += __shfl_down(v, offset);
    }
    return v;
}

__global__ void online_softmax_kernel(const float* __restrict__ x,
                                      float* __restrict__ y,
                                      int rows,
                                      int cols) {
    int row = blockIdx.x;
    if (row >= rows) {
        return;
    }

    int tid = threadIdx.x;
    int lane = tid % warpSize;
    int warp_id = tid / warpSize;
    int num_warps = (blockDim.x + warpSize - 1) / warpSize;

    const float* row_x = x + static_cast<int64_t>(row) * cols;
    float* row_y = y + static_cast<int64_t>(row) * cols;

    OnlinePair local{-INFINITY, 0.0f};
    for (int col = tid; col < cols; col += blockDim.x) {
        local = online_update(local, row_x[col]);
    }

    local = warp_reduce_online_pair(local);

    __shared__ float warp_m[kMaxWarpsPerBlock];
    __shared__ float warp_d[kMaxWarpsPerBlock];
    if (lane == 0) {
        warp_m[warp_id] = local.m;
        warp_d[warp_id] = local.d;
    }
    __syncthreads();

    OnlinePair block_pair{-INFINITY, 0.0f};
    if (warp_id == 0) {
        if (lane < num_warps) {
            block_pair.m = warp_m[lane];
            block_pair.d = warp_d[lane];
        }
        block_pair = warp_reduce_online_pair(block_pair);
        if (lane == 0) {
            warp_m[0] = block_pair.m;
            warp_d[0] = block_pair.d;
        }
    }
    __syncthreads();

    float row_m = warp_m[0];
    float row_d = warp_d[0];
    for (int col = tid; col < cols; col += blockDim.x) {
        row_y[col] = expf(row_x[col] - row_m) / row_d;
    }
}

HOST侧封装

host侧封装成两个cpu函数,提供给ATen调用

void launch_online_softmax(const float* input, float* output, int rows, int cols) {
    hipLaunchKernelGGL(online_softmax_kernel,
                       dim3(rows),
                       dim3(kThreadsPerBlock),
                       0,
                       0,
                       input,
                       output,
                       rows,
                       cols);
}

void launch_naive_3pass_softmax(const float* input, float* output, int rows, int cols) {
    hipLaunchKernelGGL(naive_3pass_softmax_kernel,
                       dim3(rows),
                       dim3(kThreadsPerBlock),
                       0,
                       0,
                       input,
                       output,
                       rows,
                       cols);
}

ATen封装

ATen库注册,封装成两个输入输出是torch::tensor的函数,并pybind绑定到python函数

#include <torch/extension.h>

void launch_online_softmax(const float* input, float* output, int rows, int cols);
void launch_naive_3pass_softmax(const float* input, float* output, int rows, int cols);

#define CHECK_HIP(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a HIP/CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_FLOAT32(x) TORCH_CHECK(x.scalar_type() == torch::kFloat32, #x " must be float32")
#define CHECK_2D(x) TORCH_CHECK(x.dim() == 2, #x " must be a 2D tensor")
#define CHECK_INPUT(x) \
    CHECK_HIP(x);      \
    CHECK_CONTIGUOUS(x); \
    CHECK_FLOAT32(x);  \
    CHECK_2D(x)

torch::Tensor online_softmax_forward(torch::Tensor input) {
    CHECK_INPUT(input);

    auto output = torch::empty_like(input);
    int rows = static_cast<int>(input.size(0));
    int cols = static_cast<int>(input.size(1));

    launch_online_softmax(input.data_ptr<float>(), output.data_ptr<float>(), rows, cols);
    return output;
}

torch::Tensor naive_3pass_softmax_forward(torch::Tensor input) {
    CHECK_INPUT(input);

    auto output = torch::empty_like(input);
    int rows = static_cast<int>(input.size(0));
    int cols = static_cast<int>(input.size(1));

    launch_naive_3pass_softmax(input.data_ptr<float>(), output.data_ptr<float>(), rows, cols);
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("online_forward", &online_softmax_forward, "Online softmax forward (HIP)");
    m.def("naive_forward", &naive_3pass_softmax_forward, "Naive 3-pass softmax forward (HIP)");
}

python注册

用torch注册接口注册

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension


setup(
    name="my_custom_softmax",
    ext_modules=[
        CUDAExtension(
            name="my_custom_softmax_backend",
            sources=["softmax_wrapper.cpp", "softmax_kernel.hip"],
            extra_compile_args={"cxx": ["-O3"], "nvcc": ["-O3"]},
        )
    ],
    cmdclass={"build_ext": BuildExtension},
)

性能测试

调用ROCm朴素实现,ROCm online softmax实现,Torch接口,分别测试,对比精度误差和性能

import argparse
import time

import torch
import my_custom_softmax_backend


def online_softmax(x):
    return my_custom_softmax_backend.online_forward(x)


def naive_softmax(x):
    return my_custom_softmax_backend.naive_forward(x)


def bench(fn, x, warmup_iters, bench_iters):
    for _ in range(warmup_iters):
        fn(x)
    torch.cuda.synchronize()

    start = time.time()
    for _ in range(bench_iters):
        fn(x)
    torch.cuda.synchronize()
    return (time.time() - start) * 1000 / bench_iters


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--rows", type=int, default=4096)
    parser.add_argument("--cols", type=int, default=1024)
    parser.add_argument("--iters", type=int, default=200)
    parser.add_argument("--warmup", type=int, default=20)
    parser.add_argument("--scale", type=float, default=10.0)
    args = parser.parse_args()

    x = torch.randn(args.rows, args.cols, device="cuda", dtype=torch.float32) * args.scale

    online = online_softmax(x)
    naive = naive_softmax(x)
    ref = torch.softmax(x, dim=-1)
    torch.cuda.synchronize()

    print(f"shape=({args.rows}, {args.cols})")
    print(f"max_abs_diff(online, torch) = {(online - ref).abs().max().item():.8g}")
    print(f"max_abs_diff(naive,  torch) = {(naive - ref).abs().max().item():.8g}")
    print(f"max_abs_diff(online, naive) = {(online - naive).abs().max().item():.8g}")

    torch_ms = bench(lambda t: torch.softmax(t, dim=-1), x, args.warmup, args.iters)
    online_ms = bench(online_softmax, x, args.warmup, args.iters)
    naive_ms = bench(naive_softmax, x, args.warmup, args.iters)

    print("Performance:")
    print(f"  torch.softmax:       {torch_ms:.4f} ms")
    print(f"  online two-pass:     {online_ms:.4f} ms")
    print(f"  naive three-pass:    {naive_ms:.4f} ms")
    print(f"  speedup online/naive {naive_ms / online_ms:.3f}x")
    print(f"  speedup online/torch {torch_ms / online_ms:.3f}x")


if __name__ == "__main__":
    main()

结果分析

来几个典型结果:cols较大时,online优于naive,但是如果rows较小,online还是不如torch。这分两方面看:

  • online打过naive是因为online少一次全局内存读写,肯定是比naive更优的
  • online打不过torch是因为online的分块方式是,每一个block负责一行,行数不够的时候并行度不够,无法喂饱CU计算单元,可能是block不够多,划分出的warp不够多,无法让CU超线程掩盖数据搬运延迟
root@u-1917-fcb583f1:/workspace/repo/src/infra/custom-pytorch-operator/code/custom_softmax#  python bench_softmax.py --rows 256 --cols 8192 --iters 200 --warmup 20
shape=(256, 8192)
max_abs_diff(online, torch) = 3.5762787e-07
max_abs_diff(naive,  torch) = 3.5762787e-07
max_abs_diff(online, naive) = 1.1920929e-07
Performance:
  torch.softmax:       0.0225 ms
  online two-pass:     0.0282 ms
  naive three-pass:    0.0323 ms
  speedup online/naive 1.145x
  speedup online/torch 0.799x

row不大,col也不大,online和torch接近,但都不如naive

  • online和torch估计都是某种形式的online softmax,故接近
  • online在cols不大的时候,虽然少一次循环,少一次显存读写,但是合并计算的常数显然比正常的reduce max,reduce sum要大,合并时有两个exp调用,还有ifelse,故更慢。只有col更大的时候,节约的这一次显存读写,才能覆盖合并操作的开销,让online对naive实现反超
root@u-1917-fcb583f1:/workspace/repo/src/infra/custom-pytorch-operator/code/custom_softmax# python bench_softmax.py --rows 256 --co
ls 2048 --iters 200 --warmup 20
shape=(256, 2048)
max_abs_diff(online, torch) = 3.5762787e-07
max_abs_diff(naive,  torch) = 3.5762787e-07
max_abs_diff(online, naive) = 1.1920929e-07
Performance:
  torch.softmax:       0.0111 ms
  online two-pass:     0.0108 ms
  naive three-pass:    0.0097 ms
  speedup online/naive 0.897x
  speedup online/torch 1.028x

继续提高cols,online相对于naive的优势更明显。naive 3-pass理论上有三次读显存,一次写显存,四次显存访问,online 2-pass有两次读显存,一次写显存,如果把访存视为主要瓶颈的话(事实上确实如此),online相对于naive的理论最大加速比为4/3=1.33x4/3=1.33x4/3=1.33x,所以能达到1.28x是合理的,甚至还有一定加速空间

root@u-1917-fcb583f1:/workspace/repo/src/infra/custom-pytorch-operator/code/custom_softmax# python bench_softmax.py --rows 256 --co
ls 2097152 --iters 200 --warmup 20
shape=(256, 2097152)
max_abs_diff(online, torch) = 3.5762787e-07
max_abs_diff(naive,  torch) = 3.5762787e-07
max_abs_diff(online, naive) = 1.1920929e-07
Performance:
  torch.softmax:       11.5582 ms
  online two-pass:     11.9933 ms
  naive three-pass:    15.3814 ms
  speedup online/naive 1.283x
  speedup online/torch 0.964x

数据量继续增大,这里增加行数。

  • 数据量的变大会方法online在数据搬运上的优势,掩盖在计算上的劣势,加速比进一步放大,能达到1.32x,接近理论上限1.33x,相比torch都实现了1.15的优化
  • 并且这里增加行数,弥补了我们按行分块导致的block,warp数不足的缺点,有效实现了CU超线程,应该在超线程这一块和torch没有差距了,同时每一行的recude+2pass实现比torch更优,因此实现了由于torch的加速比
root@u-1917-fcb583f1:/workspace/repo/src/infra/custom-pytorch-operator/code/custom_softmax# python bench_softmax.py --rows 4096 --c
ols 262144 --iters 200 --warmup 20
shape=(4096, 262144)
max_abs_diff(online, torch) = 4.7683716e-07
max_abs_diff(naive,  torch) = 4.7683716e-07
max_abs_diff(online, naive) = 1.7881393e-07
Performance:
  torch.softmax:       24.4696 ms
  online two-pass:     21.1227 ms
  naive three-pass:    27.9633 ms
  speedup online/naive 1.324x
  speedup online/torch 1.158x
Logo

免费领 200 小时云算力,进群参与显卡、AI PC 幸运抽奖

更多推荐