来源

这个比赛的github仓库,race_tests目录下有三个算子,拉取源码操作如下,记得切换分支

cd /data
git clone https://github.com/tile-ai/tilelang-metax.git
cd tilelang-metax
git checkout race

上一个文章分享了torch baseline的实现思路,了解了moe算子的基本思想,本次来看看tilelang baseline的思路,tilelang实现在基础实现的基础上,增加了一些优化。

tilelang baseline

custom_kernel

从调用链高层到底层来看,最外层是这个自定义算子函数。关键点在于,这里把传入的配置参数,传给了一个RoutedMoEKernel,编译得到一个内核算子,再把这个内核算子传给moe类,进行推理

实际上,这里的RoutedMoEKernel就是优化版本的tilelang内核,之所以还要套一层MoE类,是因为有些计算量不大,但是麻烦的操作,留在python端,用torch做,只把计算密集型任务传递给tilelang内核。

def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
    """
    DeepSeek-style Mixture of Experts using Tilelang.

    Args:
        data: Tuple of (input: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
            - input: Input tensor of shape [batch_size, seq_len, hidden_size]
            - weights: Dictionary containing model weights
            - config: Dictionary containing model configuration parameters

    Returns:
        Tuple containing:
            - output: Processed tensor [batch_size, seq_len, d_model]
    """
    input_tensor, weights, config = data

    # 编译 Tilelang kernel(group_sum = 所有激活专家的 token 总数)
    routed_kernel = RoutedMoEKernel(
        config["d_hidden"],
        config["d_expert"],
        config["n_routed_experts"],
        group_sum=config["batch_size"] * config["seq_len"] * config["n_experts_per_token"],
        group_count=config["n_routed_experts"]
    )

    moe = MoE(config, routed_kernel, weights, padding_M=128)

    output = moe(input_tensor)

    return output

MoE

来看MoE类的封装。首先是初始化

  • self.routed_kernel = routed_kernel保存编译后的tilelang内核,后面调用
  • self.experts = nn.ModuleList(...)初始化多个专家,这里和torch baseline不同的是,MoE类构造时,就把模型参数也作为参数传入,然后专家构造时直接把参数赋值给专家网络,而不是先初始化shape,再在外面手动对参数赋值。
  • self.gating_network = MoEGate(config, weights).to(self.device)外面定义的路由函数也保存,后面推理会用
  • self.expert_cache = torch.zeros((config["batch_size"] * config["seq_len"], config["d_hidden"]), dtype=torch.float16, device=self.device),只在构造时初始化一次结果数组,不用每次forward现场申请内存
  • self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts], dim=0),把多个专家的参数放一起,形成三维张量,shape是[n_experts, d_expert, d_hidden](gate/up),[n_experts, d_hidden, d_expert](down),这是为了后面传递给tilelang内核时方便
  • self.stacked_expert_tokens = torch.empty()后面这些,也都是构造时创建数组,避免每次forward现场申请内存
