1.使用tilelang 语法编写一个gpu kernel 函数

@tilelang.jit(out_idx=[-1])
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):

    @T.prim_func
    def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor(
        (M, N), out_dtype)):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_N), in_dtype)
            B_shared = T.alloc_shared((block_M, block_N), in_dtype)
            C_local = T.alloc_fragment((block_M, block_N), out_dtype)
            C_shared = T.alloc_shared((block_M, block_N), out_dtype)
 
            T.copy(A[by * block_M, bx * block_N], A_shared)
            T.copy(B[by * block_M, bx * block_N], B_shared)
            for (local_y, local_x) in T.Parallel(block_M, block_N):
                C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x]
            T.copy(C_local, C_shared)
            T.copy(C_shared, C[by * block_M, bx * block_N])

    return elem_add

调用代码

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--m", type=int, default=1024)
    parser.add_argument("--n", type=int, default=1024)
    parser.add_argument("--use_autotune", action="store_true", default=False)
    args, _ = parser.parse_known_args()
    M, N = args.m, args.n

    a = torch.randn(M, N, dtype=torch.float32, device="cuda")
    b = torch.randn(M, N, dtype=torch.float32, device="cuda")

    if args.use_autotune:
        result = get_best_config(M, N)
        kernel = result.kernel
    else:
        # Default config
        config = {"block_M": 64, "block_N": 64, "threads": 128}
        kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")

    out = kernel(a, b)
    torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)

2.kernel的tir表示

<class 'tvm.tir.function.PrimFunc'>
# from tvm.script import tir as T

@T.prim_func
def elem_add(A_handle: T.handle, B_handle: T.handle, C_handle: T.handle):
    A = T.match_buffer(A_handle, (1024, 1024), strides=(1024, 1))
    B = T.match_buffer(B_handle, (1024, 1024), strides=(1024, 1))
    C = T.match_buffer(C_handle, (1024, 1024), strides=(1024, 1))
    # with T.block("root"):
    bx = T.launch_thread("blockIdx.x", 16)
    by = T.launch_thread("blockIdx.y", 16)
    tx = T.launch_thread("threadIdx.x", 128)
    ty = T.launch_thread("threadIdx.y", 1)
    tz = T.launch_thread("threadIdx.z", 1)
    with T.block("tilelang_root"):
        T.reads(A[by * 64, bx * 64], B[by * 64, bx * 64], C[by * 64, bx * 64])
        T.writes()
        A_shared = T.alloc_buffer((64, 64), scope="shared.dyn")
        B_shared = T.alloc_buffer((64, 64), scope="shared.dyn")
        C_local = T.alloc_buffer((64, 64), scope="local.fragment")
        C_shared = T.alloc_buffer((64, 64), scope="shared.dyn")
        T.copy(T.region(A[by * 64, bx * 64], 1, 64, 64), T.region(A_shared[0, 0], 2, 64, 64), -1, T.bool(False), 0)
        T.copy(T.region(B[by * 64, bx * 64], 1, 64, 64), T.region(B_shared[0, 0], 2, 64, 64), -1, T.bool(False), 0)
        for local_y in T.parallel(64):
            for local_x in T.parallel(64):
                C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x]
        T.copy(T.region(C_local[0, 0], 1, 64, 64), T.region(C_shared[0, 0], 2, 64, 64), -1, T.bool(False), 0)
        T.copy(T.region(C_shared[0, 0], 1, 64, 64), T.region(C[by * 64, bx * 64], 2, 64, 64), -1, T.bool(False), 0)

3.对tir prim func 进行cuda后端编译

3.1 compile

