ROCm 自定义softmax算子
自定义算子

之前这篇文章
https://blog.csdn.net/Maxwell_Newton/article/details/162127777?spm=1011.2415.3001.5331
学习了一下自定义ROCm算子接入python的流程,想着自己实现一个复杂一点的算子,跑通一遍这个流程
- HIP实现内核
- ATen注册
- pybind封装
- python调用,和torch对比测试
完整项目发布于仓库https://github.com/hereisaway/ROCm-online-softmax
算子定义
Softmax 算子(Softmax Operator)是深度学习和机器学习中最常用的激活函数之一,主要用于将一个包含任意实数的向量(通常称为 Logits)映射为一个概率分布。以下是 Softmax 算子的标准定义、数学公式以及其核心特性:
1. 数学定义
对于一个包含 NNN 个元素的输入向量 x=[x1,x2,…,xN]T\mathbf{x} = [x_1, x_2, \dots, x_N]^Tx=[x1,x2,…,xN]T,Softmax 算子会将其映射为相同维度的输出向量 s=[s1,s2,…,sN]T\mathbf{s} = [s_1, s_2, \dots, s_N]^Ts=[s1,s2,…,sN]T。对于输出向量中的第 iii 个元素 sis_isi,其计算公式如下:si=Softmax(xi)=exi∑j=1Nexjs_i = \text{Softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{N} e^{x_j}}si=Softmax(xi)=∑j=1Nexjexi关键特性归一化(Probability Distribution):所有输出元素的值都在 (0,1)(0, 1)(0,1) 区间内,且它们的总和严格等于 1:∑i=1Nsi=1\sum_{i=1}^{N} s_i = 1i=1∑Nsi=1放大差异(“Soft” Max):它被称为“软最大值”,是因为它不仅会挑选出最大的那个数,还会通过指数函数(exie^{x_i}exi)拉大输入之间的差距。相比于 Hardmax(直接将最大值设为 1,其余设为 0),Softmax 保持了函数的可导性,适合梯度下降优化。
2. 硬件与工程实现中的实际工程定义
在 GPU 计算(如 CUDA 内核开发)或高性能算子优化中,直接使用上述标准公式进行计算会遭遇数值溢出(Numerical Overflow)的问题。因为当 xix_ixi 较大时(例如 xi=1000x_i = 1000xi=1000),e1000e^{1000}e1000 会直接导致浮点数溢出(inf)。因此,工业界(如 cuDNN, PyTorch, TensorRT)在实现 Softmax 算子时,通常使用的是平移不变性(Translation Invariance)修正后的安全版本。安全的 Softmax 计算公式令 M=max(x1,x2,…,xN)M = \max(x_1, x_2, \dots, x_N)M=max(x1,x2,…,xN),则有:si=exi−M∑j=1Nexj−Ms_i = \frac{e^{x_i - M}}{\sum_{j=1}^{N} e^{x_j - M}}si=∑j=1Nexj−Mexi−M通过减去最大值 MMM,输入向量的最大值被限制为 0(e0=1e^0 = 1e0=1),其余值全部 ≤0\le 0≤0,这保证了指数项 exi−Me^{x_i - M}exi−M 的取值范围严格在 (0,1](0, 1](0,1] 之间,彻底避免了上溢出(Overflow)问题(即使遭遇下溢出变为 0,也是安全的)。
实现思路
1. 朴素 Softmax (Naive Softmax)
为了保证数值安全性,工业界常用的“朴素”安全 Softmax 包含 3 个独立的循环步骤(即 3-pass 算法)。对于一个长度为 NNN 的向量 XXX:
- Pass 1 (求最大值):找出最大值 M=maxj(xj)M = \max_j(x_j)M=maxj(xj)。
- Pass 2 (求指数和):计算每个元素的指数并累加,D=∑jexj−MD = \sum_j e^{x_j - M}D=∑jexj−M。
- Pass 3 (归一化):计算最终结果 si=exi−MDs_i = \frac{e^{x_i - M}}{D}si=Dexi−M。
硬件视角的致命缺陷:多轮访存在 GPU 编程中,如果向量 NNN 的规模很大(或者处于 Transformer 的 Attention 矩阵计算中),这个 3-pass3\text{-pass}3-pass 的逻辑会带来巨大的显存(HBM/DRAM)带宽浪费:
- 第一次访存:从 Global Memory 读取 XXX,计算出 MMM,写回或暂存。
- 第二次访存:再次从 Global Memory 读取 XXX,计算出 DDD。
- 第三次访存:第三次从 Global Memory 读取 XXX,计算出 sis_isi,并将结果写回 Global Memory。
这意味着数据被反复读取了 3 次(3 次 Read,1 次 Write)。由于 Softmax 属于 Element-wise(访存密集型) 算子,计算核心(ALU)大部分时间都在等待数据从显存传输,算子性能完全受限于显存带宽(Memory-bound)。
2. Online Softmax
为了减少访存次数,FlashAttention 的核心前置技术——Online Softmax(由 Nvidia 的 Max Milakov 等人提出)应运而生。它的核心思想是:将“求最大值”和“求分母累加和”合二为一,在单次遍历(1-pass)中动态更新。
数学递推原理:假设我们已经处理了前 kkk 个元素,当前已知的最大值为 MkM_kMk,当前分母的指数累加和为 DkD_kDk。当处理第 k+1k+1k+1 个元素 xk+1x_{k+1}xk+1 时,我们需要动态更新这两个局部变量:
- 更新局部最大值 Mk+1M_{k+1}Mk+1:Mk+1=max(Mk,xk+1)M_{k+1} = \max(M_k, x_{k+1})Mk+1=max(Mk,xk+1)
- 动态修正并更新累加和 Dk+1D_{k+1}Dk+1:如果新的元素比之前的最大值还要大(xk+1>Mkx_{k+1} > M_kxk+1>Mk),那么之前基于 MkM_kMk 计算的指数和 DkD_kDk 就缩放错了比例。我们需要利用指数的性质对其进行重缩放(Rescaling)修正:Dk+1=Dk⋅eMk−Mk+1+exk+1−Mk+1D_{k+1} = D_k \cdot e^{M_k - M_{k+1}} + e^{x_{k+1} - M_{k+1}}Dk+1=Dk⋅eMk−Mk+1+exk+1−Mk+1
- 物理含义:eMk−Mk+1e^{M_k - M_{k+1}}eMk−Mk+1 恰好是将旧的累加和从以 MkM_kMk 为基准,对齐到以新的 Mk+1M_{k+1}Mk+1 为基准。
算法执行流程 (2-pass):通过这个递推公式,Online Softmax 将算法精简为了 2-pass:
- Pass 1 (Online 迭代):只需遍历一遍数据,就能同时得到全局最大值 MNM_NMN 和全局正确的指数分母和 DND_NDN。
- Pass 2 (归一化):再遍历一遍数据,直接输出最终的 si=exi−MNDNs_i = \frac{e^{x_i - M_N}}{D_N}si=DNexi−MN。
实现流程
HIP 朴素softmax内核
简单的3pass,也就是三次循环。分块思路是,输入的二维矩阵,每一行分给一个线程块。三次循环前两次都是reduce操作,一个max一个sum,用前一篇reduce算子的思路,warp reduce规约。第三次循环就是一个element-wise 除法,跨步循环即可。
const float* row_x = x + static_cast<int64_t>(row) * cols,当前block处理输入张量第第几行,找到这一行的起始地址float* row_y = y + static_cast<int64_t>(row) * cols;同样找到在输出张量的起始地址for (int col = tid; col < cols; col += blockDim.x)第一个循环,跨步循环,实现任意长度的规约,每个block只有blockdim.x个线程,所以blockdim.x为步长local_max = warp_reduce_max(local_max);跨步循环之后warp reduce求出每个warp的结果warp_buf[warp_id] = local_sum;每个warp第一个线程结果写入线程束对应的缓冲区row_m = lane < num_warps ? warp_buf[lane] : -INFINITY;第一个warp对每个warp的结果再次进行规约warp_buf[0] = row_m;至此得到了每一行的最大指数- 第二个循环,仍然是规约,只不过这次规约的是sum
local_sum += expf(row_x[col] - row_m);,local_sum = warp_reduce_sum(local_sum);别的都和第一轮差不多 row_y[col] = expf(row_x[col] - row_m) / row_d;最后一次循环用sum和max做缩放,计算每个位置的结果
__global__ void naive_3pass_softmax_kernel(const float* __restrict__ x,
float* __restrict__ y,
int rows,
int cols) {
int row = blockIdx.x;
if (row >= rows) {
return;
}
int tid = threadIdx.x;
int lane = tid % warpSize;
int warp_id = tid / warpSize;
int num_warps = (blockDim.x + warpSize - 1) / warpSize;
const float* row_x = x + static_cast<int64_t>(row) * cols;
float* row_y = y + static_cast<int64_t>(row) * cols;
__shared__ float warp_buf[kMaxWarpsPerBlock];
float local_max = -INFINITY;
for (int col = tid; col < cols; col += blockDim.x) {
local_max = fmaxf(local_max, row_x[col]);
}
local_max = warp_reduce_max(local_max);
if (lane == 0) {
warp_buf[warp_id] = local_max;
}
__syncthreads();
float row_m = -INFINITY;
if (warp_id == 0) {
row_m = lane < num_warps ? warp_buf[lane] : -INFINITY;
row_m = warp_reduce_max(row_m);
if (lane == 0) {
warp_buf[0] = row_m;
}
}
__syncthreads();
row_m = warp_buf[0];
float local_sum = 0.0f;
for (int col = tid; col < cols; col += blockDim.x) {
local_sum += expf(row_x[col] - row_m);
}
local_sum = warp_reduce_sum(local_sum);
if (lane == 0) {
warp_buf[warp_id] = local_sum;
}
__syncthreads();
float row_d = 0.0f;
if (warp_id == 0) {
row_d = lane < num_warps ? warp_buf[lane] : 0.0f;
row_d = warp_reduce_sum(row_d);
if (lane == 0) {
warp_buf[0] = row_d;
}
}
__syncthreads();
row_d = warp_buf[0];
for (int col = tid; col < cols; col += blockDim.x) {
row_y[col] = expf(row_x[col] - row_m) / row_d;
}
}
HIP online softmax内核
因为我们要在第一次循环就做完max和sum的规约,仍然考虑使用warp reduce规约框架,但是需要定义一个类,包含max和sum信息,然后定义这个类的合并规则,然后把原本的warp reduce里的max,+操作改成这个合并操作即可
local = online_update(local, row_x[col]);初始跨步循环,维护一个OnlinePair ,每次加入一个新元素,更新最大值和sumlocal = warp_reduce_online_pair(local);仍然线程束规约,只是规约的是OnlinePair ,内部调用v = online_combine(v, other);处理两个线程的合并,注意和online_update不同,这里合并的是两个OnlinePair 对象,而online_update是一个OnlinePair 和一个普通元素xonline_combine内,先确定最大值,这是好做的,然后要把最大值更小的一侧缩放,用ifelse不太好,会造成线程束发散,我们直接用这个通用公式out.d = a.d * expf(a.m - m) + b.d * expf(b.m - m);代价是会两次expf调用
struct OnlinePair {
float m;
float d;
};
__device__ __forceinline__ OnlinePair online_update(OnlinePair acc, float x) {
float new_m = fmaxf(acc.m, x);
acc.d = acc.d * expf(acc.m - new_m) + expf(x - new_m);
acc.m = new_m;
return acc;
}
__device__ __forceinline__ OnlinePair online_combine(OnlinePair a, OnlinePair b) {
if (a.d == 0.0f) {
return b;
}
if (b.d == 0.0f) {
return a;
}
float m = fmaxf(a.m, b.m);
OnlinePair out;
out.m = m;
out.d = a.d * expf(a.m - m) + b.d * expf(b.m - m);
return out;
}
__device__ __forceinline__ OnlinePair warp_reduce_online_pair(OnlinePair v) {
for (int offset = warpSize / 2; offset > 0; offset >>= 1) {
OnlinePair other;
other.m = __shfl_down(v.m, offset);
other.d = __shfl_down(v.d, offset);
v = online_combine(v, other);
}
return v;
}
__device__ __forceinline__ float warp_reduce_max(float v) {
for (int offset = warpSize / 2; offset > 0; offset >>= 1) {
v = fmaxf(v, __shfl_down(v, offset));
}
return v;
}
__device__ __forceinline__ float warp_reduce_sum(float v) {
for (int offset = warpSize / 2; offset > 0; offset >>= 1) {
v += __shfl_down(v, offset);
}
return v;
}
__global__ void online_softmax_kernel(const float* __restrict__ x,
float* __restrict__ y,
int rows,
int cols) {
int row = blockIdx.x;
if (row >= rows) {
return;
}
int tid = threadIdx.x;
int lane = tid % warpSize;
int warp_id = tid / warpSize;
int num_warps = (blockDim.x + warpSize - 1) / warpSize;
const float* row_x = x + static_cast<int64_t>(row) * cols;
float* row_y = y + static_cast<int64_t>(row) * cols;
OnlinePair local{-INFINITY, 0.0f};
for (int col = tid; col < cols; col += blockDim.x) {
local = online_update(local, row_x[col]);
}
local = warp_reduce_online_pair(local);
__shared__ float warp_m[kMaxWarpsPerBlock];
__shared__ float warp_d[kMaxWarpsPerBlock];
if (lane == 0) {
warp_m[warp_id] = local.m;
warp_d[warp_id] = local.d;
}
__syncthreads();
OnlinePair block_pair{-INFINITY, 0.0f};
if (warp_id == 0) {
if (lane < num_warps) {
block_pair.m = warp_m[lane];
block_pair.d = warp_d[lane];
}
block_pair = warp_reduce_online_pair(block_pair);
if (lane == 0) {
warp_m[0] = block_pair.m;
warp_d[0] = block_pair.d;
}
}
__syncthreads();
float row_m = warp_m[0];
float row_d = warp_d[0];
for (int col = tid; col < cols; col += blockDim.x) {
row_y[col] = expf(row_x[col] - row_m) / row_d;
}
}
HOST侧封装
host侧封装成两个cpu函数,提供给ATen调用
void launch_online_softmax(const float* input, float* output, int rows, int cols) {
hipLaunchKernelGGL(online_softmax_kernel,
dim3(rows),
dim3(kThreadsPerBlock),
0,
0,
input,
output,
rows,
cols);
}
void launch_naive_3pass_softmax(const float* input, float* output, int rows, int cols) {
hipLaunchKernelGGL(naive_3pass_softmax_kernel,
dim3(rows),
dim3(kThreadsPerBlock),
0,
0,
input,
output,
rows,
cols);
}
ATen封装
ATen库注册,封装成两个输入输出是torch::tensor的函数,并pybind绑定到python函数
#include <torch/extension.h>
void launch_online_softmax(const float* input, float* output, int rows, int cols);
void launch_naive_3pass_softmax(const float* input, float* output, int rows, int cols);
#define CHECK_HIP(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a HIP/CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_FLOAT32(x) TORCH_CHECK(x.scalar_type() == torch::kFloat32, #x " must be float32")
#define CHECK_2D(x) TORCH_CHECK(x.dim() == 2, #x " must be a 2D tensor")
#define CHECK_INPUT(x) \
CHECK_HIP(x); \
CHECK_CONTIGUOUS(x); \
CHECK_FLOAT32(x); \
CHECK_2D(x)
torch::Tensor online_softmax_forward(torch::Tensor input) {
CHECK_INPUT(input);
auto output = torch::empty_like(input);
int rows = static_cast<int>(input.size(0));
int cols = static_cast<int>(input.size(1));
launch_online_softmax(input.data_ptr<float>(), output.data_ptr<float>(), rows, cols);
return output;
}
torch::Tensor naive_3pass_softmax_forward(torch::Tensor input) {
CHECK_INPUT(input);
auto output = torch::empty_like(input);
int rows = static_cast<int>(input.size(0));
int cols = static_cast<int>(input.size(1));
launch_naive_3pass_softmax(input.data_ptr<float>(), output.data_ptr<float>(), rows, cols);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("online_forward", &online_softmax_forward, "Online softmax forward (HIP)");
m.def("naive_forward", &naive_3pass_softmax_forward, "Naive 3-pass softmax forward (HIP)");
}
python注册
用torch注册接口注册
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name="my_custom_softmax",
ext_modules=[
CUDAExtension(
name="my_custom_softmax_backend",
sources=["softmax_wrapper.cpp", "softmax_kernel.hip"],
extra_compile_args={"cxx": ["-O3"], "nvcc": ["-O3"]},
)
],
cmdclass={"build_ext": BuildExtension},
)
性能测试
调用ROCm朴素实现,ROCm online softmax实现,Torch接口,分别测试,对比精度误差和性能
import argparse
import time
import torch
import my_custom_softmax_backend
def online_softmax(x):
return my_custom_softmax_backend.online_forward(x)
def naive_softmax(x):
return my_custom_softmax_backend.naive_forward(x)
def bench(fn, x, warmup_iters, bench_iters):
for _ in range(warmup_iters):
fn(x)
torch.cuda.synchronize()
start = time.time()
for _ in range(bench_iters):
fn(x)
torch.cuda.synchronize()
return (time.time() - start) * 1000 / bench_iters
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--rows", type=int, default=4096)
parser.add_argument("--cols", type=int, default=1024)
parser.add_argument("--iters", type=int, default=200)
parser.add_argument("--warmup", type=int, default=20)
parser.add_argument("--scale", type=float, default=10.0)
args = parser.parse_args()
x = torch.randn(args.rows, args.cols, device="cuda", dtype=torch.float32) * args.scale
online = online_softmax(x)
naive = naive_softmax(x)
ref = torch.softmax(x, dim=-1)
torch.cuda.synchronize()
print(f"shape=({args.rows}, {args.cols})")
print(f"max_abs_diff(online, torch) = {(online - ref).abs().max().item():.8g}")
print(f"max_abs_diff(naive, torch) = {(naive - ref).abs().max().item():.8g}")
print(f"max_abs_diff(online, naive) = {(online - naive).abs().max().item():.8g}")
torch_ms = bench(lambda t: torch.softmax(t, dim=-1), x, args.warmup, args.iters)
online_ms = bench(online_softmax, x, args.warmup, args.iters)
naive_ms = bench(naive_softmax, x, args.warmup, args.iters)
print("Performance:")
print(f" torch.softmax: {torch_ms:.4f} ms")
print(f" online two-pass: {online_ms:.4f} ms")
print(f" naive three-pass: {naive_ms:.4f} ms")
print(f" speedup online/naive {naive_ms / online_ms:.3f}x")
print(f" speedup online/torch {torch_ms / online_ms:.3f}x")
if __name__ == "__main__":
main()
结果分析
来几个典型结果:cols较大时,online优于naive,但是如果rows较小,online还是不如torch。这分两方面看:
- online打过naive是因为online少一次全局内存读写,肯定是比naive更优的
- online打不过torch是因为online的分块方式是,每一个block负责一行,行数不够的时候并行度不够,无法喂饱CU计算单元,可能是block不够多,划分出的warp不够多,无法让CU超线程掩盖数据搬运延迟
root@u-1917-fcb583f1:/workspace/repo/src/infra/custom-pytorch-operator/code/custom_softmax# python bench_softmax.py --rows 256 --cols 8192 --iters 200 --warmup 20
shape=(256, 8192)
max_abs_diff(online, torch) = 3.5762787e-07
max_abs_diff(naive, torch) = 3.5762787e-07
max_abs_diff(online, naive) = 1.1920929e-07
Performance:
torch.softmax: 0.0225 ms
online two-pass: 0.0282 ms
naive three-pass: 0.0323 ms
speedup online/naive 1.145x
speedup online/torch 0.799x
row不大,col也不大,online和torch接近,但都不如naive
- online和torch估计都是某种形式的online softmax,故接近
- online在cols不大的时候,虽然少一次循环,少一次显存读写,但是合并计算的常数显然比正常的reduce max,reduce sum要大,合并时有两个exp调用,还有ifelse,故更慢。只有col更大的时候,节约的这一次显存读写,才能覆盖合并操作的开销,让online对naive实现反超
root@u-1917-fcb583f1:/workspace/repo/src/infra/custom-pytorch-operator/code/custom_softmax# python bench_softmax.py --rows 256 --co
ls 2048 --iters 200 --warmup 20
shape=(256, 2048)
max_abs_diff(online, torch) = 3.5762787e-07
max_abs_diff(naive, torch) = 3.5762787e-07
max_abs_diff(online, naive) = 1.1920929e-07
Performance:
torch.softmax: 0.0111 ms
online two-pass: 0.0108 ms
naive three-pass: 0.0097 ms
speedup online/naive 0.897x
speedup online/torch 1.028x
继续提高cols,online相对于naive的优势更明显。naive 3-pass理论上有三次读显存,一次写显存,四次显存访问,online 2-pass有两次读显存,一次写显存,如果把访存视为主要瓶颈的话(事实上确实如此),online相对于naive的理论最大加速比为4/3=1.33x4/3=1.33x4/3=1.33x,所以能达到1.28x是合理的,甚至还有一定加速空间
root@u-1917-fcb583f1:/workspace/repo/src/infra/custom-pytorch-operator/code/custom_softmax# python bench_softmax.py --rows 256 --co
ls 2097152 --iters 200 --warmup 20
shape=(256, 2097152)
max_abs_diff(online, torch) = 3.5762787e-07
max_abs_diff(naive, torch) = 3.5762787e-07
max_abs_diff(online, naive) = 1.1920929e-07
Performance:
torch.softmax: 11.5582 ms
online two-pass: 11.9933 ms
naive three-pass: 15.3814 ms
speedup online/naive 1.283x
speedup online/torch 0.964x
数据量继续增大,这里增加行数。
- 数据量的变大会方法online在数据搬运上的优势,掩盖在计算上的劣势,加速比进一步放大,能达到1.32x,接近理论上限1.33x,相比torch都实现了1.15的优化
- 并且这里增加行数,弥补了我们按行分块导致的block,warp数不足的缺点,有效实现了CU超线程,应该在超线程这一块和torch没有差距了,同时每一行的recude+2pass实现比torch更优,因此实现了由于torch的加速比
root@u-1917-fcb583f1:/workspace/repo/src/infra/custom-pytorch-operator/code/custom_softmax# python bench_softmax.py --rows 4096 --c
ols 262144 --iters 200 --warmup 20
shape=(4096, 262144)
max_abs_diff(online, torch) = 4.7683716e-07
max_abs_diff(naive, torch) = 4.7683716e-07
max_abs_diff(online, naive) = 1.7881393e-07
Performance:
torch.softmax: 24.4696 ms
online two-pass: 21.1227 ms
naive three-pass: 27.9633 ms
speedup online/naive 1.324x
speedup online/torch 1.158x
更多推荐

所有评论(0)