# 完整 MoE 模块(Tilelang 优化版)
# 负责路由、分组、调用 Grouped GEMM kernel,以及结果 scatter-reduce
class MoE(nn.Module):
    def __init__(
        self, config: Dict, routed_kernel, weights: Dict, padding_M: int = 128
    ):
        super().__init__()
        self.config = config
        self.routed_kernel = routed_kernel  # 已编译的 RoutedMoEKernel 实例
        self.padding_M = padding_M          # token 分块大小,用于 padding 对齐
        self.experts = nn.ModuleList(
            [
                Expert(
                    config,
                    gate=weights[f"experts.{i}.0.weight"],
                    up=weights[f"experts.{i}.1.weight"],
                    down=weights[f"experts.{i}.2.weight"],
                )
                for i in range(config["n_routed_experts"])
            ]
        )
        self.device = torch.device("cuda")
        self.gating_network = MoEGate(config, weights).to(self.device)

        # 预分配输出累加缓冲区,避免每次 forward 时分配显存
        self.expert_cache = torch.zeros(
            (config["batch_size"] * config["seq_len"], config["d_hidden"]), dtype=torch.float16, device=self.device
        )
        # 将所有专家权重堆叠为 3D 张量,便于 kernel 按专家索引访问
        # 形状:[n_experts, d_expert, d_hidden](gate/up),[n_experts, d_hidden, d_expert](down)
        self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts], dim=0)
        self.stacked_expert_w_up = torch.stack([expert.W_up_weight for expert in self.experts], dim=0)
        self.stacked_expert_w_down = torch.stack([expert.W_down_weight for expert in self.experts], dim=0)

        # 按专家分组排列后的 token 特征缓冲区,形状 [B*S*top_k, d_hidden]
        self.stacked_expert_tokens = torch.empty(
            (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]),
            dtype=torch.float16,
            device=self.device,
        )
        # 每个位置对应的路由权重(scalar),形状 [B*S*top_k]
        self.stacked_expert_weights = torch.empty(
            (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.float16, device=self.device
        )
        # 每个位置在原始 x_flat 中的 token 索引,用于最后 scatter-reduce
        self.stacked_expert_tokens_idxs = torch.empty(
            (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.int64, device=self.device
        )

        # Step1 输出的中间激活缓冲区(gate*up),形状 [B*S*top_k, d_expert]
        self.up_logits_routed = torch.empty(
            (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_expert"]),
            dtype=torch.float16,
            device=self.device,
        )
        # kernel 最终输出缓冲区,形状 [B*S*top_k, d_hidden]
        self.expert_output_routed = torch.empty(
            (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]),
            dtype=torch.float16,
            device=self.device,
        )

forward

比较关键的部分,MoE类的前向传播,会把数据处理成可以直接让tilelang内核计算的格式,然后调用tilelang内核

  • expert_indices, expert_scores = self.gating_network(x)这块和torch内核类似,先路由,然后把路由结果reshape,B,S维度展平。
  • idxs = flat_expert_indices.argsort() 这块也和torch baseline一样的,把路由结果排序,返回原始下标,然后做计数,前缀和。这样前缀和数组可以指示每个专家需要负责的token区间。
  • for expert_id, end_idx in enumerate(tokens_per_expert):枚举前缀和指示的每个区间,区间内的token都属于当前专家负责
  • exp_token_idxs = token_idxs[start_idx:end_idx] 取出当前专家推理的token下标列表
  • expert_tokens = x_flat[exp_token_idxs] py语法糖,传入一个下标列表,取出这些下标位置的元素,这里取出的就是具体的待推理token
  • self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens和torch baselin不同的地方来了,这里不是去计算了,而是把一个专家的所有token堆叠起来,形成一个大张量,后面会传给tilelang kernel,类似地,token原始id和专家权重也要堆叠起来。
  • 为了调用tilelang内核,还需要额外构造辅助信息,group_sizes = torch.tensor(counts, dtype=torch.int32, device=self.device)就是每个专家的token个数
  • group_offset = torch.tensor(tokens_per_expert - counts, dtype=torch.int32, device=self.device)这里把之前的前缀和,和计数数组作差,得到的是每隔专家的区间,只不过之前是[2,4,6]表示区间[0,2)[2,4),这样现在是变成[0,2,4],表示的区间还是不变,但是获取区间更方便了,第i个专家的区间就是 [ a i , a i + 1 − 1 ] [a_i,a_{i+1}-1] [ai,ai+11]
  • group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil((counts[i - 1] + 1) / self.padding_M) * self.padding_M这里看表达式很复杂,但是其实就是为了健壮性做了一个padding对齐,首先,每个专家负责的区间长度可能不是padding_M的整数倍,那么就利用整除变成整数倍,其次,后面还可能会padding上一些,比如恰好能被padding_M整除的话,后面还是会增加一个长度padding_M的填充位,这样不同专家的区间之间肯定有填充位置存在。
  • M = (math.ceil(self.config["batch_size"] * self.config["seq_len"] * self.config["n_experts_per_token"] / block_token)+ self.config["n_routed_experts"])计算tilelang内核计算时,B * S * topk维度的分块数,块长block_token=padding_M=128
  • for bx in range(M):这里的两层循环,外层枚举每个块,内层枚举每个专家,找到padding后,区间开头不超过当前块起点的,最靠后的专家编号,这样就可以确定这个块所属的专家。比如专家开始下标数组是[0,3,6],那么当前块编号bx=1,可以发现开头不超过1,最靠后的是第0个专家,它的开头是0,所以bx=1的块,属于0号专家。这是为了推理时,每一个block确定自己应该加载哪个专家进行推理。
  • with torch.cuda.stream(routed_stream):创建一个CUDA工作流来提交任务
  • 流内部,把我们前面构造的辅助数组,和堆叠起来的参数,输入张量传递给tilelang内核。self.routed_kernel(...)
  • self.expert_cache = torch.scatter_reduce(...)最后结果还是用torch来加回对应token位置
  • routed_output = self.expert_cache.view(*orig_shape)并且把展平的张量,reshape回输入张量一样的shape

整体来说,只有moe infer的部分,调用了tilelang内核,参数加载,路由,数据预处理,后处理,都还是torch。这都是潜在的优化点

    @torch.no_grad()
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        orig_shape = x.shape
        batch_size, seq_len, hidden_dim = x.shape

        # 1. 路由:获取每个 token 的 top-k 专家索引和权重
        expert_indices, expert_scores = self.gating_network(x)
        flat_expert_indices = expert_indices.view(-1)       # [B*S*top_k]
        flat_expert_weights = expert_scores.view(-1)        # [B*S*top_k]
        x_flat = x.view(-1, hidden_dim)                     # [B*S, d_hidden]

        # 2. 按专家 ID 排序,将同一专家的 token 聚合在一起(Grouped GEMM 的前提)
        idxs = flat_expert_indices.argsort()                # 排序后的 (token, expert) 对索引
        counts = flat_expert_indices.bincount().cpu().numpy()  # 每个专家分配到的 token 数
        # counts = flat_expert_indices.bincount()
        tokens_per_expert = counts.cumsum()                 # 累积和,用于确定每个专家的区间 [start, end)
        # tokens_per_expert = torch.cumsum(counts, dim=0)
        num_per_tok = self.config["n_experts_per_token"]
        # 将排序索引还原为原始 token 索引(argsort 索引 // top_k = token 索引)
        token_idxs = idxs // num_per_tok

        # 3. 将 token 按专家分组填入 stacked_expert_tokens 缓冲区
        for expert_id, end_idx in enumerate(tokens_per_expert):
            start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1]
            if start_idx == end_idx:
                continue  # 该专家本轮无分配 token

            exp_token_idxs = token_idxs[start_idx:end_idx]    # 该专家对应的原始 token 索引
            expert_tokens = x_flat[exp_token_idxs]             # gather token 特征

            # 写入连续缓冲区,形成 Grouped GEMM 所需的连续内存布局
            self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens
            self.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs
            self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[idxs[start_idx:end_idx]]

        # 4. 构造 Grouped GEMM 所需的辅助张量
        # group_sizes:每个专家的 token 数
        group_sizes = torch.tensor(counts, dtype=torch.int32, device=self.device)
        # group_offset:每个专家在 stacked_expert_tokens 中的起始偏移(未 padding)
        group_offset = torch.tensor(tokens_per_expert - counts, dtype=torch.int32, device=self.device)

        # group_padded_offsets:每个专家在 padded 虚拟空间中的起始偏移
        # 每个专家占用 ceil((count+1)/padding_M)*padding_M 个槽,保证 block 对齐
        # +1 是为了在边界处留出至少一个 padding 行,避免越界
        group_padded_offsets = [0 for _ in range(len(group_sizes))]
        for i in range(1, len(group_sizes)):
            group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil((counts[i - 1] + 1) / self.padding_M) * self.padding_M

        # 5. 为每个 kernel block (bx) 确定它属于哪个专家
        # 遍历所有 bx,找到最后一个满足 m_start_padded >= group_padded_offsets[i] 的专家
        block_token = 128
        M = (
            math.ceil(self.config["batch_size"] * self.config["seq_len"] * self.config["n_experts_per_token"] / block_token)
            + self.config["n_routed_experts"]
        )
        group_idx_for_bx = [0 for _ in range(M)]

        for bx in range(M):
            m_start_padded = bx * block_token
            for i in range(self.config["n_routed_experts"]):
                if m_start_padded >= group_padded_offsets[i]:
                    group_idx_for_bx[bx] = i

        group_padded_offsets = torch.tensor(group_padded_offsets, dtype=torch.int32, device=self.device)
        group_idx_for_bx = torch.tensor(group_idx_for_bx, dtype=torch.int32, device=self.device)

        routed_stream = torch.cuda.default_stream()
        torch.cuda.synchronize()

        with torch.cuda.stream(routed_stream):
            # 6. 调用 Tilelang Grouped GEMM kernel
            # 输入:按专家分组的 token + 堆叠的专家权重 + 辅助索引张量
            # 输出:expert_output_routed,每个 token 已乘以路由权重
            self.routed_kernel(
                self.stacked_expert_tokens,
                self.stacked_expert_w_gate,
                self.stacked_expert_w_up,
                self.stacked_expert_w_down,
                self.stacked_expert_weights,
                group_sizes,
                group_offset,
                group_padded_offsets,
                group_idx_for_bx,
                self.up_logits_routed,
                self.expert_output_routed,
            )

            # 7. Scatter-reduce:将各专家输出按原始 token 索引累加回 expert_cache
            # 同一 token 可能被多个专家处理(top-k>1),需要求和
            self.expert_cache = torch.scatter_reduce(
                self.expert_cache,
                0,
                self.stacked_expert_tokens_idxs.view(-1, 1).repeat(1, x_flat.shape[-1]),
                self.expert_output_routed,
                reduce="sum",
            )
            routed_output = self.expert_cache.view(*orig_shape)

        torch.cuda.synchronize()

        return routed_output

tilelang内核 一阶段gate+up

tilelang内核都是处理好的数据,只用执行MOE专家推理部分,也就是gate+up+down。gate和up可以同时进行,down需要等待gate * up的结果。所以把内核分成两部分,第一部分先计算gate,up,然后gate * up

  • with T.Kernel(M, T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by):首先是分块方法,输入是[B * S * topk,hid_dim]的待推理token列表,以及[hid_dim,expert_dim]的专家参数矩阵 X u p , X g a t e X_{up},X_{gate} Xup,Xgate,输出是[B * S * topk,expert_dim]的推理中间结果。仔细看的话,这就是一个GEMM算子,经典做法就是考虑对结果数组分块,B * S * topk维度的块数,块长,在前面的MoE forward就定好了,这里直接作为参数M传进来,expert_dim维度,按照块长block_dexpert分块。
  • actual_rows = T.max(0, T.min(block_token, cur_group_size - (m_start_padded - group_padded_offsets[cur_group_idx])))这是由于我们前面的padding,分块M块是在padding后的输入上来说的,但是要取真实数据,需要转换为真实坐标
  • for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages):类似经典GEMM,一个[Block_M,Block_N]c tile的矩阵乘法,需要用到的a tile,b tile分别是[Block_M,K],[K,Block_N]的两个长条,其中K维度可能很大,于是我们对K维度也进行分块,但这里由于多个分块都要累加到同一个c tile数组,可能出现写冲突,所以不能并行,只能串行,但是也可以使用流水线提高穿行效率,于是用T.Pipelined做枚举
  • 内部就是一般的GEMM思路,读取两个a,b tile block,做乘法,写入结果数组,T.gemm(input_shared, routed_expert_gate_shared, gate_logits_local, transpose_B=True),这里要计算gate和up两次矩阵乘法。
  • 最后根据SiLU的定义,需要对gate和up结果做一个逐元素相乘,两层循环枚举即可for i, j in T.Parallel(block_token, block_dexpert):,这里逐元素,不同位置没有数据依赖,可以并行
        # =====================================================================
        # Step 1: 计算 gate logits 和 up logits,融合 SiLU 激活
        # Grid: [M, ceil(dexpert/block_dexpert)]
        #   bx:token 分块索引(含 padding)
        #   by:expert 维度分块索引
        # =====================================================================
        with T.Kernel(M, T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by):
            # shared memory:存放当前 tile 的 input 和权重
            input_shared = T.alloc_fragment((block_token, block_dhidden), dtype=dtype)
            routed_expert_gate_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype)
            routed_expert_up_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype)

            # 寄存器累加器:float32 精度
            gate_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype)
            up_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype)

            # swizzle:优化 shared memory bank conflict
            T.use_swizzle(10)

            # 当前 block 在 padded 空间中的 token 起始位置
            m_start_padded = bx * block_token

            # 查表:该 block 属于哪个专家
            cur_group_idx = group_idx_for_bx[bx]

            cur_group_size = group_sizes[cur_group_idx]
            # 将 padded 偏移转换为实际 input 中的偏移
            # m_start = padded起始 - 该专家padded偏移 + 该专家真实偏移
            m_start = m_start_padded - group_padded_offsets[cur_group_idx] + group_offsets[cur_group_idx]
            # 计算本 block 中实际有效的 token 行数(末尾 block 可能不足 block_token)
            actual_rows = T.max(0, T.min(block_token, cur_group_size - (m_start_padded - group_padded_offsets[cur_group_idx])))

            T.clear(gate_logits_local)
            T.clear(up_logits_local)

            # 沿 hidden 维度分块累加:流水线预取,减少访存延迟
            for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages):
                # 从 global memory 加载 input tile
                T.copy(
                    input[m_start : m_start + block_token, k * block_dhidden : (k + 1) * block_dhidden],
                    input_shared,
                )
                # 加载当前专家的 gate 权重 tile(形状 [dexpert, dhidden],GEMM 时转置 B)
                T.copy(
                    routed_expert_gate[
                        cur_group_idx, by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden
                    ],
                    routed_expert_gate_shared,
                )
                # input @ gate^T -> gate_logits  [block_token, block_dexpert]
                T.gemm(input_shared, routed_expert_gate_shared, gate_logits_local, transpose_B=True)
                # 加载 up 权重 tile
                T.copy(
                    routed_expert_up[
                        cur_group_idx, by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden
                    ],
                    routed_expert_up_shared,
                )
                # input @ up^T -> up_logits  [block_token, block_dexpert]
                T.gemm(input_shared, routed_expert_up_shared, up_logits_local, transpose_B=True)

            # 融合 SiLU 激活并将 gate 与 up 逐元素相乘
            # SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x))
            # 用 exp2 加速:sigmoid(x) = 1 / (1 + exp(-x)) = 1 / (1 + exp2(-x * log2e))
            for i, j in T.Parallel(block_token, block_dexpert):
                gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale)))
                up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j]

            # 将结果写回 global memory(仅写有效行,跳过 padding 行)
            for i, j in T.Parallel(block_token, block_dexpert):
                if i < actual_rows:
                    up_logits[m_start + i, by * block_dexpert + j] = up_logits_local[i, j]

