tilelang-运行一个测试用例
tilelang的一个测试用例
1.使用tilelang 语法编写一个gpu kernel 函数
@tilelang.jit(out_idx=[-1])
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
@T.prim_func
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor(
(M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), in_dtype)
B_shared = T.alloc_shared((block_M, block_N), in_dtype)
C_local = T.alloc_fragment((block_M, block_N), out_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(B[by * block_M, bx * block_N], B_shared)
for (local_y, local_x) in T.Parallel(block_M, block_N):
C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x]
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return elem_add
调用代码
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=1024)
parser.add_argument("--n", type=int, default=1024)
parser.add_argument("--use_autotune", action="store_true", default=False)
args, _ = parser.parse_known_args()
M, N = args.m, args.n
a = torch.randn(M, N, dtype=torch.float32, device="cuda")
b = torch.randn(M, N, dtype=torch.float32, device="cuda")
if args.use_autotune:
result = get_best_config(M, N)
kernel = result.kernel
else:
# Default config
config = {"block_M": 64, "block_N": 64, "threads": 128}
kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")
out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
2.kernel的tir表示
<class 'tvm.tir.function.PrimFunc'>
# from tvm.script import tir as T
@T.prim_func
def elem_add(A_handle: T.handle, B_handle: T.handle, C_handle: T.handle):
A = T.match_buffer(A_handle, (1024, 1024), strides=(1024, 1))
B = T.match_buffer(B_handle, (1024, 1024), strides=(1024, 1))
C = T.match_buffer(C_handle, (1024, 1024), strides=(1024, 1))
# with T.block("root"):
bx = T.launch_thread("blockIdx.x", 16)
by = T.launch_thread("blockIdx.y", 16)
tx = T.launch_thread("threadIdx.x", 128)
ty = T.launch_thread("threadIdx.y", 1)
tz = T.launch_thread("threadIdx.z", 1)
with T.block("tilelang_root"):
T.reads(A[by * 64, bx * 64], B[by * 64, bx * 64], C[by * 64, bx * 64])
T.writes()
A_shared = T.alloc_buffer((64, 64), scope="shared.dyn")
B_shared = T.alloc_buffer((64, 64), scope="shared.dyn")
C_local = T.alloc_buffer((64, 64), scope="local.fragment")
C_shared = T.alloc_buffer((64, 64), scope="shared.dyn")
T.copy(T.region(A[by * 64, bx * 64], 1, 64, 64), T.region(A_shared[0, 0], 2, 64, 64), -1, T.bool(False), 0)
T.copy(T.region(B[by * 64, bx * 64], 1, 64, 64), T.region(B_shared[0, 0], 2, 64, 64), -1, T.bool(False), 0)
for local_y in T.parallel(64):
for local_x in T.parallel(64):
C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x]
T.copy(T.region(C_local[0, 0], 1, 64, 64), T.region(C_shared[0, 0], 2, 64, 64), -1, T.bool(False), 0)
T.copy(T.region(C_shared[0, 0], 1, 64, 64), T.region(C[by * 64, bx * 64], 2, 64, 64), -1, T.bool(False), 0)
3.对tir prim func 进行cuda后端编译
3.1 compile
tilelang/jit/__init__.py
def compile(func: PrimFunc = None,
3.2 cached
tilelang/cache/__init__.py
def cached(func: PrimFunc = None
3.3 cached
tilelang/cache/kernel_cache.py
class KernelCache:
def cached(self,func: PrimFunc = None
3.4 _load_kernel_from_disk
tilelang/cache/kernel_cache.py
class KernelCache:
def _load_kernel_from_disk(
判断是否存在缓存文件
3.5 初始化JITKernel
kernel = JITKernel(
func,
out_idx=out_idx,
execution_backend=execution_backend,
target=target,
target_host=target_host,
verbose=verbose,
pass_configs=pass_configs,
compile_flags=compile_flags,
)
3.6 JITKernel init函数
tilelang/jit/kernel.py
class JITKernel:
def __init__(
调用adapter = self._compile_and_create_adapter(func, out_idx)函数
3.7 _compile_and_create_adapter
调用tilelang.lower
artifact = tilelang.lower(
tilelang_func,
target=target,
target_host=target_host,
enable_host_codegen=enable_host_codegen,
enable_device_compile=enable_device_compile)
3.8 tilelang.lower
tilelang/engine/lower.py
def lower(func_or_mod: tir.PrimFunc | tvm.IRModule,
mod = LowerAndLegalize(mod, target)
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def elem_add(A_handle: T.handle, B_handle: T.handle, C_handle: T.handle):
T.func_attr({"target": T.target({"arch": "sm_120", "host": {"keys": ["cpu"], "kind": "c", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32})})
A = T.match_buffer(A_handle, (1024, 1024), strides=(1024, 1))
B = T.match_buffer(B_handle, (1024, 1024), strides=(1024, 1))
C = T.match_buffer(C_handle, (1024, 1024), strides=(1024, 1))
with T.block("root"):
T.reads()
T.writes()
A_shared = T.Buffer((1, 1, 64, 64), scope="shared.dyn")
B_shared = T.Buffer((1, 1, 64, 64), scope="shared.dyn")
C_local = T.Buffer((32,), scope="local")
C_shared = T.Buffer((1, 1, 64, 64), scope="shared.dyn")
T.block_attr({"layout_map": {A_shared: metadata["tl.Layout"][0], B_shared: metadata["tl.Layout"][1], C_local: metadata["tl.Fragment"][0], C_shared: metadata["tl.Layout"][2]}})
bx = T.launch_thread("blockIdx.x", 16)
by = T.launch_thread("blockIdx.y", 16)
tx = T.launch_thread("threadIdx.x", 128)
ty = T.launch_thread("threadIdx.y", 1)
tz = T.launch_thread("threadIdx.z", 1)
with T.block("tilelang_root"):
T.reads(A[by * 64, bx * 64], B[by * 64, bx * 64], C[by * 64, bx * 64])
T.writes()
T.block_attr({"layout_map": {A_shared: metadata["tl.Layout"][0], B_shared: metadata["tl.Layout"][1], C_local: metadata["tl.Fragment"][0], C_shared: metadata["tl.Layout"][2]}})
A_shared = T.alloc_buffer((1, 1, 64, 64), data=A_shared.data, scope="shared.dyn")
B_shared = T.alloc_buffer((1, 1, 64, 64), data=B_shared.data, scope="shared.dyn")
C_local = T.alloc_buffer((32,), data=C_local.data, scope="local")
C_shared = T.alloc_buffer((1, 1, 64, 64), data=C_shared.data, scope="shared.dyn")
if tx == 0:
T.tma_load(T.create_tma_descriptor(7, 2, A.data, 1024, 1024, T.int64(4), T.int64(4096), 64, 64, 1, 1, 0, 0, 2, 0), 0, T.tvm_access_ptr(T.type_annotation("float32"), A_shared.data, 0, 4096, 2), bx * 64, by * 64, 0)
if tx == 0:
T.tma_load(T.create_tma_descriptor(7, 2, B.data, 1024, 1024, T.int64(4), T.int64(4096), 64, 64, 1, 1, 0, 0, 2, 0), 0, T.tvm_access_ptr(T.type_annotation("float32"), B_shared.data, 0, 4096, 2), bx * 64, by * 64, 0)
for i in T.unroll(8, annotations={"pragma_unroll_explicit": T.bool(False)}):
for vec in T.vectorized(4):
C_local[i * 4 + vec] = A_shared[0, 0, i * 8 + tx // 16, tx % 16 * 4 + vec] + B_shared[0, 0, i * 8 + tx // 16, tx % 16 * 4 + vec]
for i in T.unroll(8, annotations={"pragma_unroll_explicit": T.bool(False)}):
for vec in T.vectorized(4):
C_shared[0, 0, i * 8 + tx // 16, tx % 16 * 4 + vec] = C_local[i * 4 + vec]
if tx == 0:
T.tma_store(T.create_tma_descriptor(7, 2, C.data, 1024, 1024, T.int64(4), T.int64(4096), 64, 64, 1, 1, 0, 0, 2, 0), T.tvm_access_ptr(T.type_annotation("float32"), C_shared.data, 0, 4096, 1), bx * 64, by * 64, 0, 0)
# Metadata omitted. Use show_meta=True in script() method to show it.
mod = OptimizeForTarget(mod, target)
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def elem_add(args: T.handle, arg_type_ids: T.handle("int32", "global"), num_args: T.int32, out_ret_value: T.handle("void", "global"), out_ret_tcode: T.handle("int32", "global"), resource_handle: T.handle) -> T.int32:
C_desc = T.handle("uint8x128", "grid_constant")
C = T.handle("float32", "global")
B_desc = T.handle("uint8x128", "grid_constant")
B = T.handle("float32", "global")
A_desc = T.handle("uint8x128", "grid_constant")
A = T.handle("float32", "global")
T.func_attr({"calling_conv": 1, "target": T.target({"keys": ["cpu"], "kind": "c", "tag": ""}), "thread_extent": {}, "tir.is_entry_func": True, "tma_descriptor_args": {C_desc: ["__tvm_tensormap_create_tiled", C_desc, 7, 2, C, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0], B_desc: ["__tvm_tensormap_create_tiled", B_desc, 7, 2, B, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0], A_desc: ["__tvm_tensormap_create_tiled", A_desc, 7, 2, A, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0]}})
assert num_args == 3, "elem_add: num_args should be 3"
assert not T.isnullptr(args), "elem_add: TVMValue* arg pointer was NULL"
assert not T.isnullptr(arg_type_ids), "elem_add: int* type_codes was NULL"
arg_type_ids_1 = T.decl_buffer((3,), "int32", data=arg_type_ids)
A_handle_code: T.int32 = arg_type_ids_1[0]
assert A_handle_code == 0 or A_handle_code == 4 or A_handle_code == 7 or A_handle_code >= 64, "elem_add: Expect arg[0] to be pointer"
B_handle_code: T.int32 = arg_type_ids_1[1]
assert B_handle_code == 0 or B_handle_code == 4 or B_handle_code == 7 or B_handle_code >= 64, "elem_add: Expect arg[1] to be pointer"
C_handle_code: T.int32 = arg_type_ids_1[2]
assert C_handle_code == 0 or C_handle_code == 4 or C_handle_code == 7 or C_handle_code >= 64, "elem_add: Expect arg[2] to be pointer"
A_handle: T.handle = T.tvm_struct_get(args, 0, 12, "handle")
B_handle: T.handle = T.tvm_struct_get(args, 1, 12, "handle")
C_handle: T.handle = T.tvm_struct_get(args, 2, 12, "handle")
assert not T.isnullptr(A_handle), "elem_add.A_handle is expected to have non-NULL DLTensor* pointer"
assert 2 == T.tvm_struct_get(A_handle, 0, 4, "int32"), "elem_add.A_handle.ndim is expected to equal 2"
elem_add_A_handle_shape: T.handle("int64", "global") = T.tvm_struct_get(A_handle, 0, 2, "handle")
elem_add_A_handle_shape_1 = T.decl_buffer((2,), "int64", data=elem_add_A_handle_shape)
elem_add_A_handle_strides: T.handle("int64", "global") = T.tvm_struct_get(A_handle, 0, 3, "handle")
elem_add_A_handle_strides_1 = T.decl_buffer((2,), "int64", data=elem_add_A_handle_strides)
dev_id: T.int32 = T.tvm_struct_get(A_handle, 0, 9, "int32")
with T.LetStmt(T.tvm_struct_get(A_handle, 0, 1, "handle"), var=A):
T.attr(A, "storage_alignment", 64)
assert not T.isnullptr(B_handle), "elem_add.B_handle is expected to have non-NULL DLTensor* pointer"
assert 2 == T.tvm_struct_get(B_handle, 0, 4, "int32"), "elem_add.B_handle.ndim is expected to equal 2"
elem_add_B_handle_shape: T.handle("int64", "global") = T.tvm_struct_get(B_handle, 0, 2, "handle")
elem_add_B_handle_shape_1 = T.decl_buffer((2,), "int64", data=elem_add_B_handle_shape)
elem_add_B_handle_strides: T.handle("int64", "global") = T.tvm_struct_get(B_handle, 0, 3, "handle")
elem_add_B_handle_strides_1 = T.decl_buffer((2,), "int64", data=elem_add_B_handle_strides)
with T.LetStmt(T.tvm_struct_get(B_handle, 0, 1, "handle"), var=B):
T.attr(B, "storage_alignment", 64)
assert not T.isnullptr(C_handle), "elem_add.C_handle is expected to have non-NULL DLTensor* pointer"
assert 2 == T.tvm_struct_get(C_handle, 0, 4, "int32"), "elem_add.C_handle.ndim is expected to equal 2"
elem_add_C_handle_shape: T.handle("int64", "global") = T.tvm_struct_get(C_handle, 0, 2, "handle")
elem_add_C_handle_shape_1 = T.decl_buffer((2,), "int64", data=elem_add_C_handle_shape)
elem_add_C_handle_strides: T.handle("int64", "global") = T.tvm_struct_get(C_handle, 0, 3, "handle")
elem_add_C_handle_strides_1 = T.decl_buffer((2,), "int64", data=elem_add_C_handle_strides)
with T.LetStmt(T.tvm_struct_get(C_handle, 0, 1, "handle"), var=C):
T.attr(C, "storage_alignment", 64)
T.attr("default", "device_id", dev_id)
T.attr("default", "device_type", 2)
assert T.tvm_struct_get(A_handle, 0, 5, "uint8") == T.uint8(2) and T.tvm_struct_get(A_handle, 0, 6, "uint8") == T.uint8(32) and T.tvm_struct_get(A_handle, 0, 7, "uint16") == T.uint16(1), "elem_add.A_handle.dtype is expected to be float32"
assert T.Cast("int32", elem_add_A_handle_shape_1[0]) == 1024, "Argument elem_add.A_handle.shape[0] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_A_handle_shape[0])"
assert T.Cast("int32", elem_add_A_handle_shape_1[1]) == 1024, "Argument elem_add.A_handle.shape[1] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_A_handle_shape[1])"
assert T.if_then_else(T.isnullptr(elem_add_A_handle_strides), 1, T.Cast("int32", elem_add_A_handle_strides_1[1])) == 1, "Argument elem_add.A_handle.strides[1] has an unsatisfied constraint: 1 == T.if_then_else(T.isnullptr(elem_add_A_handle_strides), 1, T.Cast(\"int32\", elem_add_A_handle_strides_1[1]))"
assert T.if_then_else(T.isnullptr(elem_add_A_handle_strides), T.Cast("int32", elem_add_A_handle_shape_1[1]), T.Cast("int32", elem_add_A_handle_strides_1[0])) == 1024, "Argument elem_add.A_handle.strides[0] has an unsatisfied constraint: 1024 == T.if_then_else(T.isnullptr(elem_add_A_handle_strides), T.Cast(\"int32\", elem_add_A_handle_shape[1]), T.Cast(\"int32\", elem_add_A_handle_strides_1[0]))"
assert T.uint64(0) == T.tvm_struct_get(A_handle, 0, 8, "uint64"), "Argument elem_add.A_handle.byte_offset has an unsatisfied constraint: T.uint64(0) == T.tvm_struct_get(A_handle, 0, 8, \"uint64\")"
assert T.tvm_struct_get(A_handle, 0, 10, "int32") == 2, "Argument elem_add.A_handle.device_type has an unsatisfied constraint: 2 == T.tvm_struct_get(A_handle, 0, 10, \"int32\")"
assert not T.isnullptr(A), "elem_add.A_handle is expected to have non-NULL data pointer"
assert T.tvm_struct_get(B_handle, 0, 5, "uint8") == T.uint8(2) and T.tvm_struct_get(B_handle, 0, 6, "uint8") == T.uint8(32) and T.tvm_struct_get(B_handle, 0, 7, "uint16") == T.uint16(1), "elem_add.B_handle.dtype is expected to be float32"
assert T.Cast("int32", elem_add_B_handle_shape_1[0]) == 1024, "Argument elem_add.B_handle.shape[0] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_B_handle_shape[0])"
assert T.Cast("int32", elem_add_B_handle_shape_1[1]) == 1024, "Argument elem_add.B_handle.shape[1] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_B_handle_shape[1])"
assert T.if_then_else(T.isnullptr(elem_add_B_handle_strides), 1, T.Cast("int32", elem_add_B_handle_strides_1[1])) == 1, "Argument elem_add.B_handle.strides[1] has an unsatisfied constraint: 1 == T.if_then_else(T.isnullptr(elem_add_B_handle_strides), 1, T.Cast(\"int32\", elem_add_B_handle_strides_1[1]))"
assert T.if_then_else(T.isnullptr(elem_add_B_handle_strides), T.Cast("int32", elem_add_B_handle_shape_1[1]), T.Cast("int32", elem_add_B_handle_strides_1[0])) == 1024, "Argument elem_add.B_handle.strides[0] has an unsatisfied constraint: 1024 == T.if_then_else(T.isnullptr(elem_add_B_handle_strides), T.Cast(\"int32\", elem_add_B_handle_shape[1]), T.Cast(\"int32\", elem_add_B_handle_strides_1[0]))"
assert T.uint64(0) == T.tvm_struct_get(B_handle, 0, 8, "uint64"), "Argument elem_add.B_handle.byte_offset has an unsatisfied constraint: T.uint64(0) == T.tvm_struct_get(B_handle, 0, 8, \"uint64\")"
assert T.tvm_struct_get(B_handle, 0, 10, "int32") == 2, "Argument elem_add.B_handle.device_type has an unsatisfied constraint: 2 == T.tvm_struct_get(B_handle, 0, 10, \"int32\")"
assert dev_id == T.tvm_struct_get(B_handle, 0, 9, "int32"), "Argument elem_add.B_handle.device_id has an unsatisfied constraint: dev_id == T.tvm_struct_get(B_handle, 0, 9, \"int32\")"
assert not T.isnullptr(B), "elem_add.B_handle is expected to have non-NULL data pointer"
assert T.tvm_struct_get(C_handle, 0, 5, "uint8") == T.uint8(2) and T.tvm_struct_get(C_handle, 0, 6, "uint8") == T.uint8(32) and T.tvm_struct_get(C_handle, 0, 7, "uint16") == T.uint16(1), "elem_add.C_handle.dtype is expected to be float32"
assert T.Cast("int32", elem_add_C_handle_shape_1[0]) == 1024, "Argument elem_add.C_handle.shape[0] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_C_handle_shape[0])"
assert T.Cast("int32", elem_add_C_handle_shape_1[1]) == 1024, "Argument elem_add.C_handle.shape[1] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_C_handle_shape[1])"
assert T.if_then_else(T.isnullptr(elem_add_C_handle_strides), 1, T.Cast("int32", elem_add_C_handle_strides_1[1])) == 1, "Argument elem_add.C_handle.strides[1] has an unsatisfied constraint: 1 == T.if_then_else(T.isnullptr(elem_add_C_handle_strides), 1, T.Cast(\"int32\", elem_add_C_handle_strides_1[1]))"
assert T.if_then_else(T.isnullptr(elem_add_C_handle_strides), T.Cast("int32", elem_add_C_handle_shape_1[1]), T.Cast("int32", elem_add_C_handle_strides_1[0])) == 1024, "Argument elem_add.C_handle.strides[0] has an unsatisfied constraint: 1024 == T.if_then_else(T.isnullptr(elem_add_C_handle_strides), T.Cast(\"int32\", elem_add_C_handle_shape[1]), T.Cast(\"int32\", elem_add_C_handle_strides_1[0]))"
assert T.uint64(0) == T.tvm_struct_get(C_handle, 0, 8, "uint64"), "Argument elem_add.C_handle.byte_offset has an unsatisfied constraint: T.uint64(0) == T.tvm_struct_get(C_handle, 0, 8, \"uint64\")"
assert T.tvm_struct_get(C_handle, 0, 10, "int32") == 2, "Argument elem_add.C_handle.device_type has an unsatisfied constraint: 2 == T.tvm_struct_get(C_handle, 0, 10, \"int32\")"
assert dev_id == T.tvm_struct_get(C_handle, 0, 9, "int32"), "Argument elem_add.C_handle.device_id has an unsatisfied constraint: dev_id == T.tvm_struct_get(C_handle, 0, 9, \"int32\")"
assert not T.isnullptr(C), "elem_add.C_handle is expected to have non-NULL data pointer"
A_1 = T.decl_buffer((1024, 1024), data=A, strides=(1024, 1))
B_1 = T.decl_buffer((1024, 1024), data=B, strides=(1024, 1))
C_1 = T.decl_buffer((1024, 1024), data=C, strides=(1024, 1))
assert T.FloorMod(1024, 8) == 0, "A: Vectorize dimension in buffer must be divisible by 8"
assert T.FloorMod(1024, 8) == 0, "B: Vectorize dimension in buffer must be divisible by 8"
assert T.FloorMod(1024, 8) == 0, "C: Vectorize dimension in buffer must be divisible by 8"
T.call_packed("__tvm_set_device", 2, dev_id)
with T.attr(0, "compute_scope", "elem_add_compute_"):
with T.LetStmt(T.tvm_stack_alloca("arg_value", 16), var=A_desc):
T.call_packed("__tvm_tensormap_create_tiled", A_desc, 7, 2, A, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0)
with T.LetStmt(T.tvm_stack_alloca("arg_value", 16), var=B_desc):
T.call_packed("__tvm_tensormap_create_tiled", B_desc, 7, 2, B, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0)
with T.LetStmt(T.tvm_stack_alloca("arg_value", 16), var=C_desc):
T.call_packed("__tvm_tensormap_create_tiled", C_desc, 7, 2, C, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0)
T.call_packed("elem_add_kernel", A_desc, B_desc, C_desc, 16, 16, 256, 1, 1, 32768)
return 0
@T.prim_func
def elem_add_kernel(A_desc: T.handle("uint8x128", "grid_constant"), B_desc: T.handle("uint8x128", "grid_constant"), C_desc: T.handle("uint8x128", "grid_constant")):
T.func_attr({"calling_conv": 2, "dyn_shared_memory_buf": 32768, "target": T.target({"arch": "sm_120", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "thread_extent": {"blockIdx.x": 16, "blockIdx.y": 16, "threadIdx.x": 256, "threadIdx.y": 1, "threadIdx.z": 1}, "tir.is_global_func": T.bool(True), "tir.kernel_launch_params": ["blockIdx.x", "blockIdx.y", "threadIdx.x", "threadIdx.y", "threadIdx.z", "tir.use_dyn_shared_memory"], "tir.noalias": True})
buf_dyn_shmem = T.handle("uint8", "shared.dyn")
C_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
B_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
A_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
C_local = T.handle("float32", "local")
C_local_1 = T.decl_buffer((32,), data=C_local, scope="local")
bx = T.launch_thread("blockIdx.x", 16)
buf_dyn_shmem = T.allocate([32768], "uint8", "shared.dyn")
C_local = T.allocate([32], "float32", "local")
by = T.launch_thread("blockIdx.y", 16)
tx = T.launch_thread("threadIdx.x", 256)
T.create_barriers(1)
if T.tl_shuffle_elect(0):
T.call_extern("handle", "tl::prefetch_tma_descriptor", A_desc)
T.call_extern("handle", "tl::prefetch_tma_descriptor", B_desc)
T.call_extern("handle", "tl::prefetch_tma_descriptor", C_desc)
T.ptx_init_barrier_thread_count(T.get_mbarrier(0), 128)
T.ptx_fence_barrier_init()
T.tvm_storage_sync("shared")
ty = T.launch_thread("threadIdx.y", 1)
tz = T.launch_thread("threadIdx.z", 1)
T.attr([128, 128], "kWarpSpecializationScope", 0)
if 128 <= tx:
T.set_max_nreg(24, 0)
if T.tl_shuffle_elect(128):
T.mbarrier_expect_tx(T.get_mbarrier(0), 16384)
T.fence_proxy_async()
T.tma_load(A_desc, T.get_mbarrier(0), T.tvm_access_ptr(T.type_annotation("float32"), buf_dyn_shmem, 0, 4096, 2), bx * 64, by * 64, 0)
T.mbarrier_expect_tx(T.get_mbarrier(0), 16384)
T.fence_proxy_async()
T.tma_load(B_desc, T.get_mbarrier(0), T.tvm_access_ptr(T.type_annotation("float32"), buf_dyn_shmem, 4096, 4096, 2), bx * 64, by * 64, 0)
T.ptx_arrive_barrier(T.get_mbarrier(0))
else:
T.set_max_nreg(240, 1)
T.mbarrier_wait_parity(T.get_mbarrier(0), 0)
for i in T.unroll(8):
C_local_1[i * 4:i * 4 + 4] = A_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(0, 4)] + B_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(4096, 4)]
T.tvm_storage_sync("shared.dyn", 3, 128)
for i in T.unroll(8):
C_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(0, 4)] = C_local_1[i * 4:i * 4 + 4]
T.fence_proxy_async()
T.tvm_storage_sync("shared.dyn", 3, 128)
if T.tl_shuffle_elect(128):
T.tma_store(C_desc, T.tvm_access_ptr(T.type_annotation("float32"), buf_dyn_shmem, 0, 4096, 1), bx * 64, by * 64, 0, 0)
T.tma_store_arrive()
T.tma_store_wait()
host_mod = tir.transform.Filter(_is_host_call)(mod)
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def elem_add(args: T.handle, arg_type_ids: T.handle("int32", "global"), num_args: T.int32, out_ret_value: T.handle("void", "global"), out_ret_tcode: T.handle("int32", "global"), resource_handle: T.handle) -> T.int32:
C_desc = T.handle("uint8x128", "grid_constant")
C = T.handle("float32", "global")
B_desc = T.handle("uint8x128", "grid_constant")
B = T.handle("float32", "global")
A_desc = T.handle("uint8x128", "grid_constant")
A = T.handle("float32", "global")
T.func_attr({"calling_conv": 1, "target": T.target({"keys": ["cpu"], "kind": "c", "tag": ""}), "thread_extent": {}, "tir.is_entry_func": True, "tma_descriptor_args": {C_desc: ["__tvm_tensormap_create_tiled", C_desc, 7, 2, C, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0], B_desc: ["__tvm_tensormap_create_tiled", B_desc, 7, 2, B, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0], A_desc: ["__tvm_tensormap_create_tiled", A_desc, 7, 2, A, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0]}})
assert num_args == 3, "elem_add: num_args should be 3"
assert not T.isnullptr(args), "elem_add: TVMValue* arg pointer was NULL"
assert not T.isnullptr(arg_type_ids), "elem_add: int* type_codes was NULL"
arg_type_ids_1 = T.decl_buffer((3,), "int32", data=arg_type_ids)
A_handle_code: T.int32 = arg_type_ids_1[0]
assert A_handle_code == 0 or A_handle_code == 4 or A_handle_code == 7 or A_handle_code >= 64, "elem_add: Expect arg[0] to be pointer"
B_handle_code: T.int32 = arg_type_ids_1[1]
assert B_handle_code == 0 or B_handle_code == 4 or B_handle_code == 7 or B_handle_code >= 64, "elem_add: Expect arg[1] to be pointer"
C_handle_code: T.int32 = arg_type_ids_1[2]
assert C_handle_code == 0 or C_handle_code == 4 or C_handle_code == 7 or C_handle_code >= 64, "elem_add: Expect arg[2] to be pointer"
A_handle: T.handle = T.tvm_struct_get(args, 0, 12, "handle")
B_handle: T.handle = T.tvm_struct_get(args, 1, 12, "handle")
C_handle: T.handle = T.tvm_struct_get(args, 2, 12, "handle")
assert not T.isnullptr(A_handle), "elem_add.A_handle is expected to have non-NULL DLTensor* pointer"
assert 2 == T.tvm_struct_get(A_handle, 0, 4, "int32"), "elem_add.A_handle.ndim is expected to equal 2"
elem_add_A_handle_shape: T.handle("int64", "global") = T.tvm_struct_get(A_handle, 0, 2, "handle")
elem_add_A_handle_shape_1 = T.decl_buffer((2,), "int64", data=elem_add_A_handle_shape)
elem_add_A_handle_strides: T.handle("int64", "global") = T.tvm_struct_get(A_handle, 0, 3, "handle")
elem_add_A_handle_strides_1 = T.decl_buffer((2,), "int64", data=elem_add_A_handle_strides)
dev_id: T.int32 = T.tvm_struct_get(A_handle, 0, 9, "int32")
with T.LetStmt(T.tvm_struct_get(A_handle, 0, 1, "handle"), var=A):
T.attr(A, "storage_alignment", 64)
assert not T.isnullptr(B_handle), "elem_add.B_handle is expected to have non-NULL DLTensor* pointer"
assert 2 == T.tvm_struct_get(B_handle, 0, 4, "int32"), "elem_add.B_handle.ndim is expected to equal 2"
elem_add_B_handle_shape: T.handle("int64", "global") = T.tvm_struct_get(B_handle, 0, 2, "handle")
elem_add_B_handle_shape_1 = T.decl_buffer((2,), "int64", data=elem_add_B_handle_shape)
elem_add_B_handle_strides: T.handle("int64", "global") = T.tvm_struct_get(B_handle, 0, 3, "handle")
elem_add_B_handle_strides_1 = T.decl_buffer((2,), "int64", data=elem_add_B_handle_strides)
with T.LetStmt(T.tvm_struct_get(B_handle, 0, 1, "handle"), var=B):
T.attr(B, "storage_alignment", 64)
assert not T.isnullptr(C_handle), "elem_add.C_handle is expected to have non-NULL DLTensor* pointer"
assert 2 == T.tvm_struct_get(C_handle, 0, 4, "int32"), "elem_add.C_handle.ndim is expected to equal 2"
elem_add_C_handle_shape: T.handle("int64", "global") = T.tvm_struct_get(C_handle, 0, 2, "handle")
elem_add_C_handle_shape_1 = T.decl_buffer((2,), "int64", data=elem_add_C_handle_shape)
elem_add_C_handle_strides: T.handle("int64", "global") = T.tvm_struct_get(C_handle, 0, 3, "handle")
elem_add_C_handle_strides_1 = T.decl_buffer((2,), "int64", data=elem_add_C_handle_strides)
with T.LetStmt(T.tvm_struct_get(C_handle, 0, 1, "handle"), var=C):
T.attr(C, "storage_alignment", 64)
T.attr("default", "device_id", dev_id)
T.attr("default", "device_type", 2)
assert T.tvm_struct_get(A_handle, 0, 5, "uint8") == T.uint8(2) and T.tvm_struct_get(A_handle, 0, 6, "uint8") == T.uint8(32) and T.tvm_struct_get(A_handle, 0, 7, "uint16") == T.uint16(1), "elem_add.A_handle.dtype is expected to be float32"
assert T.Cast("int32", elem_add_A_handle_shape_1[0]) == 1024, "Argument elem_add.A_handle.shape[0] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_A_handle_shape[0])"
assert T.Cast("int32", elem_add_A_handle_shape_1[1]) == 1024, "Argument elem_add.A_handle.shape[1] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_A_handle_shape[1])"
assert T.if_then_else(T.isnullptr(elem_add_A_handle_strides), 1, T.Cast("int32", elem_add_A_handle_strides_1[1])) == 1, "Argument elem_add.A_handle.strides[1] has an unsatisfied constraint: 1 == T.if_then_else(T.isnullptr(elem_add_A_handle_strides), 1, T.Cast(\"int32\", elem_add_A_handle_strides_1[1]))"
assert T.if_then_else(T.isnullptr(elem_add_A_handle_strides), T.Cast("int32", elem_add_A_handle_shape_1[1]), T.Cast("int32", elem_add_A_handle_strides_1[0])) == 1024, "Argument elem_add.A_handle.strides[0] has an unsatisfied constraint: 1024 == T.if_then_else(T.isnullptr(elem_add_A_handle_strides), T.Cast(\"int32\", elem_add_A_handle_shape[1]), T.Cast(\"int32\", elem_add_A_handle_strides_1[0]))"
assert T.uint64(0) == T.tvm_struct_get(A_handle, 0, 8, "uint64"), "Argument elem_add.A_handle.byte_offset has an unsatisfied constraint: T.uint64(0) == T.tvm_struct_get(A_handle, 0, 8, \"uint64\")"
assert T.tvm_struct_get(A_handle, 0, 10, "int32") == 2, "Argument elem_add.A_handle.device_type has an unsatisfied constraint: 2 == T.tvm_struct_get(A_handle, 0, 10, \"int32\")"
assert not T.isnullptr(A), "elem_add.A_handle is expected to have non-NULL data pointer"
assert T.tvm_struct_get(B_handle, 0, 5, "uint8") == T.uint8(2) and T.tvm_struct_get(B_handle, 0, 6, "uint8") == T.uint8(32) and T.tvm_struct_get(B_handle, 0, 7, "uint16") == T.uint16(1), "elem_add.B_handle.dtype is expected to be float32"
assert T.Cast("int32", elem_add_B_handle_shape_1[0]) == 1024, "Argument elem_add.B_handle.shape[0] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_B_handle_shape[0])"
assert T.Cast("int32", elem_add_B_handle_shape_1[1]) == 1024, "Argument elem_add.B_handle.shape[1] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_B_handle_shape[1])"
assert T.if_then_else(T.isnullptr(elem_add_B_handle_strides), 1, T.Cast("int32", elem_add_B_handle_strides_1[1])) == 1, "Argument elem_add.B_handle.strides[1] has an unsatisfied constraint: 1 == T.if_then_else(T.isnullptr(elem_add_B_handle_strides), 1, T.Cast(\"int32\", elem_add_B_handle_strides_1[1]))"
assert T.if_then_else(T.isnullptr(elem_add_B_handle_strides), T.Cast("int32", elem_add_B_handle_shape_1[1]), T.Cast("int32", elem_add_B_handle_strides_1[0])) == 1024, "Argument elem_add.B_handle.strides[0] has an unsatisfied constraint: 1024 == T.if_then_else(T.isnullptr(elem_add_B_handle_strides), T.Cast(\"int32\", elem_add_B_handle_shape[1]), T.Cast(\"int32\", elem_add_B_handle_strides_1[0]))"
assert T.uint64(0) == T.tvm_struct_get(B_handle, 0, 8, "uint64"), "Argument elem_add.B_handle.byte_offset has an unsatisfied constraint: T.uint64(0) == T.tvm_struct_get(B_handle, 0, 8, \"uint64\")"
assert T.tvm_struct_get(B_handle, 0, 10, "int32") == 2, "Argument elem_add.B_handle.device_type has an unsatisfied constraint: 2 == T.tvm_struct_get(B_handle, 0, 10, \"int32\")"
assert dev_id == T.tvm_struct_get(B_handle, 0, 9, "int32"), "Argument elem_add.B_handle.device_id has an unsatisfied constraint: dev_id == T.tvm_struct_get(B_handle, 0, 9, \"int32\")"
assert not T.isnullptr(B), "elem_add.B_handle is expected to have non-NULL data pointer"
assert T.tvm_struct_get(C_handle, 0, 5, "uint8") == T.uint8(2) and T.tvm_struct_get(C_handle, 0, 6, "uint8") == T.uint8(32) and T.tvm_struct_get(C_handle, 0, 7, "uint16") == T.uint16(1), "elem_add.C_handle.dtype is expected to be float32"
assert T.Cast("int32", elem_add_C_handle_shape_1[0]) == 1024, "Argument elem_add.C_handle.shape[0] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_C_handle_shape[0])"
assert T.Cast("int32", elem_add_C_handle_shape_1[1]) == 1024, "Argument elem_add.C_handle.shape[1] has an unsatisfied constraint: 1024 == T.Cast(\"int32\", elem_add_C_handle_shape[1])"
assert T.if_then_else(T.isnullptr(elem_add_C_handle_strides), 1, T.Cast("int32", elem_add_C_handle_strides_1[1])) == 1, "Argument elem_add.C_handle.strides[1] has an unsatisfied constraint: 1 == T.if_then_else(T.isnullptr(elem_add_C_handle_strides), 1, T.Cast(\"int32\", elem_add_C_handle_strides_1[1]))"
assert T.if_then_else(T.isnullptr(elem_add_C_handle_strides), T.Cast("int32", elem_add_C_handle_shape_1[1]), T.Cast("int32", elem_add_C_handle_strides_1[0])) == 1024, "Argument elem_add.C_handle.strides[0] has an unsatisfied constraint: 1024 == T.if_then_else(T.isnullptr(elem_add_C_handle_strides), T.Cast(\"int32\", elem_add_C_handle_shape[1]), T.Cast(\"int32\", elem_add_C_handle_strides_1[0]))"
assert T.uint64(0) == T.tvm_struct_get(C_handle, 0, 8, "uint64"), "Argument elem_add.C_handle.byte_offset has an unsatisfied constraint: T.uint64(0) == T.tvm_struct_get(C_handle, 0, 8, \"uint64\")"
assert T.tvm_struct_get(C_handle, 0, 10, "int32") == 2, "Argument elem_add.C_handle.device_type has an unsatisfied constraint: 2 == T.tvm_struct_get(C_handle, 0, 10, \"int32\")"
assert dev_id == T.tvm_struct_get(C_handle, 0, 9, "int32"), "Argument elem_add.C_handle.device_id has an unsatisfied constraint: dev_id == T.tvm_struct_get(C_handle, 0, 9, \"int32\")"
assert not T.isnullptr(C), "elem_add.C_handle is expected to have non-NULL data pointer"
A_1 = T.decl_buffer((1024, 1024), data=A, strides=(1024, 1))
B_1 = T.decl_buffer((1024, 1024), data=B, strides=(1024, 1))
C_1 = T.decl_buffer((1024, 1024), data=C, strides=(1024, 1))
assert T.FloorMod(1024, 8) == 0, "A: Vectorize dimension in buffer must be divisible by 8"
assert T.FloorMod(1024, 8) == 0, "B: Vectorize dimension in buffer must be divisible by 8"
assert T.FloorMod(1024, 8) == 0, "C: Vectorize dimension in buffer must be divisible by 8"
T.call_packed("__tvm_set_device", 2, dev_id)
with T.attr(0, "compute_scope", "elem_add_compute_"):
with T.LetStmt(T.tvm_stack_alloca("arg_value", 16), var=A_desc):
T.call_packed("__tvm_tensormap_create_tiled", A_desc, 7, 2, A, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0)
with T.LetStmt(T.tvm_stack_alloca("arg_value", 16), var=B_desc):
T.call_packed("__tvm_tensormap_create_tiled", B_desc, 7, 2, B, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0)
with T.LetStmt(T.tvm_stack_alloca("arg_value", 16), var=C_desc):
T.call_packed("__tvm_tensormap_create_tiled", C_desc, 7, 2, C, 1024, 1024, 4, 4096, 64, 64, 1, 1, 0, 0, 2, 0)
T.call_packed("elem_add_kernel", A_desc, B_desc, C_desc, 16, 16, 256, 1, 1, 32768)
return 0
device_mod = tir.transform.Filter(_is_device_call)(mod)
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def elem_add_kernel(A_desc: T.handle("uint8x128", "grid_constant"), B_desc: T.handle("uint8x128", "grid_constant"), C_desc: T.handle("uint8x128", "grid_constant")):
T.func_attr({"calling_conv": 2, "dyn_shared_memory_buf": 32768, "target": T.target({"arch": "sm_120", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "thread_extent": {"blockIdx.x": 16, "blockIdx.y": 16, "threadIdx.x": 256, "threadIdx.y": 1, "threadIdx.z": 1}, "tir.is_global_func": T.bool(True), "tir.kernel_launch_params": ["blockIdx.x", "blockIdx.y", "threadIdx.x", "threadIdx.y", "threadIdx.z", "tir.use_dyn_shared_memory"], "tir.noalias": True})
buf_dyn_shmem = T.handle("uint8", "shared.dyn")
C_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
B_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
A_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
C_local = T.handle("float32", "local")
C_local_1 = T.decl_buffer((32,), data=C_local, scope="local")
bx = T.launch_thread("blockIdx.x", 16)
buf_dyn_shmem = T.allocate([32768], "uint8", "shared.dyn")
C_local = T.allocate([32], "float32", "local")
by = T.launch_thread("blockIdx.y", 16)
tx = T.launch_thread("threadIdx.x", 256)
T.create_barriers(1)
if T.tl_shuffle_elect(0):
T.call_extern("handle", "tl::prefetch_tma_descriptor", A_desc)
T.call_extern("handle", "tl::prefetch_tma_descriptor", B_desc)
T.call_extern("handle", "tl::prefetch_tma_descriptor", C_desc)
T.ptx_init_barrier_thread_count(T.get_mbarrier(0), 128)
T.ptx_fence_barrier_init()
T.tvm_storage_sync("shared")
ty = T.launch_thread("threadIdx.y", 1)
tz = T.launch_thread("threadIdx.z", 1)
T.attr([128, 128], "kWarpSpecializationScope", 0)
if 128 <= tx:
T.set_max_nreg(24, 0)
if T.tl_shuffle_elect(128):
T.mbarrier_expect_tx(T.get_mbarrier(0), 16384)
T.fence_proxy_async()
T.tma_load(A_desc, T.get_mbarrier(0), T.tvm_access_ptr(T.type_annotation("float32"), buf_dyn_shmem, 0, 4096, 2), bx * 64, by * 64, 0)
T.mbarrier_expect_tx(T.get_mbarrier(0), 16384)
T.fence_proxy_async()
T.tma_load(B_desc, T.get_mbarrier(0), T.tvm_access_ptr(T.type_annotation("float32"), buf_dyn_shmem, 4096, 4096, 2), bx * 64, by * 64, 0)
T.ptx_arrive_barrier(T.get_mbarrier(0))
else:
T.set_max_nreg(240, 1)
T.mbarrier_wait_parity(T.get_mbarrier(0), 0)
for i in T.unroll(8):
C_local_1[i * 4:i * 4 + 4] = A_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(0, 4)] + B_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(4096, 4)]
T.tvm_storage_sync("shared.dyn", 3, 128)
for i in T.unroll(8):
C_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(0, 4)] = C_local_1[i * 4:i * 4 + 4]
T.fence_proxy_async()
T.tvm_storage_sync("shared.dyn", 3, 128)
if T.tl_shuffle_elect(128):
T.tma_store(C_desc, T.tvm_access_ptr(T.type_annotation("float32"), buf_dyn_shmem, 0, 4096, 1), bx * 64, by * 64, 0, 0)
T.tma_store_arrive()
T.tma_store_wait()
调用device_codegen_without_compile
codegen_mod = device_codegen(
device_mod, target) if enable_device_compile else device_codegen_without_compile(
device_mod, target)
device_codegen_without_compile
tilelang/engine/lower.py
def device_codegen_without_compile(device_mod
LowerDeviceStorageAccessInfo
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def elem_add_kernel(A_desc: T.handle("uint8x128", "grid_constant"), B_desc: T.handle("uint8x128", "grid_constant"), C_desc: T.handle("uint8x128", "grid_constant")):
T.func_attr({"calling_conv": 2, "dyn_shared_memory_buf": 32768, "target": T.target({"arch": "sm_120", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "thread_extent": {"blockIdx.x": 16, "blockIdx.y": 16, "threadIdx.x": 256, "threadIdx.y": 1, "threadIdx.z": 1}, "tir.is_global_func": T.bool(True), "tir.kernel_launch_params": ["blockIdx.x", "blockIdx.y", "threadIdx.x", "threadIdx.y", "threadIdx.z", "tir.use_dyn_shared_memory"], "tir.noalias": True})
buf_dyn_shmem = T.handle("uint8", "shared.dyn")
C_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
B_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
A_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
C_local = T.handle("float32", "local")
C_local_1 = T.decl_buffer((32,), data=C_local, scope="local")
bx = T.launch_thread("blockIdx.x", 16)
buf_dyn_shmem = T.allocate([32768], "uint8", "shared.dyn")
C_local = T.allocate([32], "float32", "local")
by = T.launch_thread("blockIdx.y", 16)
tx = T.launch_thread("threadIdx.x", 256)
T.create_barriers(1)
if T.tl_shuffle_elect(0):
T.call_extern("handle", "tl::prefetch_tma_descriptor", A_desc)
T.call_extern("handle", "tl::prefetch_tma_descriptor", B_desc)
T.call_extern("handle", "tl::prefetch_tma_descriptor", C_desc)
T.ptx_init_barrier_thread_count(T.get_mbarrier(0), 128)
T.ptx_fence_barrier_init()
T.tvm_storage_sync("shared")
ty = T.launch_thread("threadIdx.y", 1)
tz = T.launch_thread("threadIdx.z", 1)
T.attr([128, 128], "kWarpSpecializationScope", 0)
if 128 <= tx:
T.set_max_nreg(24, 0)
if T.tl_shuffle_elect(128):
T.mbarrier_expect_tx(T.get_mbarrier(0), 16384)
T.fence_proxy_async()
buf_dyn_shmem_1 = T.Buffer((1,), data=buf_dyn_shmem, scope="shared.dyn")
T.tma_load(A_desc, T.get_mbarrier(0), T.address_of(buf_dyn_shmem_1[0]), bx * 64, by * 64, 0)
T.mbarrier_expect_tx(T.get_mbarrier(0), 16384)
T.fence_proxy_async()
buf_dyn_shmem_2 = T.Buffer((4097,), data=buf_dyn_shmem, scope="shared.dyn")
T.tma_load(B_desc, T.get_mbarrier(0), T.address_of(buf_dyn_shmem_2[4096]), bx * 64, by * 64, 0)
T.ptx_arrive_barrier(T.get_mbarrier(0))
else:
T.set_max_nreg(240, 1)
T.mbarrier_wait_parity(T.get_mbarrier(0), 0)
for i in T.unroll(8):
C_local_1[i * 4:i * 4 + 4] = A_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(0, 4)] + B_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(4096, 4)]
T.tvm_storage_sync("shared.dyn", 3, 128)
for i in T.unroll(8):
C_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(0, 4)] = C_local_1[i * 4:i * 4 + 4]
T.fence_proxy_async()
T.tvm_storage_sync("shared.dyn", 3, 128)
if T.tl_shuffle_elect(128):
buf_dyn_shmem_1 = T.Buffer((1,), data=buf_dyn_shmem, scope="shared.dyn")
T.tma_store(C_desc, T.address_of(buf_dyn_shmem_1[0]), bx * 64, by * 64, 0, 0)
T.tma_store_arrive()
T.tma_store_wait()
LowerIntrin
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def elem_add_kernel(A_desc: T.handle("uint8x128", "grid_constant"), B_desc: T.handle("uint8x128", "grid_constant"), C_desc: T.handle("uint8x128", "grid_constant")):
T.func_attr({"calling_conv": 2, "dyn_shared_memory_buf": 32768, "target": T.target({"arch": "sm_120", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "thread_extent": {"blockIdx.x": 16, "blockIdx.y": 16, "threadIdx.x": 256, "threadIdx.y": 1, "threadIdx.z": 1}, "tir.is_global_func": T.bool(True), "tir.kernel_launch_params": ["blockIdx.x", "blockIdx.y", "threadIdx.x", "threadIdx.y", "threadIdx.z", "tir.use_dyn_shared_memory"], "tir.noalias": True})
buf_dyn_shmem = T.handle("uint8", "shared.dyn")
C_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
B_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
A_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
C_local = T.handle("float32", "local")
C_local_1 = T.decl_buffer((32,), data=C_local, scope="local")
bx = T.launch_thread("blockIdx.x", 16)
buf_dyn_shmem = T.allocate([32768], "uint8", "shared.dyn")
C_local = T.allocate([32], "float32", "local")
by = T.launch_thread("blockIdx.y", 16)
tx = T.launch_thread("threadIdx.x", 256)
T.create_barriers(1)
if T.tl_shuffle_elect(0):
T.call_extern("handle", "tl::prefetch_tma_descriptor", A_desc)
T.call_extern("handle", "tl::prefetch_tma_descriptor", B_desc)
T.call_extern("handle", "tl::prefetch_tma_descriptor", C_desc)
T.ptx_init_barrier_thread_count(T.get_mbarrier(0), 128)
T.ptx_fence_barrier_init()
T.tvm_storage_sync("shared")
ty = T.launch_thread("threadIdx.y", 1)
tz = T.launch_thread("threadIdx.z", 1)
T.attr([128, 128], "kWarpSpecializationScope", 0)
if 128 <= tx:
T.set_max_nreg(24, 0)
if T.tl_shuffle_elect(128):
T.mbarrier_expect_tx(T.get_mbarrier(0), 16384)
T.fence_proxy_async()
buf_dyn_shmem_1 = T.Buffer((1,), data=buf_dyn_shmem, scope="shared.dyn")
T.tma_load(A_desc, T.get_mbarrier(0), T.address_of(buf_dyn_shmem_1[0]), bx * 64, by * 64, 0)
T.mbarrier_expect_tx(T.get_mbarrier(0), 16384)
T.fence_proxy_async()
buf_dyn_shmem_2 = T.Buffer((4097,), data=buf_dyn_shmem, scope="shared.dyn")
T.tma_load(B_desc, T.get_mbarrier(0), T.address_of(buf_dyn_shmem_2[4096]), bx * 64, by * 64, 0)
T.ptx_arrive_barrier(T.get_mbarrier(0))
else:
T.set_max_nreg(240, 1)
T.mbarrier_wait_parity(T.get_mbarrier(0), 0)
for i in T.unroll(8):
C_local_1[i * 4:i * 4 + 4] = A_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(0, 4)] + B_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(4096, 4)]
T.tvm_storage_sync("shared.dyn", 3, 128)
for i in T.unroll(8):
C_shared[T.Ramp(i * 512 + tx * 4, 1, 4) + T.Broadcast(0, 4)] = C_local_1[i * 4:i * 4 + 4]
T.fence_proxy_async()
T.tvm_storage_sync("shared.dyn", 3, 128)
if T.tl_shuffle_elect(128):
buf_dyn_shmem_1 = T.Buffer((1,), data=buf_dyn_shmem, scope="shared.dyn")
T.tma_store(C_desc, T.address_of(buf_dyn_shmem_1[0]), bx * 64, by * 64, 0, 0)
T.tma_store_arrive()
T.tma_store_wait()
Simplify
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def elem_add_kernel(A_desc: T.handle("uint8x128", "grid_constant"), B_desc: T.handle("uint8x128", "grid_constant"), C_desc: T.handle("uint8x128", "grid_constant")):
T.func_attr({"calling_conv": 2, "dyn_shared_memory_buf": 32768, "target": T.target({"arch": "sm_120", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "thread_extent": {"blockIdx.x": 16, "blockIdx.y": 16, "threadIdx.x": 256, "threadIdx.y": 1, "threadIdx.z": 1}, "tir.is_global_func": T.bool(True), "tir.kernel_launch_params": ["blockIdx.x", "blockIdx.y", "threadIdx.x", "threadIdx.y", "threadIdx.z", "tir.use_dyn_shared_memory"], "tir.noalias": True})
buf_dyn_shmem = T.handle("uint8", "shared.dyn")
C_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
B_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
A_shared = T.decl_buffer((4096,), data=buf_dyn_shmem, scope="shared.dyn")
C_local = T.handle("float32", "local")
C_local_1 = T.decl_buffer((32,), data=C_local, scope="local")
bx = T.launch_thread("blockIdx.x", 16)
buf_dyn_shmem = T.allocate([32768], "uint8", "shared.dyn")
C_local = T.allocate([32], "float32", "local")
by = T.launch_thread("blockIdx.y", 16)
tx = T.launch_thread("threadIdx.x", 256)
T.create_barriers(1)
if T.tl_shuffle_elect(0):
T.call_extern("handle", "tl::prefetch_tma_descriptor", A_desc)
T.call_extern("handle", "tl::prefetch_tma_descriptor", B_desc)
T.call_extern("handle", "tl::prefetch_tma_descriptor", C_desc)
T.ptx_init_barrier_thread_count(T.get_mbarrier(0), 128)
T.ptx_fence_barrier_init()
T.tvm_storage_sync("shared")
ty = T.launch_thread("threadIdx.y", 1)
tz = T.launch_thread("threadIdx.z", 1)
T.attr([128, 128], "kWarpSpecializationScope", 0)
if 128 <= tx:
T.set_max_nreg(24, 0)
if T.tl_shuffle_elect(128):
T.mbarrier_expect_tx(T.get_mbarrier(0), 16384)
T.fence_proxy_async()
buf_dyn_shmem_1 = T.Buffer((1,), data=buf_dyn_shmem, scope="shared.dyn")
T.tma_load(A_desc, T.get_mbarrier(0), T.address_of(buf_dyn_shmem_1[0]), bx * 64, by * 64, 0)
T.mbarrier_expect_tx(T.get_mbarrier(0), 16384)
T.fence_proxy_async()
buf_dyn_shmem_2 = T.Buffer((4097,), data=buf_dyn_shmem, scope="shared.dyn")
T.tma_load(B_desc, T.get_mbarrier(0), T.address_of(buf_dyn_shmem_2[4096]), bx * 64, by * 64, 0)
T.ptx_arrive_barrier(T.get_mbarrier(0))
else:
T.set_max_nreg(240, 1)
T.mbarrier_wait_parity(T.get_mbarrier(0), 0)
for i in T.unroll(8):
C_local_1[i * 4:i * 4 + 4] = A_shared[i * 512 + tx * 4:i * 512 + tx * 4 + 4] + B_shared[i * 512 + tx * 4 + 4096:i * 512 + tx * 4 + 4096 + 4]
T.tvm_storage_sync("shared.dyn", 3, 128)
for i in T.unroll(8):
C_shared[i * 512 + tx * 4:i * 512 + tx * 4 + 4] = C_local_1[i * 4:i * 4 + 4]
T.fence_proxy_async()
T.tvm_storage_sync("shared.dyn", 3, 128)
if T.tl_shuffle_elect(128):
buf_dyn_shmem_1 = T.Buffer((1,), data=buf_dyn_shmem, scope="shared.dyn")
T.tma_store(C_desc, T.address_of(buf_dyn_shmem_1[0]), bx * 64, by * 64, 0, 0)
T.tma_store_arrive()
T.tma_store_wait()
device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda_without_compile")(device_mod, target)
#include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
#include <tl_templates/cuda/reduce.h>
#include <tl_templates/cuda/ldsm.h>
#include <tl_templates/cuda/threadblock_swizzle.h>
#include <tl_templates/cuda/debug.h>
#ifdef ENABLE_BF16
#include <tl_templates/cuda/cuda_bf16_fallbacks.cuh>
#endif
extern "C" __global__ void elem_add_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, __grid_constant__ const CUtensorMap C_desc);
extern "C" __global__ void __launch_bounds__(256, 1) elem_add_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, __grid_constant__ const CUtensorMap C_desc) {
extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
float C_local[32];
__shared__ uint64_t mbarrier_mem[1];
auto mbarrier = reinterpret_cast<Barrier*>(mbarrier_mem);
if (tl::tl_shuffle_elect<0>()) {
tl::prefetch_tma_descriptor(A_desc);
tl::prefetch_tma_descriptor(B_desc);
tl::prefetch_tma_descriptor(C_desc);
mbarrier[0].init(128);
}
tl::fence_barrier_init();
__syncthreads();
if (128 <= ((int)threadIdx.x)) {
tl::warpgroup_reg_dealloc<24>();
if (tl::tl_shuffle_elect<128>()) {
mbarrier[0].expect_transaction(16384);
tl::fence_proxy_async();
tl::tma_load(A_desc, mbarrier[0], (&(((float*)buf_dyn_shmem)[0])), (((int)blockIdx.x) * 64), (((int)blockIdx.y) * 64));
mbarrier[0].expect_transaction(16384);
tl::fence_proxy_async();
tl::tma_load(B_desc, mbarrier[0], (&(((float*)buf_dyn_shmem)[4096])), (((int)blockIdx.x) * 64), (((int)blockIdx.y) * 64));
}
mbarrier[0].arrive();
} else {
tl::warpgroup_reg_alloc<240>();
mbarrier[0].wait(0);
#pragma unroll
for (int i = 0; i < 8; ++i) {
float4 __1;
float4 v_ = *(float4*)(((float*)buf_dyn_shmem) + ((i * 512) + (((int)threadIdx.x) * 4)));
float4 v__1 = *(float4*)(((float*)buf_dyn_shmem) + (((i * 512) + (((int)threadIdx.x) * 4)) + 4096));
__1.x = (v_.x+v__1.x);
__1.y = (v_.y+v__1.y);
__1.z = (v_.z+v__1.z);
__1.w = (v_.w+v__1.w);
*(float4*)(C_local + (i * 4)) = __1;
}
tl::__sync_thread_partial<3, 128>();
#pragma unroll
for (int i_1 = 0; i_1 < 8; ++i_1) {
*(float4*)(((float*)buf_dyn_shmem) + ((i_1 * 512) + (((int)threadIdx.x) * 4))) = *(float4*)(C_local + (i_1 * 4));
}
tl::fence_proxy_async();
tl::__sync_thread_partial<3, 128>();
if (tl::tl_shuffle_elect<128>()) {
tl::tma_store(C_desc, (&(((float*)buf_dyn_shmem)[0])), (((int)blockIdx.x) * 64), (((int)blockIdx.y) * 64));
tl::tma_store_arrive();
tl::tma_store_wait<0>();
}
}
}
CythonKernelAdapter
调用完tilelang.lower之后,会继续调用CythonKernelAdapter
adapter = CythonKernelAdapter(
params=artifact.params,
result_idx=out_idx,
target=target,
func_or_mod=tilelang_func,
host_mod=artifact.host_mod,
device_mod=artifact.device_mod,
kernel_global_source=artifact.kernel_source,
verbose=verbose,
pass_configs=pass_configs,
compile_flags=compile_flags,
)
self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True))
#include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
#include <tl_templates/cuda/reduce.h>
#include <tl_templates/cuda/ldsm.h>
#include <tl_templates/cuda/threadblock_swizzle.h>
#include <tl_templates/cuda/debug.h>
#ifdef ENABLE_BF16
#include <tl_templates/cuda/cuda_bf16_fallbacks.cuh>
#endif
extern "C" __global__ void elem_add_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, __grid_constant__ const CUtensorMap C_desc);
extern "C" __global__ void __launch_bounds__(256, 1) elem_add_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, __grid_constant__ const CUtensorMap C_desc) {
extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
float C_local[32];
__shared__ uint64_t mbarrier_mem[1];
auto mbarrier = reinterpret_cast<Barrier*>(mbarrier_mem);
if (tl::tl_shuffle_elect<0>()) {
tl::prefetch_tma_descriptor(A_desc);
tl::prefetch_tma_descriptor(B_desc);
tl::prefetch_tma_descriptor(C_desc);
mbarrier[0].init(128);
}
tl::fence_barrier_init();
__syncthreads();
if (128 <= ((int)threadIdx.x)) {
tl::warpgroup_reg_dealloc<24>();
if (tl::tl_shuffle_elect<128>()) {
mbarrier[0].expect_transaction(16384);
tl::fence_proxy_async();
tl::tma_load(A_desc, mbarrier[0], (&(((float*)buf_dyn_shmem)[0])), (((int)blockIdx.x) * 64), (((int)blockIdx.y) * 64));
mbarrier[0].expect_transaction(16384);
tl::fence_proxy_async();
tl::tma_load(B_desc, mbarrier[0], (&(((float*)buf_dyn_shmem)[4096])), (((int)blockIdx.x) * 64), (((int)blockIdx.y) * 64));
}
mbarrier[0].arrive();
} else {
tl::warpgroup_reg_alloc<240>();
mbarrier[0].wait(0);
#pragma unroll
for (int i = 0; i < 8; ++i) {
float4 __1;
float4 v_ = *(float4*)(((float*)buf_dyn_shmem) + ((i * 512) + (((int)threadIdx.x) * 4)));
float4 v__1 = *(float4*)(((float*)buf_dyn_shmem) + (((i * 512) + (((int)threadIdx.x) * 4)) + 4096));
__1.x = (v_.x+v__1.x);
__1.y = (v_.y+v__1.y);
__1.z = (v_.z+v__1.z);
__1.w = (v_.w+v__1.w);
*(float4*)(C_local + (i * 4)) = __1;
}
tl::__sync_thread_partial<3, 128>();
#pragma unroll
for (int i_1 = 0; i_1 < 8; ++i_1) {
*(float4*)(((float*)buf_dyn_shmem) + ((i_1 * 512) + (((int)threadIdx.x) * 4))) = *(float4*)(C_local + (i_1 * 4));
}
tl::fence_proxy_async();
tl::__sync_thread_partial<3, 128>();
if (tl::tl_shuffle_elect<128>()) {
tl::tma_store(C_desc, (&(((float*)buf_dyn_shmem)[0])), (((int)blockIdx.x) * 64), (((int)blockIdx.y) * 64));
tl::tma_store_arrive();
tl::tma_store_wait<0>();
}
}
}
#define ERROR_BUF_SIZE 1024
static char error_buf[ERROR_BUF_SIZE];
extern "C" const char* get_last_error() {
return error_buf;
}
extern "C" int init() {
error_buf[0] = '\0';
cudaError_t result_elem_add_kernel = cudaFuncSetAttribute(elem_add_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 32768);
if (result_elem_add_kernel != cudaSuccess) {
snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", 32768, cudaGetErrorString(result_elem_add_kernel));
return -1;
}
return 0;
}
extern "C" int call(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C, cudaStream_t stream=cudaStreamDefault) {
CUtensorMap A_desc;
CUtensorMapDataType A_desc_type= (CUtensorMapDataType)7;
cuuint32_t A_desc_tensorRank= 2;
void *A_desc_globalAddress= A;
cuuint64_t A_desc_globalDim[2]= {1024,1024};
cuuint64_t A_desc_globalStride[2]= {4,4096};
cuuint32_t A_desc_boxDim[2]= {64,64};
cuuint32_t A_desc_elementStrides[2]= {1,1};
CUtensorMapInterleave A_desc_interleave= (CUtensorMapInterleave)0;
CUtensorMapSwizzle A_desc_swizzle= (CUtensorMapSwizzle)0;
CUtensorMapL2promotion A_desc_l2Promotion= (CUtensorMapL2promotion)2;
CUtensorMapFloatOOBfill A_desc_oobFill= (CUtensorMapFloatOOBfill)0;
CUresult A_desc_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
&A_desc, A_desc_type, A_desc_tensorRank, A_desc_globalAddress, A_desc_globalDim, A_desc_globalStride + 1, A_desc_boxDim, A_desc_elementStrides, A_desc_interleave, A_desc_swizzle, A_desc_l2Promotion, A_desc_oobFill);
if (A_desc_result != CUDA_SUCCESS) {
std::stringstream ss;
ss << "Error: Failed to initialize the TMA descriptor A_desc";
snprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str());
return -1;
}
CUtensorMap B_desc;
CUtensorMapDataType B_desc_type= (CUtensorMapDataType)7;
cuuint32_t B_desc_tensorRank= 2;
void *B_desc_globalAddress= B;
cuuint64_t B_desc_globalDim[2]= {1024,1024};
cuuint64_t B_desc_globalStride[2]= {4,4096};
cuuint32_t B_desc_boxDim[2]= {64,64};
cuuint32_t B_desc_elementStrides[2]= {1,1};
CUtensorMapInterleave B_desc_interleave= (CUtensorMapInterleave)0;
CUtensorMapSwizzle B_desc_swizzle= (CUtensorMapSwizzle)0;
CUtensorMapL2promotion B_desc_l2Promotion= (CUtensorMapL2promotion)2;
CUtensorMapFloatOOBfill B_desc_oobFill= (CUtensorMapFloatOOBfill)0;
CUresult B_desc_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
&B_desc, B_desc_type, B_desc_tensorRank, B_desc_globalAddress, B_desc_globalDim, B_desc_globalStride + 1, B_desc_boxDim, B_desc_elementStrides, B_desc_interleave, B_desc_swizzle, B_desc_l2Promotion, B_desc_oobFill);
if (B_desc_result != CUDA_SUCCESS) {
std::stringstream ss;
ss << "Error: Failed to initialize the TMA descriptor B_desc";
snprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str());
return -1;
}
CUtensorMap C_desc;
CUtensorMapDataType C_desc_type= (CUtensorMapDataType)7;
cuuint32_t C_desc_tensorRank= 2;
void *C_desc_globalAddress= C;
cuuint64_t C_desc_globalDim[2]= {1024,1024};
cuuint64_t C_desc_globalStride[2]= {4,4096};
cuuint32_t C_desc_boxDim[2]= {64,64};
cuuint32_t C_desc_elementStrides[2]= {1,1};
CUtensorMapInterleave C_desc_interleave= (CUtensorMapInterleave)0;
CUtensorMapSwizzle C_desc_swizzle= (CUtensorMapSwizzle)0;
CUtensorMapL2promotion C_desc_l2Promotion= (CUtensorMapL2promotion)2;
CUtensorMapFloatOOBfill C_desc_oobFill= (CUtensorMapFloatOOBfill)0;
CUresult C_desc_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
&C_desc, C_desc_type, C_desc_tensorRank, C_desc_globalAddress, C_desc_globalDim, C_desc_globalStride + 1, C_desc_boxDim, C_desc_elementStrides, C_desc_interleave, C_desc_swizzle, C_desc_l2Promotion, C_desc_oobFill);
if (C_desc_result != CUDA_SUCCESS) {
std::stringstream ss;
ss << "Error: Failed to initialize the TMA descriptor C_desc";
snprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str());
return -1;
}
elem_add_kernel<<<dim3(16, 16, 1), dim3(256, 1, 1), 32768, stream>>>(A_desc, B_desc, C_desc);
TILELANG_CHECK_LAST_ERROR("elem_add_kernel");
return 0;
}
调用命令编译cuda cu文件
['/usr/local/cuda-12.8/bin/nvcc',
'-std=c++17',
'-w',
'-Xcudafe',
'--diag_suppress=177',
'--compiler-options',
'-fPIC',
'-lineinfo',
'--shared',
'/tmp/tmp3f_1m6dk.cu',
'-lcuda',
'-gencode',
'arch=compute_120a,code=sm_120a',
'-I/mnt/d/0_work/share_with_ubuntu/tilelang/3rdparty/cutlass/include',
'-I/mnt/d/0_work/share_with_ubuntu/tilelang/3rdparty/../src',
'-o',
'/tmp/tmp3f_1m6dk.so']
更多推荐

所有评论(0)