不可不知小技巧|FlagGems 算子库 @pointwise_dynamic 巧解 Pointwise 算子通用性问题
在多元算力时代,Pointwise 算子的通用性比单纯的峰值性能更重要。FlagGems 通过 @pointwise_dynamic 动态代码生成 技术,把广播适配、非连续访存、多维度索引等复杂工程问题封装在框架层,让开发者回归算法本身。在这一设计背后, FlagGems 成为全球支持芯片数量最多、算子数量最大的 Triton 单一算子库,实现了 90% 以上算子性能持平 / 超越 CUDA 原生
在大模型推理与训练的全链路中,Pointwise 算子(逐元素操作)无处不在。从基础的 add、mul、clamp,到 gelu、silu 等激活函数,再到各类融合算子,它们数量多、调用频繁,是决定模型端到端性能的关键一环。
随着国产 AI 芯片生态快速发展,一套算子代码在多类硬件上通用、且保持高性能,成为算子落地的核心诉求。Pointwise 算子要真正做到 “通用性”,必须跨过两道难关:完美兼容 PyTorch 广播语义、正确处理非连续张量(Non-Contiguous Tensor)。
作为 FlagOS 核心组件的 FlagGems 高性能算子库,基于 Triton 打造了一套动态代码生成机制,用 @pointwise_dynamic 让开发者只需聚焦标量计算逻辑,无需手写索引、stride、mask、广播适配等样板代码,即可实现 Pointwise 算子的跨硬件、全场景、高性能通用。
这篇文章,我们就从一个看似基础的问题出发:Pointwise 算子如何做到真正的“通用”?深入浅出地引出 FlagGems 算子库如何用 @pointwise_dynamic,优雅解决 Pointwise 算子通用性问题。
从一个“简单”的add开始
如果只看 Triton 的基础示例,Pointwise 算子几乎没有门槛。一个最经典的向量加法 kernel,大致是这样:
@triton.jit
def add_kernel(a_ptr, b_ptr, o_ptr, N, TILE_SIZE: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * TILE_SIZE + tl.arange(0, TILE_SIZE)
mask = offsets < N
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
o = a + b
tl.store(o_ptr + offsets, o, mask=mask)
def add(a,b):
o = torch.empty_like(a)
N = a.numel()
TILE_SIZE=512
grid = triton.cdiv(N,TILE_SIZE)
add_kernel[grid](a,b,o,N,TILE_SIZE)
return o
这个实现没有问题,但它隐含了一个前提——输入 tensor 必须是连续的、形状一致的。而实际上 PyTorch 代码很少满足这个前提。比如最常见的广播:
a = torch.randn(1, 128)
b = torch.randn(64, 128)
c = a + b
或者更隐蔽的情况:
# 创建一个长度为200的一维张量,取下标为偶数的元素(步长为2)
# 结果是一个长度为100但内存不连续的张量
a = torch.randn(200)[::2]
# 创建一个200x200的矩阵,行和列都每隔一行/列取一个(步长为2)
# 结果是一个100x100但内存不连续的张量
b = torch.randn(200, 200)[::2, ::2]
前者涉及 broadcast,后者则是典型的 non-contiguous tensor。这两种情况,会直接让“简单 kernel”失效,要么算错,要么根本跑不通。于是问题的本质就暴露出来了:Pointwise 算子真正的难点,从来不在计算逻辑,而在 Tensor 语义。
Pointwise 算子通用性的两大拦路虎
要让一个 Pointwise 算子“通用”,必须正确处理两件事:广播(broadcast) 和非连续张量(Non-Contiguous Tensor)。
广播的复杂之处在于,它并不会真的复制数据,而是通过“虚拟扩展”的方式改变 Tensor 的逻辑形状。这意味着,同一个元素可能会被重复读取,但它在内存里只存在一份。如果 kernel 仍然用线性地址访问,很容易读错位置。
而非连续 Tensor 更进一步,它直接改变了“如何从索引映射到内存地址”。在 PyTorch 中,一个元素的位置不再是简单的 base + offset,而是由 stride 决定的多维映射:
offset = i0 * stride0 + i1 * stride1 + ...
一旦忽略 stride,结果几乎必然是错误的。除了这些问题,真正棘手的是,这些信息在 Triton 里并不好表达。
Triton 的 JIT 函数不支持传入 list 或 tuple,也就是说,stride 这样的“动态维度数组”无法直接作为参数传递。只能把它拆成一个个标量参数,比如 stride0, stride1, stride2。
但问题是往往 Tensor 的维度(rank)是在运行时才知道的。这就形成了一个经典矛盾:kernel 需要根据 rank 写不同逻辑;Triton 却要求参数在编译期固定展开。
FlagGems 的思路:不要“写通用”,而是“生成通用”
面对这个问题,FlagGems 给出的解决方案是运行时动态生成适配代码:@pointwise_dynamic 装饰器会根据输入张量的rank、形状、stride、数据类型,自动生成适配的 Triton JIT Kernel 与外层 Wrapper,把开发者从繁琐的底层适配中解放出来,只需要写标量级计算逻辑即可。
使用方式看起来比较简单:
@pointwise_dynamic(is_tensor=[True, True, False]),promotion_methods=[(0,1,"DEFAULT")])
@triton.jit
def add_func(x, y, alpha):
return x + y * alpha
def add(A, B, *, alpha=1):
return add_func(A, B, alpha)
从开发者的角度,只需要关心一件事:标量层面的计算逻辑,例如broadcast 怎么处理、、stride 怎么展开、多维索引怎么计算等问题,则全部由系统自动完成。
@pointwise_dynamic 的关键,在于它把“写 kernel”这件事,拆成了两个阶段:
第一阶段:在运行时理解 Tensor
当函数被调用时,它会先分析输入 Tensor:
-
推导 broadcast 后的 shape
-
确定当前计算的 rank
-
收集每个 Tensor 的 stride
第二阶段:按需生成 Kernel
接下来,它会根据 rank 动态生成一份 Triton kernel。如果当前是二维 Tensor,生成的代码逻辑大致是这样的:
mask = tid < (M * N)
i1 = tid % N
i0 = tid // N
a = tl.load(a_ptr + i0 * a_stride0 + i1 * a_stride1, mask=mask)
b = tl.load(b_ptr + i0 * b_stride0 + i1 * b_stride1, mask=mask)
o = a + b
tl.store(o_ptr + i0 * o_stride0 + i1 * o_stride1, o, mask=mask)
这里有几个关键点:多维索引是自动生成的;stride 被正确展开为标量参数;broadcast 语义已经被转换为统一的 stride 地址计算逻辑。也就是说,最终生成的 kernel,本质上和“手写高性能版本”是等价的。
除了动态生成代码,@pointwise_dynamic 内部通过两层特化机制,实现通用性与高性能的平衡。
第一层,PointwiseDynamicFunction 内部会维护一个按 rank 分类的缓存。当某种 rank 第一次出现时,FlagGems 会动态生成对应的 Wrapper 与 Triton JIT Function;后续再次遇到相同 rank,则直接复用缓存结果,避免重复生成与编译。
对于 rank=2 的场景,生成的 kernel 会自动展开 stride 地址计算逻辑,例如:
# 自动生成的 2D 地址计算逻辑
i1 = tid % N
i0 = tid // N
a = tl.load(a_ptr + i0 * a_stride0 + i1 * a_stride1, mask=mask)
b = tl.load(b_ptr + i0 * b_stride0 + i1 * b_stride1, mask=mask)
o = a + b
tl.store(o_ptr + i0 * o_stride0 + i1 * o_stride1, o, mask=mask)
第二层,在Triton Runtime层,按 constexpr 参数特化。比如 TILE_SIZE、num_warps 等参数变化时,Triton 会生成不同的编译版本(CompiledKernel)。这两层机制叠加,既有动态适配能力,又不会牺牲性能。
带来的价值:让Pointwise 算子真正“通用可用”
@pointwise_dynamic 为 FlagGems 带来了三个层面的核心价值:
-
极致的通用性:一套代码同时支持任意常见 rank 的 Tensor 场景、任意 broadcast 组合、任意非连续 Tensor,开发者只需关心标量计算逻辑。
-
大幅降低维护成本:FlagGems 目前支持 400+ 个算子,其中 Pointwise 类算子占相当比例。如果没有 @pointwise_dynamic,每个算子都需要手写处理 broadcast 和非连续的复杂逻辑,代码量至少膨胀 5-10 倍,且极易出错。
-
多芯片适配零修改:由于生成的 Triton IR 是硬件中立的,@pointwise_dynamic 本身不包含任何硬件假设。当 FlagGems 适配寒武纪、海光、摩尔线程等国产芯片时,大部分 Pointwise 算子逻辑无需修改,仅需后端 Triton 编译器完成目标硬件 lower 与调优。
接下来我们看一个实战案例。
优化前(PyTorch 风格实现):
def gelu(x):
return 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x**3)))
FlagGems 优化后(在 @pointwise_dynamic 装饰的 kernel 中):
@pointwise_dynamic(is_tensor=[True])
@triton.jit
def gelu_func(x):
# 使用 fma 合并乘加操作,利用硬件指令
one = 1.0
cdf = 0.5 * (one + tl.math.erf(x * 0.7071067811865475))
return x * cdf
这个案例说明,@pointwise_dynamic 让开发者能够聚焦于“计算逻辑的正确与高效”,而框架层自动解决“如何让这段逻辑跑在任意形状、任意芯片上”的通用性问题。
结语
在多元算力时代,Pointwise 算子的通用性比单纯的峰值性能更重要。FlagGems 通过 @pointwise_dynamic 动态代码生成 技术,把广播适配、非连续访存、多维度索引等复杂工程问题封装在框架层,让开发者回归算法本身。
在这一设计背后, FlagGems 成为全球支持芯片数量最多、算子数量最大的 Triton 单一算子库,实现了 90% 以上算子性能持平 / 超越 CUDA 原生,完美适配国产多元 AI 芯片生态,为大模型在多元硬件上的高效部署提供了关键支撑。
如果你正在做Triton 算子开发、自定义 PyTorch Operator、多芯片适配等,希望这个思路能帮大家早理解、少踩坑。
FlagGems GitHub仓地址:https://github.com/flagos-ai/FlagGems
FlagGems GitCode仓地址:https://gitcode.com/flagos-ai/FlagGems
关于众智FlagOS社区
为解决不同 AI 芯片大规模落地应用,北京智源研究院联合众多科研机构、芯片企业、系统厂商、算法和软件相关单位等国内外机构共同发起并创立了众智 FlagOS 社区。成员单位包括北京智源研究院、中科院计算所、中科加禾、安谋科技、北京大学、北京师范大学、百度飞桨、硅基流动、寒武纪、海光信息、华为、基流科技、摩尔线程、沐曦股份、澎峰科技、清微智能、天数智芯、先进编译实验室、移动研究院、中国矿业大学(北京)等多家在 FlagOS 软件栈研发中做出卓越贡献的单位。
FlagOS 是一款专为异构 AI 芯片打造的开源、统一系统软件栈,支持 AI 模型一次开发即可无缝移植至各类硬件平台,大幅降低迁移与适配成本。它包括大型算子库、统一AI编译器、并行训推框架、统一通信库等核心开源项目,致力于构建「模型-系统-芯片」三层贯通的开放技术生态,通过“一次开发跨芯迁移”释放硬件计算潜力,打破不同芯片软件栈之间生态隔离。
官网:https://flagos.io
GitHub 项目地址:https://github.com/flagos-ai
GitCode 项目地址:https://gitcode.com/flagos-ai
SkillHub:https://skillhub.flagos.io
更多推荐


所有评论(0)