tilelang内核 二阶段down

和一阶段的操作几乎一样,也是一个矩阵乘法,[B * S * topk,expert_dim]乘上[expert_dim,hid_dim],输出[B * S * topk,hid_dim]。由于输出shape和一阶段不一样,单独一个阶段,

但实际上,完全可以两个阶段融合,这也是潜在的优化点,这样就不需要全局张量来保存gate * up的中间结果了,也不需要进行全局内存和共享内存/寄存器的读写了。

        # =====================================================================
        # Step 2: down projection,将中间激活投影回 hidden 维度,并乘以路由权重
        # Grid: [M, ceil(dhidden/block_dhidden)]
        #   bx:token 分块索引(含 padding,与 Step1 相同)
        #   by:hidden 维度分块索引
        # =====================================================================
        with T.Kernel(M, T.ceildiv(dhidden, block_dhidden), threads=threads) as (bx, by):
            up_logits_shared = T.alloc_fragment((block_token, block_dexpert), dtype=dtype)
            routed_expert_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype)
            output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_dtype)

            T.use_swizzle(10)

            m_start_padded = bx * block_token
            cur_group_idx = group_idx_for_bx[bx]

            cur_group_size = group_sizes[cur_group_idx]
            m_start = m_start_padded - group_padded_offsets[cur_group_idx] + group_offsets[cur_group_idx]
            actual_rows = T.max(0, T.min(block_token, cur_group_size - (m_start_padded - group_padded_offsets[cur_group_idx])))

            T.clear(output_local)

            # 沿 expert 维度分块累加
            for k in T.Pipelined(T.ceildiv(dexpert, block_dexpert), num_stages=num_stages):
                # 加载 Step1 的中间激活
                T.copy(
                    up_logits[m_start : m_start + block_token, k * block_dexpert : (k + 1) * block_dexpert],
                    up_logits_shared,
                )
                # 加载当前专家的 down 权重 tile(形状 [dhidden, dexpert],GEMM 时转置 B)
                T.copy(
                    routed_expert_down[
                        cur_group_idx, by * block_dhidden : (by + 1) * block_dhidden, k * block_dexpert : (k + 1) * block_dexpert
                    ],
                    routed_expert_down_shared,
                )
                # up_logits @ down^T -> output  [block_token, block_dhidden]
                T.gemm(up_logits_shared, routed_expert_down_shared, output_local, transpose_B=True)

            # 写回时乘以该 token 的路由权重(scalar),跳过 padding 行
            for i, j in T.Parallel(block_token, block_dhidden):
                if i < actual_rows:
                    output[m_start + i, by * block_dhidden + j] = output_local[i, j] * routed_expert_weights[m_start + i]

    return kernel
Logo

免费领 200 小时云算力,进群参与显卡、AI PC 幸运抽奖

更多推荐