tilelang/jit/__init__.py
def compile(func: PrimFunc = None,

3.2 cached

tilelang/cache/__init__.py
def cached(func: PrimFunc = None

3.3 cached

tilelang/cache/kernel_cache.py
class KernelCache:
def cached(self,func: PrimFunc = None

3.4 _load_kernel_from_disk

tilelang/cache/kernel_cache.py
class KernelCache:
def _load_kernel_from_disk(

判断是否存在缓存文件

3.5 初始化JITKernel

        kernel = JITKernel(
            func,
            out_idx=out_idx,
            execution_backend=execution_backend,
            target=target,
            target_host=target_host,
            verbose=verbose,
            pass_configs=pass_configs,
            compile_flags=compile_flags,
        )

3.6 JITKernel init函数

tilelang/jit/kernel.py
class JITKernel:
def __init__(

调用adapter = self._compile_and_create_adapter(func, out_idx)函数

3.7 _compile_and_create_adapter

调用tilelang.lower

            artifact = tilelang.lower(
                tilelang_func,
                target=target,
                target_host=target_host,
                enable_host_codegen=enable_host_codegen,
                enable_device_compile=enable_device_compile)

3.8 tilelang.lower

tilelang/engine/lower.py
def lower(func_or_mod: tir.PrimFunc | tvm.IRModule,

mod = LowerAndLegalize(mod, target)

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def elem_add(A_handle: T.handle, B_handle: T.handle, C_handle: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_120", "host": {"keys": ["cpu"], "kind": "c", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32})})
        A = T.match_buffer(A_handle, (1024, 1024), strides=(1024, 1))
        B = T.match_buffer(B_handle, (1024, 1024), strides=(1024, 1))
        C = T.match_buffer(C_handle, (1024, 1024), strides=(1024, 1))
        with T.block("root"):
            T.reads()
            T.writes()
            A_shared = T.Buffer((1, 1, 64, 64), scope="shared.dyn")
            B_shared = T.Buffer((1, 1, 64, 64), scope="shared.dyn")
            C_local = T.Buffer((32,), scope="local")
            C_shared = T.Buffer((1, 1, 64, 64), scope="shared.dyn")
            T.block_attr({"layout_map": {A_shared: metadata["tl.Layout"][0], B_shared: metadata["tl.Layout"][1], C_local: metadata["tl.Fragment"][0], C_shared: metadata["tl.Layout"][2]}})
            bx = T.launch_thread("blockIdx.x", 16)
            by = T.launch_thread("blockIdx.y", 16)
            tx = T.launch_thread("threadIdx.x", 128)
            ty = T.launch_thread("threadIdx.y", 1)
            tz = T.launch_thread("threadIdx.z", 1)
            with T.block("tilelang_root"):
                T.reads(A[by * 64, bx * 64], B[by * 64, bx * 64], C[by * 64, bx * 64])
                T.writes()
                T.block_attr({"layout_map": {A_shared: metadata["tl.Layout"][0], B_shared: metadata["tl.Layout"][1], C_local: metadata["tl.Fragment"][0], C_shared: metadata["tl.Layout"][2]}})
                A_shared = T.alloc_buffer((1, 1, 64, 64), data=A_shared.data, scope="shared.dyn")
                B_shared = T.alloc_buffer((1, 1, 64, 64), data=B_shared.data, scope="shared.dyn")
                C_local = T.alloc_buffer((32,), data=C_local.data, scope="local")
                C_shared = T.alloc_buffer((1, 1, 64, 64), data=C_shared.data, scope="shared.dyn")
                if tx == 0:
                    T.tma_load(T.create_tma_descriptor(7, 2, A.data, 1024, 1024, T.int64(4), T.int64(4096), 64, 64, 1, 1, 0, 0, 2, 0), 0, T.tvm_access_ptr(T.type_annotation("float32"), A_shared.data, 0, 4096, 2), bx * 64, by * 64, 0)
                if tx == 0:
                    T.tma_load(T.create_tma_descriptor(7, 2, B.data, 1024, 1024, T.int64(4), T.int64(4096), 64, 64, 1, 1, 0, 0, 2, 0), 0, T.tvm_access_ptr(T.type_annotation("float32"), B_shared.data, 0, 4096, 2), bx * 64, by * 64, 0)
                for i in T.unroll(8, annotations={"pragma_unroll_explicit": T.bool(False)}):
                    for vec in T.vectorized(4):
                        C_local[i * 4 + vec] = A_shared[0, 0, i * 8 + tx // 16, tx % 16 * 4 + vec] + B_shared[0, 0, i * 8 + tx // 16, tx % 16 * 4 + vec]
                for i in T.unroll(8, annotations={"pragma_unroll_explicit": T.bool(False)}):
                    for vec in T.vectorized(4):
                        C_shared[0, 0, i * 8 + tx // 16, tx % 16 * 4 + vec] = C_local[i * 4 + vec]
                if tx == 0:
                    T.tma_store(T.create_tma_descriptor(7, 2, C.data, 1024, 1024, T.int64(4), T.int64(4096), 64, 64, 1, 1, 0, 0, 2, 0), T.tvm_access_ptr(T.type_annotation("float32"), C_shared.data, 0, 4096, 1), bx * 64, by * 64, 0, 0)

# Metadata omitted. Use show_meta=True in script() method to show it.

mod = OptimizeForTarget(mod, target)

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def elem_add(args: T.handle, arg_type_ids: T.handle("int32", "global"), num_args: T.int32, out_ret_value: T.handle("void", "global"), out_ret_tcode: T.handle("int32", "global"), resource_handle: T.handle) -> T.int32:
        C_desc = T.handle("uint8x128", "grid_constant")
        C = T.handle("float32", "global")
        B_desc = T.handle("uint8x128", "grid_constant")
        B = T.handle("float32", "global")
        A_desc = T.handle("uint8x128", "grid_constant")
        A = T.handle("float32", "global")
        T.func_attr({"calling_conv": 1, "target": T.target({"keys": ["cpu"], "kind": "c", "tag": ""}), "thread_extent": {}, "tir.is_entry_func": True, "tma_descriptor_args": {C_desc: ["__tvm_tensormap_create_tiled", C_desc, 7, 2, C, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0], B_desc: ["__tvm_tensormap_create_tiled", B_desc, 7, 2, B, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0], A_desc: ["__tvm_tensormap_create_tiled", A_desc, 7, 2, A, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0]}})
        assert num_args == 3, "elem_add: num_args should be 3"
        assert not T.isnullptr(args), "elem_add: TVMValue* arg pointer was NULL"
        assert not T.isnullptr(arg_type_ids), "elem_add: int* type_codes was NULL"
        arg_type_ids_1 = T.decl_buffer((3,), "int32", data=arg_type_ids)
        A_handle_code: T.int32 = arg_type_ids_1[0]
        assert A_handle_code == 0 or A_handle_code == 4 or A_handle_code == 7 or A_handle_code >= 64, "elem_add: Expect arg[0] to be pointer"
        B_handle_code: T.int32 = arg_type_ids_1[1]
        assert B_handle_code == 0 or B_handle_code == 4 or B_handle_code == 7 or B_handle_code >= 64, "elem_add: Expect arg[1] to be pointer"
        C_handle_code: T.int32 = arg_type_ids_1[2]
        assert C_handle_code == 0 or C_handle_code == 4 or C_handle_code == 7 or C_handle_code >= 64, "elem_add: Expect arg[2] to be pointer"
        A_handle: T.handle = T.tvm_struct_get(args, 0, 12, "handle")
        B_handle: T.handle = T.tvm_struct_get(args, 1, 12, "handle")
        C_handle: T.handle = T.tvm_struct_get(args, 2, 12, "handle")
        assert not T.isnullptr(A_handle), "elem_add.A_handle is expected to have non-NULL DLTensor* pointer"
        assert 2 == T.tvm_struct_get(A_handle, 0, 4, "int32"), "elem_add.A_handle.ndim is expected to equal 2"
        elem_add_A_handle_shape: T.handle("int64", "global") = T.tvm_struct_get(A_handle, 0, 2, "handle")
        elem_add_A_handle_shape_1 = T.decl_buffer((2,), "int64", data=elem_add_A_handle_shape)
        elem_add_A_handle_strides: T.handle("int64", "global") = T.tvm_struct_get(A_handle, 0, 3, "handle")
        elem_add_A_handle_strides_1 = T.decl_buffer((2,), "int64", data=elem_add_A_handle_strides)
        dev_id: T.int32 = T.tvm_struct_get(A_handle, 0, 9, "int32")
        with T.LetStmt(T.tvm_struct_get(A_handle, 0, 1, "handle"), var=A):
            T.attr(A, "storage_alignment", 64)
            assert not T.isnullptr(B_handle), "elem_add.B_handle is expected to have non-NULL DLTensor* pointer"
            assert 2 == T.tvm_struct_get(B_handle, 0, 4, "int32"), "elem_add.B_handle.ndim is expected to equal 2"
            elem_add_B_handle_shape: T.handle("int64", "global") = T.tvm_struct_get(B_handle, 0, 2, "handle")
            elem_add_B_handle_shape_1 = T.decl_buffer((2,), "int64", data=elem_add_B_handle_shape)
            elem_add_B_handle_strides: T.handle("int64", "global") = T.tvm_struct_get(B_handle, 0, 3, "handle")
            elem_add_B_handle_strides_1 = T.decl_buffer((2,), "int64", data=elem_add_B_handle_strides)
            with T.LetStmt(T.tvm_struct_get(B_handle, 0, 1, "handle"), var=B):
                T.attr(B, "storage_alignment", 64)
                assert not T.isnullptr(C_handle), "elem_add.C_handle is expected to have non-NULL DLTensor* pointer"
                assert 2 == T.tvm_struct_get(C_handle, 0, 4, "int32"), "elem_add.C_handle.ndim is expected to equal 2"
                elem_add_C_handle_shape: T.handle("int64", "global") = T.tvm_struct_get(C_handle, 0, 2, "handle")
                elem_add_C_handle_shape_1 = T.decl_buffer((2,), "int64", data=elem_add_C_handle_shape)
                elem_add_C_handle_strides: T.handle("int64", "global") = T.tvm_struct_get(C_handle, 0, 3, "handle")
                elem_add_C_handle_strides_1 = T.decl_buffer((2,), "int64", data=elem_add_C_handle_strides)
                with T.LetStmt(T.tvm_struct_get(C_handle, 0, 1, "handle"), var=C):
                    T.attr(C, "storage_alignment", 64)
                    T.attr("default", "device_id", dev_id)
                    T.attr("default", "device_type", 2)
                    assert T.tvm_struct_get(A_handle, 0, 5, "uint8") == T.uint8(2) and T.tvm_struct_get(A_handle, 0, 6, "uint8") == T.uint8(32) and T.tvm_struct_get(A_handle, 0, 7, "uint16") == T.uint16(1), "elem_add.A_handle.dtype is expected to be float32"
                    assert T.Cast("int32", elem_add_A_handle_shape_1[0]) == 1024, "Argument elem_add.A_handle.shape[0] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_A_handle_shape[0])"
                    assert T.Cast("int32", elem_add_A_handle_shape_1[1]) == 1024, "Argument elem_add.A_handle.shape[1] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_A_handle_shape[1])"
                    assert T.if_then_else(T.isnullptr(elem_add_A_handle_strides), 1, T.Cast("int32", elem_add_A_handle_strides_1[1])) == 1, "Argument elem_add.A_handle.strides[1] has an unsatisfied constraint: 1 == T.if_then_else(T.isnullptr(elem_add_A_handle_strides), 1, T.Cast(\"int32\", elem_add_A_handle_strides_1[1]))"
                    assert T.if_then_else(T.isnullptr(elem_add_A_handle_strides), T.Cast("int32", elem_add_A_handle_shape_1[1]), T.Cast("int32", elem_add_A_handle_strides_1[0])) == 1024, "Argument elem_add.A_handle.strides[0] has an unsatisfied constraint: 1024 == T.if_then_else(T.isnullptr(elem_add_A_handle_strides), T.Cast(\"int32\", elem_add_A_handle_shape[1]), T.Cast(\"int32\", elem_add_A_handle_strides_1[0]))"
                    assert T.uint64(0) == T.tvm_struct_get(A_handle, 0, 8, "uint64"), "Argument elem_add.A_handle.byte_offset has an unsatisfied constraint: T.uint64(0) == T.tvm_struct_get(A_handle, 0, 8, \"uint64\")"
                    assert T.tvm_struct_get(A_handle, 0, 10, "int32") == 2, "Argument elem_add.A_handle.device_type has an unsatisfied constraint: 2 == T.tvm_struct_get(A_handle, 0, 10, \"int32\")"
                    assert not T.isnullptr(A), "elem_add.A_handle is expected to have non-NULL data pointer"
                    assert T.tvm_struct_get(B_handle, 0, 5, "uint8") == T.uint8(2) and T.tvm_struct_get(B_handle, 0, 6, "uint8") == T.uint8(32) and T.tvm_struct_get(B_handle, 0, 7, "uint16") == T.uint16(1), "elem_add.B_handle.dtype is expected to be float32"
                    assert T.Cast("int32", elem_add_B_handle_shape_1[0]) == 1024, "Argument elem_add.B_handle.shape[0] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_B_handle_shape[0])"
                    assert T.Cast("int32", elem_add_B_handle_shape_1[1]) == 1024, "Argument elem_add.B_handle.shape[1] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_B_handle_shape[1])"
                    assert T.if_then_else(T.isnullptr(elem_add_B_handle_strides), 1, T.Cast("int32", elem_add_B_handle_strides_1[1])) == 1, "Argument elem_add.B_handle.strides[1] has an unsatisfied constraint: 1 == T.if_then_else(T.isnullptr(elem_add_B_handle_strides), 1, T.Cast(\"int32\", elem_add_B_handle_strides_1[1]))"
                    assert T.if_then_else(T.isnullptr(elem_add_B_handle_strides), T.Cast("int32", elem_add_B_handle_shape_1[1]), T.Cast("int32", elem_add_B_handle_strides_1[0])) == 1024, "Argument elem_add.B_handle.strides[0] has an unsatisfied constraint: 1024 == T.if_then_else(T.isnullptr(elem_add_B_handle_strides), T.Cast(\"int32\", elem_add_B_handle_shape[1]), T.Cast(\"int32\", elem_add_B_handle_strides_1[0]))"
                    assert T.uint64(0) == T.tvm_struct_get(B_handle, 0, 8, "uint64"), "Argument elem_add.B_handle.byte_offset has an unsatisfied constraint: T.uint64(0) == T.tvm_struct_get(B_handle, 0, 8, \"uint64\")"
                    assert T.tvm_struct_get(B_handle, 0, 10, "int32") == 2, "Argument elem_add.B_handle.device_type has an unsatisfied constraint: 2 == T.tvm_struct_get(B_handle, 0, 10, \"int32\")"
                    assert dev_id == T.tvm_struct_get(B_handle, 0, 9, "int32"), "Argument elem_add.B_handle.device_id has an unsatisfied constraint: dev_id == T.tvm_struct_get(B_handle, 0, 9, \"int32\")"
                    assert not T.isnullptr(B), "elem_add.B_handle is expected to have non-NULL data pointer"
                    assert T.tvm_struct_get(C_handle, 0, 5, "uint8") == T.uint8(2) and T.tvm_struct_get(C_handle, 0, 6, "uint8") == T.uint8(32) and T.tvm_struct_get(C_handle, 0, 7, "uint16") == T.uint16(1), "elem_add.C_handle.dtype is expected to be float32"
                    assert T.Cast("int32", elem_add_C_handle_shape_1[0]) == 1024, "Argument elem_add.C_handle.shape[0] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_C_handle_shape[0])"
                    assert T.Cast("int32", elem_add_C_handle_shape_1[1]) == 1024, "Argument elem_add.C_handle.shape[1] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_C_handle_shape[1])"
                    assert T.if_then_else(T.isnullptr(elem_add_C_handle_strides), 1, T.Cast("int32", elem_add_C_handle_strides_1[1])) == 1, "Argument elem_add.C_handle.strides[1] has an unsatisfied constraint: 1 == T.if_then_else(T.isnullptr(elem_add_C_handle_strides), 1, T.Cast(\"int32\", elem_add_C_handle_strides_1[1]))"
                    assert T.if_then_else(T.isnullptr(elem_add_C_handle_strides), T.Cast("int32", elem_add_C_handle_shape_1[1]), T.Cast("int32", elem_add_C_handle_strides_1[0])) == 1024, "Argument elem_add.C_handle.strides[0] has an unsatisfied constraint: 1024 == T.if_then_else(T.isnullptr(elem_add_C_handle_strides), T.Cast(\"int32\", elem_add_C_handle_shape[1]), T.Cast(\"int32\", elem_add_C_handle_strides_1[0]))"
                    assert T.uint64(0) == T.tvm_struct_get(C_handle, 0, 8, "uint64"), "Argument elem_add.C_handle.byte_offset has an unsatisfied constraint: T.uint64(0) == T.tvm_struct_get(C_handle, 0, 8, \"uint64\")"
                    assert T.tvm_struct_get(C_handle, 0, 10, "int32") == 2, "Argument elem_add.C_handle.device_type has an unsatisfied constraint: 2 == T.tvm_struct_get(C_handle, 0, 10, \"int32\")"
                    assert dev_id == T.tvm_struct_get(C_handle, 0, 9, "int32"), "Argument elem_add.C_handle.device_id has an unsatisfied constraint: dev_id == T.tvm_struct_get(C_handle, 0, 9, \"int32\")"
                    assert not T.isnullptr(C), "elem_add.C_handle is expected to have non-NULL data pointer"
                    A_1 = T.decl_buffer((1024, 1024), data=A, strides=(1024, 1))
                    B_1 = T.decl_buffer((1024, 1024), data=B, strides=(1024, 1))
                    C_1 = T.decl_buffer((1024, 1024), data=C, strides=(1024, 1))
                    assert T.FloorMod(1024, 8) == 0, "A: Vectorize dimension in buffer must be divisible by 8"
                    assert T.FloorMod(1024, 8) == 0, "B: Vectorize dimension in buffer must be divisible by 8"
                    assert T.FloorMod(1024, 8) == 0, "C: Vectorize dimension in buffer must be divisible by 8"
                    T.call_packed("__tvm_set_device", 2, dev_id)
                    with T.attr(0, "compute_scope", "elem_add_compute_"):
                        with T.LetStmt(T.tvm_stack_alloca("arg_value", 16), var=A_desc):
                            T.call_packed("__tvm_tensormap_create_tiled", A_desc, 7, 2, A, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0)
                            with T.LetStmt(T.tvm_stack_alloca("arg_value", 16), var=B_desc):
                                T.call_packed("__tvm_tensormap_create_tiled", B_desc, 7, 2, B, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0)
                                with T.LetStmt(T.tvm_stack_alloca("arg_value", 16), var=C_desc):
                                    T.call_packed("__tvm_tensormap_create_tiled", C_desc, 7, 2, C, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0)
                                    T.call_packed("elem_add_kernel", A_desc, B_desc, C_desc, 16, 16, 256, 1, 1, 32768)
                    return 0

    @T.prim_func
    def elem_add_kernel(A_desc: T.handle("uint8x128", "grid_constant"), B_desc: T.handle("uint8x128", "grid_constant"), C_desc: T.handle("uint8x128", "grid_constant")):
        T.func_attr({"calling_conv": 2, "dyn_shared_memory_buf": 32768, "target": T.target({"arch": "sm_120", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "thread_extent": {"blockIdx.x": 16, "blockIdx.y": 16, "threadIdx.x": 256, "threadIdx.y": 1, "threadIdx.z": 1}, "tir.is_global_func": T.bool(True), "tir.kernel_launch_params": ["blockIdx.x", "blockIdx.y", "threadIdx.x", "threadIdx.y", "threadIdx.z", "tir.use_dyn_shared_memory"], "tir.noalias": True})
        buf_dyn_shmem = T.handle("uint8", "shared.dyn")
        C_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
        B_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
        A_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
        C_local = T.handle("float32", "local")
        C_local_1 = T.decl_buffer((32,), data=C_local, scope="local")
        bx = T.launch_thread("blockIdx.x", 16)
        buf_dyn_shmem = T.allocate([32768], "uint8", "shared.dyn")
        C_local = T.allocate([32], "float32", "local")
        by = T.launch_thread("blockIdx.y", 16)
        tx = T.launch_thread("threadIdx.x", 256)
        T.create_barriers(1)
        if T.tl_shuffle_elect(0):
            T.call_extern("handle", "tl::prefetch_tma_descriptor", A_desc)
            T.call_extern("handle", "tl::prefetch_tma_descriptor", B_desc)
            T.call_extern("handle", "tl::prefetch_tma_descriptor", C_desc)
            T.ptx_init_barrier_thread_count(T.get_mbarrier(0), 128)
        T.ptx_fence_barrier_init()
        T.tvm_storage_sync("shared")
        ty = T.launch_thread("threadIdx.y", 1)
        tz = T.launch_thread("threadIdx.z", 1)
        T.attr([128, 128], "kWarpSpecializationScope", 0)
        if 128 <= tx:
            T.set_max_nreg(24, 0)
            if T.tl_shuffle_elect(128):
                T.mbarrier_expect_tx(T.get_mbarrier(0), 16384)
                T.fence_proxy_async()
                T.tma_load(A_desc, T.get_mbarrier(0), T.tvm_access_ptr(T.type_annotation("float32"), buf_dyn_shmem, 0, 4096, 2), bx * 64, by * 64, 0)
                T.mbarrier_expect_tx(T.get_mbarrier(0), 16384)
                T.fence_proxy_async()
                T.tma_load(B_desc, T.get_mbarrier(0), T.tvm_access_ptr(T.type_annotation("float32"), buf_dyn_shmem, 4096, 4096, 2), bx * 64, by * 64, 0)
            T.ptx_arrive_barrier(T.get_mbarrier(0))
        else:
            T.set_max_nreg(240, 1)
            T.mbarrier_wait_parity(T.get_mbarrier(0), 0)
            for i in T.unroll(8):
                C_local_1[i * 4:i * 4 + 4] = A_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(0, 4)] + B_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(4096, 4)]
            T.tvm_storage_sync("shared.dyn", 3, 128)
            for i in T.unroll(8):
                C_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(0, 4)] = C_local_1[i * 4:i * 4 + 4]
            T.fence_proxy_async()
            T.tvm_storage_sync("shared.dyn", 3, 128)
            if T.tl_shuffle_elect(128):
                T.tma_store(C_desc, T.tvm_access_ptr(T.type_annotation("float32"), buf_dyn_shmem, 0, 4096, 1), bx * 64, by * 64, 0, 0)
                T.tma_store_arrive()
                T.tma_store_wait()

host_mod = tir.transform.Filter(_is_host_call)(mod)

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def elem_add(args: T.handle, arg_type_ids: T.handle("int32", "global"), num_args: T.int32, out_ret_value: T.handle("void", "global"), out_ret_tcode: T.handle("int32", "global"), resource_handle: T.handle) -> T.int32:
        C_desc = T.handle("uint8x128", "grid_constant")
        C = T.handle("float32", "global")
        B_desc = T.handle("uint8x128", "grid_constant")
        B = T.handle("float32", "global")
        A_desc = T.handle("uint8x128", "grid_constant")
        A = T.handle("float32", "global")
        T.func_attr({"calling_conv": 1, "target": T.target({"keys": ["cpu"], "kind": "c", "tag": ""}), "thread_extent": {}, "tir.is_entry_func": True, "tma_descriptor_args": {C_desc: ["__tvm_tensormap_create_tiled", C_desc, 7, 2, C, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0], B_desc: ["__tvm_tensormap_create_tiled", B_desc, 7, 2, B, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0], A_desc: ["__tvm_tensormap_create_tiled", A_desc, 7, 2, A, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0]}})
        assert num_args == 3, "elem_add: num_args should be 3"
        assert not T.isnullptr(args), "elem_add: TVMValue* arg pointer was NULL"
        assert not T.isnullptr(arg_type_ids), "elem_add: int* type_codes was NULL"
        arg_type_ids_1 = T.decl_buffer((3,), "int32", data=arg_type_ids)
        A_handle_code: T.int32 = arg_type_ids_1[0]
        assert A_handle_code == 0 or A_handle_code == 4 or A_handle_code == 7 or A_handle_code >= 64, "elem_add: Expect arg[0] to be pointer"
        B_handle_code: T.int32 = arg_type_ids_1[1]
        assert B_handle_code == 0 or B_handle_code == 4 or B_handle_code == 7 or B_handle_code >= 64, "elem_add: Expect arg[1] to be pointer"
        C_handle_code: T.int32 = arg_type_ids_1[2]
        assert C_handle_code == 0 or C_handle_code == 4 or C_handle_code == 7 or C_handle_code >= 64, "elem_add: Expect arg[2] to be pointer"
        A_handle: T.handle = T.tvm_struct_get(args, 0, 12, "handle")
        B_handle: T.handle = T.tvm_struct_get(args, 1, 12, "handle")
        C_handle: T.handle = T.tvm_struct_get(args, 2, 12, "handle")
        assert not T.isnullptr(A_handle), "elem_add.A_handle is expected to have non-NULL DLTensor* pointer"
        assert 2 == T.tvm_struct_get(A_handle, 0, 4, "int32"), "elem_add.A_handle.ndim is expected to equal 2"
        elem_add_A_handle_shape: T.handle("int64", "global") = T.tvm_struct_get(A_handle, 0, 2, "handle")
        elem_add_A_handle_shape_1 = T.decl_buffer((2,), "int64", data=elem_add_A_handle_shape)
        elem_add_A_handle_strides: T.handle("int64", "global") = T.tvm_struct_get(A_handle, 0, 3, "handle")
        elem_add_A_handle_strides_1 = T.decl_buffer((2,), "int64", data=elem_add_A_handle_strides)
        dev_id: T.int32 = T.tvm_struct_get(A_handle, 0, 9, "int32")
        with T.LetStmt(T.tvm_struct_get(A_handle, 0, 1, "handle"), var=A):
            T.attr(A, "storage_alignment", 64)
            assert not T.isnullptr(B_handle), "elem_add.B_handle is expected to have non-NULL DLTensor* pointer"
            assert 2 == T.tvm_struct_get(B_handle, 0, 4, "int32"), "elem_add.B_handle.ndim is expected to equal 2"
            elem_add_B_handle_shape: T.handle("int64", "global") = T.tvm_struct_get(B_handle, 0, 2, "handle")
            elem_add_B_handle_shape_1 = T.decl_buffer((2,), "int64", data=elem_add_B_handle_shape)
            elem_add_B_handle_strides: T.handle("int64", "global") = T.tvm_struct_get(B_handle, 0, 3, "handle")
            elem_add_B_handle_strides_1 = T.decl_buffer((2,), "int64", data=elem_add_B_handle_strides)
            with T.LetStmt(T.tvm_struct_get(B_handle, 0, 1, "handle"), var=B):
                T.attr(B, "storage_alignment", 64)
                assert not T.isnullptr(C_handle), "elem_add.C_handle is expected to have non-NULL DLTensor* pointer"
                assert 2 == T.tvm_struct_get(C_handle, 0, 4, "int32"), "elem_add.C_handle.ndim is expected to equal 2"
                elem_add_C_handle_shape: T.handle("int64", "global") = T.tvm_struct_get(C_handle, 0, 2, "handle")
                elem_add_C_handle_shape_1 = T.decl_buffer((2,), "int64", data=elem_add_C_handle_shape)
                elem_add_C_handle_strides: T.handle("int64", "global") = T.tvm_struct_get(C_handle, 0, 3, "handle")
                elem_add_C_handle_strides_1 = T.decl_buffer((2,), "int64", data=elem_add_C_handle_strides)
                with T.LetStmt(T.tvm_struct_get(C_handle, 0, 1, "handle"), var=C):
                    T.attr(C, "storage_alignment", 64)
                    T.attr("default", "device_id", dev_id)
                    T.attr("default", "device_type", 2)
                    assert T.tvm_struct_get(A_handle, 0, 5, "uint8") == T.uint8(2) and T.tvm_struct_get(A_handle, 0, 6, "uint8") == T.uint8(32) and T.tvm_struct_get(A_handle, 0, 7, "uint16") == T.uint16(1), "elem_add.A_handle.dtype is expected to be float32"
                    assert T.Cast("int32", elem_add_A_handle_shape_1[0]) == 1024, "Argument elem_add.A_handle.shape[0] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_A_handle_shape[0])"
                    assert T.Cast("int32", elem_add_A_handle_shape_1[1]) == 1024, "Argument elem_add.A_handle.shape[1] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_A_handle_shape[1])"
                    assert T.if_then_else(T.isnullptr(elem_add_A_handle_strides), 1, T.Cast("int32", elem_add_A_handle_strides_1[1])) == 1, "Argument elem_add.A_handle.strides[1] has an unsatisfied constraint: 1 == T.if_then_else(T.isnullptr(elem_add_A_handle_strides), 1, T.Cast(\"int32\", elem_add_A_handle_strides_1[1]))"
                    assert T.if_then_else(T.isnullptr(elem_add_A_handle_strides), T.Cast("int32", elem_add_A_handle_shape_1[1]), T.Cast("int32", elem_add_A_handle_strides_1[0])) == 1024, "Argument elem_add.A_handle.strides[0] has an unsatisfied constraint: 1024 == T.if_then_else(T.isnullptr(elem_add_A_handle_strides), T.Cast(\"int32\", elem_add_A_handle_shape[1]), T.Cast(\"int32\", elem_add_A_handle_strides_1[0]))"
                    assert T.uint64(0) == T.tvm_struct_get(A_handle, 0, 8, "uint64"), "Argument elem_add.A_handle.byte_offset has an unsatisfied constraint: T.uint64(0) == T.tvm_struct_get(A_handle, 0, 8, \"uint64\")"
                    assert T.tvm_struct_get(A_handle, 0, 10, "int32") == 2, "Argument elem_add.A_handle.device_type has an unsatisfied constraint: 2 == T.tvm_struct_get(A_handle, 0, 10, \"int32\")"
                    assert not T.isnullptr(A), "elem_add.A_handle is expected to have non-NULL data pointer"
                    assert T.tvm_struct_get(B_handle, 0, 5, "uint8") == T.uint8(2) and T.tvm_struct_get(B_handle, 0, 6, "uint8") == T.uint8(32) and T.tvm_struct_get(B_handle, 0, 7, "uint16") == T.uint16(1), "elem_add.B_handle.dtype is expected to be float32"
                    assert T.Cast("int32", elem_add_B_handle_shape_1[0]) == 1024, "Argument elem_add.B_handle.shape[0] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_B_handle_shape[0])"
                    assert T.Cast("int32", elem_add_B_handle_shape_1[1]) == 1024, "Argument elem_add.B_handle.shape[1] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_B_handle_shape[1])"
                    assert T.if_then_else(T.isnullptr(elem_add_B_handle_strides), 1, T.Cast("int32", elem_add_B_handle_strides_1[1])) == 1, "Argument elem_add.B_handle.strides[1] has an unsatisfied constraint: 1 == T.if_then_else(T.isnullptr(elem_add_B_handle_strides), 1, T.Cast(\"int32\", elem_add_B_handle_strides_1[1]))"
                    assert T.if_then_else(T.isnullptr(elem_add_B_handle_strides), T.Cast("int32", elem_add_B_handle_shape_1[1]), T.Cast("int32", elem_add_B_handle_strides_1[0])) == 1024, "Argument elem_add.B_handle.strides[0] has an unsatisfied constraint: 1024 == T.if_then_else(T.isnullptr(elem_add_B_handle_strides), T.Cast(\"int32\", elem_add_B_handle_shape[1]), T.Cast(\"int32\", elem_add_B_handle_strides_1[0]))"
                    assert T.uint64(0) == T.tvm_struct_get(B_handle, 0, 8, "uint64"), "Argument elem_add.B_handle.byte_offset has an unsatisfied constraint: T.uint64(0) == T.tvm_struct_get(B_handle, 0, 8, \"uint64\")"
                    assert T.tvm_struct_get(B_handle, 0, 10, "int32") == 2, "Argument elem_add.B_handle.device_type has an unsatisfied constraint: 2 == T.tvm_struct_get(B_handle, 0, 10, \"int32\")"
                    assert dev_id == T.tvm_struct_get(B_handle, 0, 9, "int32"), "Argument elem_add.B_handle.device_id has an unsatisfied constraint: dev_id == T.tvm_struct_get(B_handle, 0, 9, \"int32\")"
                    assert not T.isnullptr(B), "elem_add.B_handle is expected to have non-NULL data pointer"
                    assert T.tvm_struct_get(C_handle, 0, 5, "uint8") == T.uint8(2) and T.tvm_struct_get(C_handle, 0, 6, "uint8") == T.uint8(32) and T.tvm_struct_get(C_handle, 0, 7, "uint16") == T.uint16(1), "elem_add.C_handle.dtype is expected to be float32"
                    assert T.Cast("int32", elem_add_C_handle_shape_1[0]) == 1024, "Argument elem_add.C_handle.shape[0] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_C_handle_shape[0])"
                    assert T.Cast("int32", elem_add_C_handle_shape_1[1]) == 1024, "Argument elem_add.C_handle.shape[1] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_C_handle_shape[1])"
                    assert T.if_then_else(T.isnullptr(elem_add_C_handle_strides), 1, T.Cast("int32", elem_add_C_handle_strides_1[1])) == 1, "Argument elem_add.C_handle.strides[1] has an unsatisfied constraint: 1 == T.if_then_else(T.isnullptr(elem_add_C_handle_strides), 1, T.Cast(\"int32\", elem_add_C_handle_strides_1[1]))"
                    assert T.if_then_else(T.isnullptr(elem_add_C_handle_strides), T.Cast("int32", elem_add_C_handle_shape_1[1]), T.Cast("int32", elem_add_C_handle_strides_1[0])) == 1024, "Argument elem_add.C_handle.strides[0] has an unsatisfied constraint: 1024 == T.if_then_else(T.isnullptr(elem_add_C_handle_strides), T.Cast(\"int32\", elem_add_C_handle_shape[1]), T.Cast(\"int32\", elem_add_C_handle_strides_1[0]))"
                    assert T.uint64(0) == T.tvm_struct_get(C_handle, 0, 8, "uint64"), "Argument elem_add.C_handle.byte_offset has an unsatisfied constraint: T.uint64(0) == T.tvm_struct_get(C_handle, 0, 8, \"uint64\")"
                    assert T.tvm_struct_get(C_handle, 0, 10, "int32") == 2, "Argument elem_add.C_handle.device_type has an unsatisfied constraint: 2 == T.tvm_struct_get(C_handle, 0, 10, \"int32\")"
                    assert dev_id == T.tvm_struct_get(C_handle, 0, 9, "int32"), "Argument elem_add.C_handle.device_id has an unsatisfied constraint: dev_id == T.tvm_struct_get(C_handle, 0, 9, \"int32\")"
                    assert not T.isnullptr(C), "elem_add.C_handle is expected to have non-NULL data pointer"
                    A_1 = T.decl_buffer((1024, 1024), data=A, strides=(1024, 1))
                    B_1 = T.decl_buffer((1024, 1024), data=B, strides=(1024, 1))
                    C_1 = T.decl_buffer((1024, 1024), data=C, strides=(1024, 1))
                    assert T.FloorMod(1024, 8) == 0, "A: Vectorize dimension in buffer must be divisible by 8"
                    assert T.FloorMod(1024, 8) == 0, "B: Vectorize dimension in buffer must be divisible by 8"
                    assert T.FloorMod(1024, 8) == 0, "C: Vectorize dimension in buffer must be divisible by 8"
                    T.call_packed("__tvm_set_device", 2, dev_id)
                    with T.attr(0, "compute_scope", "elem_add_compute_"):
                        with T.LetStmt(T.tvm_stack_alloca("arg_value", 16), var=A_desc):
                            T.call_packed("__tvm_tensormap_create_tiled", A_desc, 7, 2, A, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0)
                            with T.LetStmt(T.tvm_stack_alloca("arg_value", 16), var=B_desc):
                                T.call_packed("__tvm_tensormap_create_tiled", B_desc, 7, 2, B, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0)
                                with T.LetStmt(T.tvm_stack_alloca("arg_value", 16), var=C_desc):
                                    T.call_packed("__tvm_tensormap_create_tiled", C_desc, 7, 2, C, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0)
                                    T.call_packed("elem_add_kernel", A_desc, B_desc, C_desc, 16, 16, 256, 1, 1, 32768)
                    return 0

device_mod = tir.transform.Filter(_is_device_call)(mod)

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def elem_add_kernel(A_desc: T.handle("uint8x128", "grid_constant"), B_desc: T.handle("uint8x128", "grid_constant"), C_desc: T.handle("uint8x128", "grid_constant")):
        T.func_attr({"calling_conv": 2, "dyn_shared_memory_buf": 32768, "target": T.target({"arch": "sm_120", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "thread_extent": {"blockIdx.x": 16, "blockIdx.y": 16, "threadIdx.x": 256, "threadIdx.y": 1, "threadIdx.z": 1}, "tir.is_global_func": T.bool(True), "tir.kernel_launch_params": ["blockIdx.x", "blockIdx.y", "threadIdx.x", "threadIdx.y", "threadIdx.z", "tir.use_dyn_shared_memory"], "tir.noalias": True})
        buf_dyn_shmem = T.handle("uint8", "shared.dyn")
        C_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
        B_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
        A_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
        C_local = T.handle("float32", "local")
        C_local_1 = T.decl_buffer((32,), data=C_local, scope="local")
        bx = T.launch_thread("blockIdx.x", 16)
        buf_dyn_shmem = T.allocate([32768], "uint8", "shared.dyn")
        C_local = T.allocate([32], "float32", "local")
        by = T.launch_thread("blockIdx.y", 16)
        tx = T.launch_thread("threadIdx.x", 256)
        T.create_barriers(1)
        if T.tl_shuffle_elect(0):
            T.call_extern("handle", "tl::prefetch_tma_descriptor", A_desc)
            T.call_extern("handle", "tl::prefetch_tma_descriptor", B_desc)
            T.call_extern("handle", "tl::prefetch_tma_descriptor", C_desc)
            T.ptx_init_barrier_thread_count(T.get_mbarrier(0), 128)
        T.ptx_fence_barrier_init()
        T.tvm_storage_sync("shared")
        ty = T.launch_thread("threadIdx.y", 1)
        tz = T.launch_thread("threadIdx.z", 1)
        T.attr([128, 128], "kWarpSpecializationScope", 0)
        if 128 <= tx:
            T.set_max_nreg(24, 0)
            if T.tl_shuffle_elect(128):
                T.mbarrier_expect_tx(T.get_mbarrier(0), 16384)
                T.fence_proxy_async()
                T.tma_load(A_desc, T.get_mbarrier(0), T.tvm_access_ptr(T.type_annotation("float32"), buf_dyn_shmem, 0, 4096, 2), bx * 64, by * 64, 0)
                T.mbarrier_expect_tx(T.get_mbarrier(0), 16384)
                T.fence_proxy_async()
                T.tma_load(B_desc, T.get_mbarrier(0), T.tvm_access_ptr(T.type_annotation("float32"), buf_dyn_shmem, 4096, 4096, 2), bx * 64, by * 64, 0)
            T.ptx_arrive_barrier(T.get_mbarrier(0))
        else:
            T.set_max_nreg(240, 1)
            T.mbarrier_wait_parity(T.get_mbarrier(0), 0)
            for i in T.unroll(8):
                C_local_1[i * 4:i * 4 + 4] = A_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(0, 4)] + B_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(4096, 4)]
            T.tvm_storage_sync("shared.dyn", 3, 128)
            for i in T.unroll(8):
                C_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(0, 4)] = C_local_1[i * 4:i * 4 + 4]
            T.fence_proxy_async()
            T.tvm_storage_sync("shared.dyn", 3, 128)
            if T.tl_shuffle_elect(128):
                T.tma_store(C_desc, T.tvm_access_ptr(T.type_annotation("float32"), buf_dyn_shmem, 0, 4096, 1), bx * 64, by * 64, 0, 0)
                T.tma_store_arrive()
                T.tma_store_wait()

调用device_codegen_without_compile

codegen_mod = device_codegen(
    device_mod, target) if enable_device_compile else device_codegen_without_compile(
        device_mod, target)

device_codegen_without_compile

tilelang/engine/lower.py
def device_codegen_without_compile(device_mod

LowerDeviceStorageAccessInfo

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def elem_add_kernel(A_desc: T.handle("uint8x128", "grid_constant"), B_desc: T.handle("uint8x128", "grid_constant"), C_desc: T.handle("uint8x128", "grid_constant")):
        T.func_attr({"calling_conv": 2, "dyn_shared_memory_buf": 32768, "target": T.target({"arch": "sm_120", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "thread_extent": {"blockIdx.x": 16, "blockIdx.y": 16, "threadIdx.x": 256, "threadIdx.y": 1, "threadIdx.z": 1}, "tir.is_global_func": T.bool(True), "tir.kernel_launch_params": ["blockIdx.x", "blockIdx.y", "threadIdx.x", "threadIdx.y", "threadIdx.z", "tir.use_dyn_shared_memory"], "tir.noalias": True})
        buf_dyn_shmem = T.handle("uint8", "shared.dyn")
        C_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
        B_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
        A_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
        C_local = T.handle("float32", "local")
        C_local_1 = T.decl_buffer((32,), data=C_local, scope="local")
        bx = T.launch_thread("blockIdx.x", 16)
        buf_dyn_shmem = T.allocate([32768], "uint8", "shared.dyn")
        C_local = T.allocate([32], "float32", "local")
        by = T.launch_thread("blockIdx.y", 16)
        tx = T.launch_thread("threadIdx.x", 256)
        T.create_barriers(1)
        if T.tl_shuffle_elect(0):
            T.call_extern("handle", "tl::prefetch_tma_descriptor", A_desc)
            T.call_extern("handle", "tl::prefetch_tma_descriptor", B_desc)
            T.call_extern("handle", "tl::prefetch_tma_descriptor", C_desc)
            T.ptx_init_barrier_thread_count(T.get_mbarrier(0), 128)
        T.ptx_fence_barrier_init()
        T.tvm_storage_sync("shared")
        ty = T.launch_thread("threadIdx.y", 1)
        tz = T.launch_thread("threadIdx.z", 1)
        T.attr([128, 128], "kWarpSpecializationScope", 0)
        if 128 <= tx:
            T.set_max_nreg(24, 0)
            if T.tl_shuffle_elect(128):
                T.mbarrier_expect_tx(T.get_mbarrier(0), 16384)
                T.fence_proxy_async()
                buf_dyn_shmem_1 = T.Buffer((1,), data=buf_dyn_shmem, scope="shared.dyn")
                T.tma_load(A_desc, T.get_mbarrier(0), T.address_of(buf_dyn_shmem_1[0]), bx * 64, by * 64, 0)
                T.mbarrier_expect_tx(T.get_mbarrier(0), 16384)
                T.fence_proxy_async()
                buf_dyn_shmem_2 = T.Buffer((4097,), data=buf_dyn_shmem, scope="shared.dyn")
                T.tma_load(B_desc, T.get_mbarrier(0), T.address_of(buf_dyn_shmem_2[4096]), bx * 64, by * 64, 0)
            T.ptx_arrive_barrier(T.get_mbarrier(0))
        else:
            T.set_max_nreg(240, 1)
            T.mbarrier_wait_parity(T.get_mbarrier(0), 0)
            for i in T.unroll(8):
                C_local_1[i * 4:i * 4 + 4] = A_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(0, 4)] + B_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(4096, 4)]
            T.tvm_storage_sync("shared.dyn", 3, 128)
            for i in T.unroll(8):
                C_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(0, 4)] = C_local_1[i * 4:i * 4 + 4]
            T.fence_proxy_async()
            T.tvm_storage_sync("shared.dyn", 3, 128)
            if T.tl_shuffle_elect(128):
                buf_dyn_shmem_1 = T.Buffer((1,), data=buf_dyn_shmem, scope="shared.dyn")
                T.tma_store(C_desc, T.address_of(buf_dyn_shmem_1[0]), bx * 64, by * 64, 0, 0)
                T.tma_store_arrive()
                T.tma_store_wait()

LowerIntrin

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def elem_add_kernel(A_desc: T.handle("uint8x128", "grid_constant"), B_desc: T.handle("uint8x128", "grid_constant"), C_desc: T.handle("uint8x128", "grid_constant")):
        T.func_attr({"calling_conv": 2, "dyn_shared_memory_buf": 32768, "target": T.target({"arch": "sm_120", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "thread_extent": {"blockIdx.x": 16, "blockIdx.y": 16, "threadIdx.x": 256, "threadIdx.y": 1, "threadIdx.z": 1}, "tir.is_global_func": T.bool(True), "tir.kernel_launch_params": ["blockIdx.x", "blockIdx.y", "threadIdx.x", "threadIdx.y", "threadIdx.z", "tir.use_dyn_shared_memory"], "tir.noalias": True})
        buf_dyn_shmem = T.handle("uint8", "shared.dyn")
        C_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
        B_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
        A_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
        C_local = T.handle("float32", "local")
        C_local_1 = T.decl_buffer((32,), data=C_local, scope="local")
        bx = T.launch_thread("blockIdx.x", 16)
        buf_dyn_shmem = T.allocate([32768], "uint8", "shared.dyn")
        C_local = T.allocate([32], "float32", "local")
        by = T.launch_thread("blockIdx.y", 16)
        tx = T.launch_thread("threadIdx.x", 256)
        T.create_barriers(1)
        if T.tl_shuffle_elect(0):
            T.call_extern("handle", "tl::prefetch_tma_descriptor", A_desc)
            T.call_extern("handle", "tl::prefetch_tma_descriptor", B_desc)
            T.call_extern("handle", "tl::prefetch_tma_descriptor", C_desc)
            T.ptx_init_barrier_thread_count(T.get_mbarrier(0), 128)
        T.ptx_fence_barrier_init()
        T.tvm_storage_sync("shared")
        ty = T.launch_thread("threadIdx.y", 1)
        tz = T.launch_thread("threadIdx.z", 1)
        T.attr([128, 128], "kWarpSpecializationScope", 0)
        if 128 <= tx:
            T.set_max_nreg(24, 0)
            if T.tl_shuffle_elect(128):
                T.mbarrier_expect_tx(T.get_mbarrier(0), 16384)
                T.fence_proxy_async()
                buf_dyn_shmem_1 = T.Buffer((1,), data=buf_dyn_shmem, scope="shared.dyn")
                T.tma_load(A_desc, T.get_mbarrier(0), T.address_of(buf_dyn_shmem_1[0]), bx * 64, by * 64, 0)
                T.mbarrier_expect_tx(T.get_mbarrier(0), 16384)
                T.fence_proxy_async()
                buf_dyn_shmem_2 = T.Buffer((4097,), data=buf_dyn_shmem, scope="shared.dyn")
                T.tma_load(B_desc, T.get_mbarrier(0), T.address_of(buf_dyn_shmem_2[4096]), bx * 64, by * 64, 0)
            T.ptx_arrive_barrier(T.get_mbarrier(0))
        else:
            T.set_max_nreg(240, 1)
            T.mbarrier_wait_parity(T.get_mbarrier(0), 0)
            for i in T.unroll(8):
                C_local_1[i * 4:i * 4 + 4] = A_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(0, 4)] + B_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(4096, 4)]
            T.tvm_storage_sync("shared.dyn", 3, 128)
            for i in T.unroll(8):
                C_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(0, 4)] = C_local_1[i * 4:i * 4 + 4]
            T.fence_proxy_async()
            T.tvm_storage_sync("shared.dyn", 3, 128)
            if T.tl_shuffle_elect(128):
                buf_dyn_shmem_1 = T.Buffer((1,), data=buf_dyn_shmem, scope="shared.dyn")
                T.tma_store(C_desc, T.address_of(buf_dyn_shmem_1[0]), bx * 64, by * 64, 0, 0)
                T.tma_store_arrive()
                T.tma_store_wait()

Simplify

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def elem_add_kernel(A_desc: T.handle("uint8x128", "grid_constant"), B_desc: T.handle("uint8x128", "grid_constant"), C_desc: T.handle("uint8x128", "grid_constant")):
        T.func_attr({"calling_conv": 2, "dyn_shared_memory_buf": 32768, "target": T.target({"arch": "sm_120", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "thread_extent": {"blockIdx.x": 16, "blockIdx.y": 16, "threadIdx.x": 256, "threadIdx.y": 1, "threadIdx.z": 1}, "tir.is_global_func": T.bool(True), "tir.kernel_launch_params": ["blockIdx.x", "blockIdx.y", "threadIdx.x", "threadIdx.y", "threadIdx.z", "tir.use_dyn_shared_memory"], "tir.noalias": True})
        buf_dyn_shmem = T.handle("uint8", "shared.dyn")
        C_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
        B_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
        A_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
        C_local = T.handle("float32", "local")
        C_local_1 = T.decl_buffer((32,), data=C_local, scope="local")
        bx = T.launch_thread("blockIdx.x", 16)
        buf_dyn_shmem = T.allocate([32768], "uint8", "shared.dyn")
        C_local = T.allocate([32], "float32", "local")
        by = T.launch_thread("blockIdx.y", 16)
        tx = T.launch_thread("threadIdx.x", 256)
        T.create_barriers(1)
        if T.tl_shuffle_elect(0):
            T.call_extern("handle", "tl::prefetch_tma_descriptor", A_desc)
            T.call_extern("handle", "tl::prefetch_tma_descriptor", B_desc)
            T.call_extern("handle", "tl::prefetch_tma_descriptor", C_desc)
            T.ptx_init_barrier_thread_count(T.get_mbarrier(0), 128)
        T.ptx_fence_barrier_init()
        T.tvm_storage_sync("shared")
        ty = T.launch_thread("threadIdx.y", 1)
        tz = T.launch_thread("threadIdx.z", 1)
        T.attr([128, 128], "kWarpSpecializationScope", 0)
        if 128 <= tx:
            T.set_max_nreg(24, 0)
            if T.tl_shuffle_elect(128):
                T.mbarrier_expect_tx(T.get_mbarrier(0), 16384)
                T.fence_proxy_async()
                buf_dyn_shmem_1 = T.Buffer((1,), data=buf_dyn_shmem, scope="shared.dyn")
                T.tma_load(A_desc, T.get_mbarrier(0), T.address_of(buf_dyn_shmem_1[0]), bx * 64, by * 64, 0)
                T.mbarrier_expect_tx(T.get_mbarrier(0), 16384)
                T.fence_proxy_async()
                buf_dyn_shmem_2 = T.Buffer((4097,), data=buf_dyn_shmem, scope="shared.dyn")
                T.tma_load(B_desc, T.get_mbarrier(0), T.address_of(buf_dyn_shmem_2[4096]), bx * 64, by * 64, 0)
            T.ptx_arrive_barrier(T.get_mbarrier(0))
        else:
            T.set_max_nreg(240, 1)
            T.mbarrier_wait_parity(T.get_mbarrier(0), 0)
            for i in T.unroll(8):
                C_local_1[i * 4:i * 4 + 4] = A_shared[i * 512 + tx * 4:i * 512 + tx * 4 + 4] + B_shared[i * 512 + tx * 4 + 4096:i * 512 + tx * 4 + 4096 + 4]
            T.tvm_storage_sync("shared.dyn", 3, 128)
            for i in T.unroll(8):
                C_shared[i * 512 + tx * 4:i * 512 + tx * 4 + 4] = C_local_1[i * 4:i * 4 + 4]
            T.fence_proxy_async()
            T.tvm_storage_sync("shared.dyn", 3, 128)
            if T.tl_shuffle_elect(128):
                buf_dyn_shmem_1 = T.Buffer((1,), data=buf_dyn_shmem, scope="shared.dyn")
                T.tma_store(C_desc, T.address_of(buf_dyn_shmem_1[0]), bx * 64, by * 64, 0, 0)
                T.tma_store_arrive()
                T.tma_store_wait()

device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda_without_compile")(device_mod, target)

#include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
#include <tl_templates/cuda/reduce.h>
#include <tl_templates/cuda/ldsm.h>
#include <tl_templates/cuda/threadblock_swizzle.h>
#include <tl_templates/cuda/debug.h>
#ifdef ENABLE_BF16
#include <tl_templates/cuda/cuda_bf16_fallbacks.cuh>
#endif

extern "C" __global__ void elem_add_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, __grid_constant__ const CUtensorMap C_desc);
extern "C" __global__ void __launch_bounds__(256, 1) elem_add_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, __grid_constant__ const CUtensorMap C_desc) {
  extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
  float C_local[32];
  __shared__ uint64_t mbarrier_mem[1];
  auto mbarrier = reinterpret_cast<Barrier*>(mbarrier_mem);
  if (tl::tl_shuffle_elect<0>()) {
    tl::prefetch_tma_descriptor(A_desc);
    tl::prefetch_tma_descriptor(B_desc);
    tl::prefetch_tma_descriptor(C_desc);
    mbarrier[0].init(128);
  }
  tl::fence_barrier_init();
  __syncthreads();
  if (128 <= ((int)threadIdx.x)) {
    tl::warpgroup_reg_dealloc<24>();
    if (tl::tl_shuffle_elect<128>()) {
      mbarrier[0].expect_transaction(16384);
      tl::fence_proxy_async();
      tl::tma_load(A_desc, mbarrier[0], (&(((float*)buf_dyn_shmem)[0])), (((int)blockIdx.x) * 64), (((int)blockIdx.y) * 64));
      mbarrier[0].expect_transaction(16384);
      tl::fence_proxy_async();
      tl::tma_load(B_desc, mbarrier[0], (&(((float*)buf_dyn_shmem)[4096])), (((int)blockIdx.x) * 64), (((int)blockIdx.y) * 64));
    }
    mbarrier[0].arrive();
  } else {
    tl::warpgroup_reg_alloc<240>();
    mbarrier[0].wait(0);
    #pragma unroll
    for (int i = 0; i < 8; ++i) {
      float4 __1;
        float4 v_ = *(float4*)(((float*)buf_dyn_shmem) + ((i * 512) + (((int)threadIdx.x) * 4)));
        float4 v__1 = *(float4*)(((float*)buf_dyn_shmem) + (((i * 512) + (((int)threadIdx.x) * 4)) + 4096));
        __1.x = (v_.x+v__1.x);
        __1.y = (v_.y+v__1.y);
        __1.z = (v_.z+v__1.z);
        __1.w = (v_.w+v__1.w);
      *(float4*)(C_local + (i * 4)) = __1;
    }
    tl::__sync_thread_partial<3, 128>();
    #pragma unroll
    for (int i_1 = 0; i_1 < 8; ++i_1) {
      *(float4*)(((float*)buf_dyn_shmem) + ((i_1 * 512) + (((int)threadIdx.x) * 4))) = *(float4*)(C_local + (i_1 * 4));
    }
    tl::fence_proxy_async();
    tl::__sync_thread_partial<3, 128>();
    if (tl::tl_shuffle_elect<128>()) {
      tl::tma_store(C_desc, (&(((float*)buf_dyn_shmem)[0])), (((int)blockIdx.x) * 64), (((int)blockIdx.y) * 64));
      tl::tma_store_arrive();
      tl::tma_store_wait<0>();
    }
  }
}

CythonKernelAdapter

调用完tilelang.lower之后,会继续调用CythonKernelAdapter

            adapter = CythonKernelAdapter(
                params=artifact.params,
                result_idx=out_idx,
                target=target,
                func_or_mod=tilelang_func,
                host_mod=artifact.host_mod,
                device_mod=artifact.device_mod,
                kernel_global_source=artifact.kernel_source,
                verbose=verbose,
                pass_configs=pass_configs,
                compile_flags=compile_flags,
            )

self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True))

#include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
#include <tl_templates/cuda/reduce.h>
#include <tl_templates/cuda/ldsm.h>
#include <tl_templates/cuda/threadblock_swizzle.h>
#include <tl_templates/cuda/debug.h>
#ifdef ENABLE_BF16
#include <tl_templates/cuda/cuda_bf16_fallbacks.cuh>
#endif

extern "C" __global__ void elem_add_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, __grid_constant__ const CUtensorMap C_desc);
extern "C" __global__ void __launch_bounds__(256, 1) elem_add_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, __grid_constant__ const CUtensorMap C_desc) {
  extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
  float C_local[32];
  __shared__ uint64_t mbarrier_mem[1];
  auto mbarrier = reinterpret_cast<Barrier*>(mbarrier_mem);
  if (tl::tl_shuffle_elect<0>()) {
    tl::prefetch_tma_descriptor(A_desc);
    tl::prefetch_tma_descriptor(B_desc);
    tl::prefetch_tma_descriptor(C_desc);
    mbarrier[0].init(128);
  }
  tl::fence_barrier_init();
  __syncthreads();
  if (128 <= ((int)threadIdx.x)) {
    tl::warpgroup_reg_dealloc<24>();
    if (tl::tl_shuffle_elect<128>()) {
      mbarrier[0].expect_transaction(16384);
      tl::fence_proxy_async();
      tl::tma_load(A_desc, mbarrier[0], (&(((float*)buf_dyn_shmem)[0])), (((int)blockIdx.x) * 64), (((int)blockIdx.y) * 64));
      mbarrier[0].expect_transaction(16384);
      tl::fence_proxy_async();
      tl::tma_load(B_desc, mbarrier[0], (&(((float*)buf_dyn_shmem)[4096])), (((int)blockIdx.x) * 64), (((int)blockIdx.y) * 64));
    }
    mbarrier[0].arrive();
  } else {
    tl::warpgroup_reg_alloc<240>();
    mbarrier[0].wait(0);
    #pragma unroll
    for (int i = 0; i < 8; ++i) {
      float4 __1;
        float4 v_ = *(float4*)(((float*)buf_dyn_shmem) + ((i * 512) + (((int)threadIdx.x) * 4)));
        float4 v__1 = *(float4*)(((float*)buf_dyn_shmem) + (((i * 512) + (((int)threadIdx.x) * 4)) + 4096));
        __1.x = (v_.x+v__1.x);
        __1.y = (v_.y+v__1.y);
        __1.z = (v_.z+v__1.z);
        __1.w = (v_.w+v__1.w);
      *(float4*)(C_local + (i * 4)) = __1;
    }
    tl::__sync_thread_partial<3, 128>();
    #pragma unroll
    for (int i_1 = 0; i_1 < 8; ++i_1) {
      *(float4*)(((float*)buf_dyn_shmem) + ((i_1 * 512) + (((int)threadIdx.x) * 4))) = *(float4*)(C_local + (i_1 * 4));
    }
    tl::fence_proxy_async();
    tl::__sync_thread_partial<3, 128>();
    if (tl::tl_shuffle_elect<128>()) {
      tl::tma_store(C_desc, (&(((float*)buf_dyn_shmem)[0])), (((int)blockIdx.x) * 64), (((int)blockIdx.y) * 64));
      tl::tma_store_arrive();
      tl::tma_store_wait<0>();
    }
  }
}


#define ERROR_BUF_SIZE 1024
static char error_buf[ERROR_BUF_SIZE];

extern "C" const char* get_last_error() {
    return error_buf;
}

extern "C" int init() {
    error_buf[0] = '\0';
    
    cudaError_t result_elem_add_kernel = cudaFuncSetAttribute(elem_add_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 32768);
    if (result_elem_add_kernel != cudaSuccess) {
        snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", 32768, cudaGetErrorString(result_elem_add_kernel));
        return -1;
    }

    return 0;
}

extern "C" int call(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C, cudaStream_t stream=cudaStreamDefault) {

        CUtensorMap A_desc;
        CUtensorMapDataType A_desc_type= (CUtensorMapDataType)7;
        cuuint32_t A_desc_tensorRank= 2;
        void *A_desc_globalAddress= A;
        cuuint64_t A_desc_globalDim[2]= {1024,1024};
        cuuint64_t A_desc_globalStride[2]= {4,4096};
        cuuint32_t A_desc_boxDim[2]= {64,64};
        cuuint32_t A_desc_elementStrides[2]= {1,1};
        CUtensorMapInterleave A_desc_interleave= (CUtensorMapInterleave)0;
        CUtensorMapSwizzle A_desc_swizzle= (CUtensorMapSwizzle)0;
        CUtensorMapL2promotion A_desc_l2Promotion= (CUtensorMapL2promotion)2;
        CUtensorMapFloatOOBfill A_desc_oobFill= (CUtensorMapFloatOOBfill)0;

        CUresult A_desc_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
    &A_desc, A_desc_type, A_desc_tensorRank, A_desc_globalAddress, A_desc_globalDim, A_desc_globalStride + 1, A_desc_boxDim, A_desc_elementStrides, A_desc_interleave, A_desc_swizzle, A_desc_l2Promotion, A_desc_oobFill);

        if (A_desc_result != CUDA_SUCCESS) {
                std::stringstream ss;
                ss << "Error: Failed to initialize the TMA descriptor A_desc";
                snprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str());
                return -1;
        }

        CUtensorMap B_desc;
        CUtensorMapDataType B_desc_type= (CUtensorMapDataType)7;
        cuuint32_t B_desc_tensorRank= 2;
        void *B_desc_globalAddress= B;
        cuuint64_t B_desc_globalDim[2]= {1024,1024};
        cuuint64_t B_desc_globalStride[2]= {4,4096};
        cuuint32_t B_desc_boxDim[2]= {64,64};
        cuuint32_t B_desc_elementStrides[2]= {1,1};
        CUtensorMapInterleave B_desc_interleave= (CUtensorMapInterleave)0;
        CUtensorMapSwizzle B_desc_swizzle= (CUtensorMapSwizzle)0;
        CUtensorMapL2promotion B_desc_l2Promotion= (CUtensorMapL2promotion)2;
        CUtensorMapFloatOOBfill B_desc_oobFill= (CUtensorMapFloatOOBfill)0;

        CUresult B_desc_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
    &B_desc, B_desc_type, B_desc_tensorRank, B_desc_globalAddress, B_desc_globalDim, B_desc_globalStride + 1, B_desc_boxDim, B_desc_elementStrides, B_desc_interleave, B_desc_swizzle, B_desc_l2Promotion, B_desc_oobFill);

        if (B_desc_result != CUDA_SUCCESS) {
                std::stringstream ss;
                ss << "Error: Failed to initialize the TMA descriptor B_desc";
                snprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str());
                return -1;
        }

        CUtensorMap C_desc;
        CUtensorMapDataType C_desc_type= (CUtensorMapDataType)7;
        cuuint32_t C_desc_tensorRank= 2;
        void *C_desc_globalAddress= C;
        cuuint64_t C_desc_globalDim[2]= {1024,1024};
        cuuint64_t C_desc_globalStride[2]= {4,4096};
        cuuint32_t C_desc_boxDim[2]= {64,64};
        cuuint32_t C_desc_elementStrides[2]= {1,1};
        CUtensorMapInterleave C_desc_interleave= (CUtensorMapInterleave)0;
        CUtensorMapSwizzle C_desc_swizzle= (CUtensorMapSwizzle)0;
        CUtensorMapL2promotion C_desc_l2Promotion= (CUtensorMapL2promotion)2;
        CUtensorMapFloatOOBfill C_desc_oobFill= (CUtensorMapFloatOOBfill)0;

        CUresult C_desc_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
    &C_desc, C_desc_type, C_desc_tensorRank, C_desc_globalAddress, C_desc_globalDim, C_desc_globalStride + 1, C_desc_boxDim, C_desc_elementStrides, C_desc_interleave, C_desc_swizzle, C_desc_l2Promotion, C_desc_oobFill);

        if (C_desc_result != CUDA_SUCCESS) {
                std::stringstream ss;
                ss << "Error: Failed to initialize the TMA descriptor C_desc";
                snprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str());
                return -1;
        }
        elem_add_kernel<<<dim3(16, 16, 1), dim3(256, 1, 1), 32768, stream>>>(A_desc, B_desc, C_desc);
        TILELANG_CHECK_LAST_ERROR("elem_add_kernel");

        return 0;
}

调用命令编译cuda cu文件

['/usr/local/cuda-12.8/bin/nvcc',
'-std=c++17',
'-w',
'-Xcudafe',
'--diag_suppress=177',
'--compiler-options',
'-fPIC',
'-lineinfo',
'--shared',
'/tmp/tmp3f_1m6dk.cu',
'-lcuda',
'-gencode',
'arch=compute_120a,code=sm_120a',
'-I/mnt/d/0_work/share_with_ubuntu/tilelang/3rdparty/cutlass/include',
'-I/mnt/d/0_work/share_with_ubuntu/tilelang/3rdparty/../src',
'-o',
'/tmp/tmp3f_1m6dk.so']

Logo

免费领 200 小时云算力,进群参与显卡、AI PC 幸运抽奖

更多推荐