Tilelang-metax|MoE|Tilelang baseline
来源
这个比赛的github仓库,race_tests目录下有三个算子,拉取源码操作如下,记得切换分支
cd /data
git clone https://github.com/tile-ai/tilelang-metax.git
cd tilelang-metax
git checkout race
上一个文章分享了torch baseline的实现思路,了解了moe算子的基本思想,本次来看看tilelang baseline的思路,tilelang实现在基础实现的基础上,增加了一些优化。
tilelang baseline
custom_kernel
从调用链高层到底层来看,最外层是这个自定义算子函数。关键点在于,这里把传入的配置参数,传给了一个RoutedMoEKernel,编译得到一个内核算子,再把这个内核算子传给moe类,进行推理
实际上,这里的RoutedMoEKernel就是优化版本的tilelang内核,之所以还要套一层MoE类,是因为有些计算量不大,但是麻烦的操作,留在python端,用torch做,只把计算密集型任务传递给tilelang内核。
def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
"""
DeepSeek-style Mixture of Experts using Tilelang.
Args:
data: Tuple of (input: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
- input: Input tensor of shape [batch_size, seq_len, hidden_size]
- weights: Dictionary containing model weights
- config: Dictionary containing model configuration parameters
Returns:
Tuple containing:
- output: Processed tensor [batch_size, seq_len, d_model]
"""
input_tensor, weights, config = data
# 编译 Tilelang kernel(group_sum = 所有激活专家的 token 总数)
routed_kernel = RoutedMoEKernel(
config["d_hidden"],
config["d_expert"],
config["n_routed_experts"],
group_sum=config["batch_size"] * config["seq_len"] * config["n_experts_per_token"],
group_count=config["n_routed_experts"]
)
moe = MoE(config, routed_kernel, weights, padding_M=128)
output = moe(input_tensor)
return output
MoE
来看MoE类的封装。首先是初始化
self.routed_kernel = routed_kernel保存编译后的tilelang内核,后面调用self.experts = nn.ModuleList(...)初始化多个专家,这里和torch baseline不同的是,MoE类构造时,就把模型参数也作为参数传入,然后专家构造时直接把参数赋值给专家网络,而不是先初始化shape,再在外面手动对参数赋值。self.gating_network = MoEGate(config, weights).to(self.device)外面定义的路由函数也保存,后面推理会用self.expert_cache = torch.zeros((config["batch_size"] * config["seq_len"], config["d_hidden"]), dtype=torch.float16, device=self.device),只在构造时初始化一次结果数组,不用每次forward现场申请内存self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts], dim=0),把多个专家的参数放一起,形成三维张量,shape是[n_experts, d_expert, d_hidden](gate/up),[n_experts, d_hidden, d_expert](down),这是为了后面传递给tilelang内核时方便self.stacked_expert_tokens = torch.empty()后面这些,也都是构造时创建数组,避免每次forward现场申请内存
# 完整 MoE 模块(Tilelang 优化版)
# 负责路由、分组、调用 Grouped GEMM kernel,以及结果 scatter-reduce
class MoE(nn.Module):
def __init__(
self, config: Dict, routed_kernel, weights: Dict, padding_M: int = 128
):
super().__init__()
self.config = config
self.routed_kernel = routed_kernel # 已编译的 RoutedMoEKernel 实例
self.padding_M = padding_M # token 分块大小,用于 padding 对齐
self.experts = nn.ModuleList(
[
Expert(
config,
gate=weights[f"experts.{i}.0.weight"],
up=weights[f"experts.{i}.1.weight"],
down=weights[f"experts.{i}.2.weight"],
)
for i in range(config["n_routed_experts"])
]
)
self.device = torch.device("cuda")
self.gating_network = MoEGate(config, weights).to(self.device)
# 预分配输出累加缓冲区,避免每次 forward 时分配显存
self.expert_cache = torch.zeros(
(config["batch_size"] * config["seq_len"], config["d_hidden"]), dtype=torch.float16, device=self.device
)
# 将所有专家权重堆叠为 3D 张量,便于 kernel 按专家索引访问
# 形状:[n_experts, d_expert, d_hidden](gate/up),[n_experts, d_hidden, d_expert](down)
self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts], dim=0)
self.stacked_expert_w_up = torch.stack([expert.W_up_weight for expert in self.experts], dim=0)
self.stacked_expert_w_down = torch.stack([expert.W_down_weight for expert in self.experts], dim=0)
# 按专家分组排列后的 token 特征缓冲区,形状 [B*S*top_k, d_hidden]
self.stacked_expert_tokens = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]),
dtype=torch.float16,
device=self.device,
)
# 每个位置对应的路由权重(scalar),形状 [B*S*top_k]
self.stacked_expert_weights = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.float16, device=self.device
)
# 每个位置在原始 x_flat 中的 token 索引,用于最后 scatter-reduce
self.stacked_expert_tokens_idxs = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.int64, device=self.device
)
# Step1 输出的中间激活缓冲区(gate*up),形状 [B*S*top_k, d_expert]
self.up_logits_routed = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_expert"]),
dtype=torch.float16,
device=self.device,
)
# kernel 最终输出缓冲区,形状 [B*S*top_k, d_hidden]
self.expert_output_routed = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]),
dtype=torch.float16,
device=self.device,
)
forward
比较关键的部分,MoE类的前向传播,会把数据处理成可以直接让tilelang内核计算的格式,然后调用tilelang内核
expert_indices, expert_scores = self.gating_network(x)这块和torch内核类似,先路由,然后把路由结果reshape,B,S维度展平。idxs = flat_expert_indices.argsort()这块也和torch baseline一样的,把路由结果排序,返回原始下标,然后做计数,前缀和。这样前缀和数组可以指示每个专家需要负责的token区间。for expert_id, end_idx in enumerate(tokens_per_expert):枚举前缀和指示的每个区间,区间内的token都属于当前专家负责exp_token_idxs = token_idxs[start_idx:end_idx]取出当前专家推理的token下标列表expert_tokens = x_flat[exp_token_idxs]py语法糖,传入一个下标列表,取出这些下标位置的元素,这里取出的就是具体的待推理tokenself.stacked_expert_tokens[start_idx:end_idx] = expert_tokens和torch baselin不同的地方来了,这里不是去计算了,而是把一个专家的所有token堆叠起来,形成一个大张量,后面会传给tilelang kernel,类似地,token原始id和专家权重也要堆叠起来。- 为了调用tilelang内核,还需要额外构造辅助信息,
group_sizes = torch.tensor(counts, dtype=torch.int32, device=self.device)就是每个专家的token个数 group_offset = torch.tensor(tokens_per_expert - counts, dtype=torch.int32, device=self.device)这里把之前的前缀和,和计数数组作差,得到的是每隔专家的区间,只不过之前是[2,4,6]表示区间[0,2)[2,4),这样现在是变成[0,2,4],表示的区间还是不变,但是获取区间更方便了,第i个专家的区间就是 [ a i , a i + 1 − 1 ] [a_i,a_{i+1}-1] [ai,ai+1−1]group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil((counts[i - 1] + 1) / self.padding_M) * self.padding_M这里看表达式很复杂,但是其实就是为了健壮性做了一个padding对齐,首先,每个专家负责的区间长度可能不是padding_M的整数倍,那么就利用整除变成整数倍,其次,后面还可能会padding上一些,比如恰好能被padding_M整除的话,后面还是会增加一个长度padding_M的填充位,这样不同专家的区间之间肯定有填充位置存在。M = (math.ceil(self.config["batch_size"] * self.config["seq_len"] * self.config["n_experts_per_token"] / block_token)+ self.config["n_routed_experts"])计算tilelang内核计算时,B * S * topk维度的分块数,块长block_token=padding_M=128for bx in range(M):这里的两层循环,外层枚举每个块,内层枚举每个专家,找到padding后,区间开头不超过当前块起点的,最靠后的专家编号,这样就可以确定这个块所属的专家。比如专家开始下标数组是[0,3,6],那么当前块编号bx=1,可以发现开头不超过1,最靠后的是第0个专家,它的开头是0,所以bx=1的块,属于0号专家。这是为了推理时,每一个block确定自己应该加载哪个专家进行推理。with torch.cuda.stream(routed_stream):创建一个CUDA工作流来提交任务- 流内部,把我们前面构造的辅助数组,和堆叠起来的参数,输入张量传递给tilelang内核。
self.routed_kernel(...) self.expert_cache = torch.scatter_reduce(...)最后结果还是用torch来加回对应token位置routed_output = self.expert_cache.view(*orig_shape)并且把展平的张量,reshape回输入张量一样的shape
整体来说,只有moe infer的部分,调用了tilelang内核,参数加载,路由,数据预处理,后处理,都还是torch。这都是潜在的优化点
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_shape = x.shape
batch_size, seq_len, hidden_dim = x.shape
# 1. 路由:获取每个 token 的 top-k 专家索引和权重
expert_indices, expert_scores = self.gating_network(x)
flat_expert_indices = expert_indices.view(-1) # [B*S*top_k]
flat_expert_weights = expert_scores.view(-1) # [B*S*top_k]
x_flat = x.view(-1, hidden_dim) # [B*S, d_hidden]
# 2. 按专家 ID 排序,将同一专家的 token 聚合在一起(Grouped GEMM 的前提)
idxs = flat_expert_indices.argsort() # 排序后的 (token, expert) 对索引
counts = flat_expert_indices.bincount().cpu().numpy() # 每个专家分配到的 token 数
# counts = flat_expert_indices.bincount()
tokens_per_expert = counts.cumsum() # 累积和,用于确定每个专家的区间 [start, end)
# tokens_per_expert = torch.cumsum(counts, dim=0)
num_per_tok = self.config["n_experts_per_token"]
# 将排序索引还原为原始 token 索引(argsort 索引 // top_k = token 索引)
token_idxs = idxs // num_per_tok
# 3. 将 token 按专家分组填入 stacked_expert_tokens 缓冲区
for expert_id, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1]
if start_idx == end_idx:
continue # 该专家本轮无分配 token
exp_token_idxs = token_idxs[start_idx:end_idx] # 该专家对应的原始 token 索引
expert_tokens = x_flat[exp_token_idxs] # gather token 特征
# 写入连续缓冲区,形成 Grouped GEMM 所需的连续内存布局
self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens
self.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs
self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[idxs[start_idx:end_idx]]
# 4. 构造 Grouped GEMM 所需的辅助张量
# group_sizes:每个专家的 token 数
group_sizes = torch.tensor(counts, dtype=torch.int32, device=self.device)
# group_offset:每个专家在 stacked_expert_tokens 中的起始偏移(未 padding)
group_offset = torch.tensor(tokens_per_expert - counts, dtype=torch.int32, device=self.device)
# group_padded_offsets:每个专家在 padded 虚拟空间中的起始偏移
# 每个专家占用 ceil((count+1)/padding_M)*padding_M 个槽,保证 block 对齐
# +1 是为了在边界处留出至少一个 padding 行,避免越界
group_padded_offsets = [0 for _ in range(len(group_sizes))]
for i in range(1, len(group_sizes)):
group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil((counts[i - 1] + 1) / self.padding_M) * self.padding_M
# 5. 为每个 kernel block (bx) 确定它属于哪个专家
# 遍历所有 bx,找到最后一个满足 m_start_padded >= group_padded_offsets[i] 的专家
block_token = 128
M = (
math.ceil(self.config["batch_size"] * self.config["seq_len"] * self.config["n_experts_per_token"] / block_token)
+ self.config["n_routed_experts"]
)
group_idx_for_bx = [0 for _ in range(M)]
for bx in range(M):
m_start_padded = bx * block_token
for i in range(self.config["n_routed_experts"]):
if m_start_padded >= group_padded_offsets[i]:
group_idx_for_bx[bx] = i
group_padded_offsets = torch.tensor(group_padded_offsets, dtype=torch.int32, device=self.device)
group_idx_for_bx = torch.tensor(group_idx_for_bx, dtype=torch.int32, device=self.device)
routed_stream = torch.cuda.default_stream()
torch.cuda.synchronize()
with torch.cuda.stream(routed_stream):
# 6. 调用 Tilelang Grouped GEMM kernel
# 输入:按专家分组的 token + 堆叠的专家权重 + 辅助索引张量
# 输出:expert_output_routed,每个 token 已乘以路由权重
self.routed_kernel(
self.stacked_expert_tokens,
self.stacked_expert_w_gate,
self.stacked_expert_w_up,
self.stacked_expert_w_down,
self.stacked_expert_weights,
group_sizes,
group_offset,
group_padded_offsets,
group_idx_for_bx,
self.up_logits_routed,
self.expert_output_routed,
)
# 7. Scatter-reduce:将各专家输出按原始 token 索引累加回 expert_cache
# 同一 token 可能被多个专家处理(top-k>1),需要求和
self.expert_cache = torch.scatter_reduce(
self.expert_cache,
0,
self.stacked_expert_tokens_idxs.view(-1, 1).repeat(1, x_flat.shape[-1]),
self.expert_output_routed,
reduce="sum",
)
routed_output = self.expert_cache.view(*orig_shape)
torch.cuda.synchronize()
return routed_output
tilelang内核 一阶段gate+up
tilelang内核都是处理好的数据,只用执行MOE专家推理部分,也就是gate+up+down。gate和up可以同时进行,down需要等待gate * up的结果。所以把内核分成两部分,第一部分先计算gate,up,然后gate * up
with T.Kernel(M, T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by):首先是分块方法,输入是[B * S * topk,hid_dim]的待推理token列表,以及[hid_dim,expert_dim]的专家参数矩阵 X u p , X g a t e X_{up},X_{gate} Xup,Xgate,输出是[B * S * topk,expert_dim]的推理中间结果。仔细看的话,这就是一个GEMM算子,经典做法就是考虑对结果数组分块,B * S * topk维度的块数,块长,在前面的MoE forward就定好了,这里直接作为参数M传进来,expert_dim维度,按照块长block_dexpert分块。actual_rows = T.max(0, T.min(block_token, cur_group_size - (m_start_padded - group_padded_offsets[cur_group_idx])))这是由于我们前面的padding,分块M块是在padding后的输入上来说的,但是要取真实数据,需要转换为真实坐标for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages):类似经典GEMM,一个[Block_M,Block_N]c tile的矩阵乘法,需要用到的a tile,b tile分别是[Block_M,K],[K,Block_N]的两个长条,其中K维度可能很大,于是我们对K维度也进行分块,但这里由于多个分块都要累加到同一个c tile数组,可能出现写冲突,所以不能并行,只能串行,但是也可以使用流水线提高穿行效率,于是用T.Pipelined做枚举- 内部就是一般的GEMM思路,读取两个a,b tile block,做乘法,写入结果数组,
T.gemm(input_shared, routed_expert_gate_shared, gate_logits_local, transpose_B=True),这里要计算gate和up两次矩阵乘法。 - 最后根据SiLU的定义,需要对gate和up结果做一个逐元素相乘,两层循环枚举即可
for i, j in T.Parallel(block_token, block_dexpert):,这里逐元素,不同位置没有数据依赖,可以并行
# =====================================================================
# Step 1: 计算 gate logits 和 up logits,融合 SiLU 激活
# Grid: [M, ceil(dexpert/block_dexpert)]
# bx:token 分块索引(含 padding)
# by:expert 维度分块索引
# =====================================================================
with T.Kernel(M, T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by):
# shared memory:存放当前 tile 的 input 和权重
input_shared = T.alloc_fragment((block_token, block_dhidden), dtype=dtype)
routed_expert_gate_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype)
routed_expert_up_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype)
# 寄存器累加器:float32 精度
gate_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype)
up_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype)
# swizzle:优化 shared memory bank conflict
T.use_swizzle(10)
# 当前 block 在 padded 空间中的 token 起始位置
m_start_padded = bx * block_token
# 查表:该 block 属于哪个专家
cur_group_idx = group_idx_for_bx[bx]
cur_group_size = group_sizes[cur_group_idx]
# 将 padded 偏移转换为实际 input 中的偏移
# m_start = padded起始 - 该专家padded偏移 + 该专家真实偏移
m_start = m_start_padded - group_padded_offsets[cur_group_idx] + group_offsets[cur_group_idx]
# 计算本 block 中实际有效的 token 行数(末尾 block 可能不足 block_token)
actual_rows = T.max(0, T.min(block_token, cur_group_size - (m_start_padded - group_padded_offsets[cur_group_idx])))
T.clear(gate_logits_local)
T.clear(up_logits_local)
# 沿 hidden 维度分块累加:流水线预取,减少访存延迟
for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages):
# 从 global memory 加载 input tile
T.copy(
input[m_start : m_start + block_token, k * block_dhidden : (k + 1) * block_dhidden],
input_shared,
)
# 加载当前专家的 gate 权重 tile(形状 [dexpert, dhidden],GEMM 时转置 B)
T.copy(
routed_expert_gate[
cur_group_idx, by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden
],
routed_expert_gate_shared,
)
# input @ gate^T -> gate_logits [block_token, block_dexpert]
T.gemm(input_shared, routed_expert_gate_shared, gate_logits_local, transpose_B=True)
# 加载 up 权重 tile
T.copy(
routed_expert_up[
cur_group_idx, by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden
],
routed_expert_up_shared,
)
# input @ up^T -> up_logits [block_token, block_dexpert]
T.gemm(input_shared, routed_expert_up_shared, up_logits_local, transpose_B=True)
# 融合 SiLU 激活并将 gate 与 up 逐元素相乘
# SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x))
# 用 exp2 加速:sigmoid(x) = 1 / (1 + exp(-x)) = 1 / (1 + exp2(-x * log2e))
for i, j in T.Parallel(block_token, block_dexpert):
gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale)))
up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j]
# 将结果写回 global memory(仅写有效行,跳过 padding 行)
for i, j in T.Parallel(block_token, block_dexpert):
if i < actual_rows:
up_logits[m_start + i, by * block_dexpert + j] = up_logits_local[i, j]
tilelang内核 二阶段down
和一阶段的操作几乎一样,也是一个矩阵乘法,[B * S * topk,expert_dim]乘上[expert_dim,hid_dim],输出[B * S * topk,hid_dim]。由于输出shape和一阶段不一样,单独一个阶段,
但实际上,完全可以两个阶段融合,这也是潜在的优化点,这样就不需要全局张量来保存gate * up的中间结果了,也不需要进行全局内存和共享内存/寄存器的读写了。
# =====================================================================
# Step 2: down projection,将中间激活投影回 hidden 维度,并乘以路由权重
# Grid: [M, ceil(dhidden/block_dhidden)]
# bx:token 分块索引(含 padding,与 Step1 相同)
# by:hidden 维度分块索引
# =====================================================================
with T.Kernel(M, T.ceildiv(dhidden, block_dhidden), threads=threads) as (bx, by):
up_logits_shared = T.alloc_fragment((block_token, block_dexpert), dtype=dtype)
routed_expert_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype)
output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_dtype)
T.use_swizzle(10)
m_start_padded = bx * block_token
cur_group_idx = group_idx_for_bx[bx]
cur_group_size = group_sizes[cur_group_idx]
m_start = m_start_padded - group_padded_offsets[cur_group_idx] + group_offsets[cur_group_idx]
actual_rows = T.max(0, T.min(block_token, cur_group_size - (m_start_padded - group_padded_offsets[cur_group_idx])))
T.clear(output_local)
# 沿 expert 维度分块累加
for k in T.Pipelined(T.ceildiv(dexpert, block_dexpert), num_stages=num_stages):
# 加载 Step1 的中间激活
T.copy(
up_logits[m_start : m_start + block_token, k * block_dexpert : (k + 1) * block_dexpert],
up_logits_shared,
)
# 加载当前专家的 down 权重 tile(形状 [dhidden, dexpert],GEMM 时转置 B)
T.copy(
routed_expert_down[
cur_group_idx, by * block_dhidden : (by + 1) * block_dhidden, k * block_dexpert : (k + 1) * block_dexpert
],
routed_expert_down_shared,
)
# up_logits @ down^T -> output [block_token, block_dhidden]
T.gemm(up_logits_shared, routed_expert_down_shared, output_local, transpose_B=True)
# 写回时乘以该 token 的路由权重(scalar),跳过 padding 行
for i, j in T.Parallel(block_token, block_dhidden):
if i < actual_rows:
output[m_start + i, by * block_dhidden + j] = output_local[i, j] * routed_expert_weights[m_start + i]
return kernel
更多推荐

所有评论(0)