Skip to content

Commit 7275571

Browse files
committed
fix test
1 parent 55b9b05 commit 7275571

File tree

2 files changed

+17
-16
lines changed

2 files changed

+17
-16
lines changed

testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,19 @@ def test_lower_hopper_intrin_barrier():
2828
def before():
2929
with T.Kernel(8):
3030
_ = T.launch_thread("threadIdx.x", 128)
31-
T.create_list_of_mbarrier(128, 128, 128, 128)
31+
T.call_intrin("handle", tir.op.Op.get("tl.create_list_of_mbarrier"), 128, 128, 128, 128)
3232

3333
@T.prim_func
3434
def after():
3535
with T.Kernel(8):
3636
v_1 = T.launch_thread("threadIdx.x", 128)
37-
T.evaluate(tir.Call("handle", "tir.create_barriers", [4]))
38-
with T.If(v_1 == 0), T.Then():
39-
T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(0), 128]))
40-
T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(1), 128]))
41-
T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(2), 128]))
42-
T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(3), 128]))
37+
mbarrier = T.alloc_barrier([128, 128, 128, 128])
38+
with T.If(tir.Call("bool", tir.op.Op.get("tl.tl_shuffle_elect"), [0])), T.Then():
39+
T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.call_intrin("handle", tir.op.Op.get("tl.get_mbarrier"), 0), 128]))
40+
T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.call_intrin("handle", tir.op.Op.get("tl.get_mbarrier"), 1), 128]))
41+
T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.call_intrin("handle", tir.op.Op.get("tl.get_mbarrier"), 2), 128]))
42+
T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.call_intrin("handle", tir.op.Op.get("tl.get_mbarrier"), 3), 128]))
43+
T.evaluate(tir.Call("handle", tir.op.Op.get("tl.ptx_fence_barrier_init"), []))
4344
T.evaluate(tir.Call("handle", "tir.tvm_storage_sync", ["shared"]))
4445

4546
_check(before, after)

testing/python/transform/test_tilelang_transform_warp_specialized.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,45 +76,45 @@ def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)):
7676
A_shared = T.decl_buffer((3, 1, 8, 256), T.float16, scope="shared.dyn")
7777
B_shared = T.decl_buffer((3, 1, 4, 512), T.float16, scope="shared.dyn")
7878
C_local = T.decl_buffer((32,), scope="local")
79-
T.create_list_of_mbarrier(128, 128, 128, 128, 128, 128)
79+
T.call_intrin("handle", tir.op.Op.get("tl.create_list_of_mbarrier"), 128, 128, 128, 128, 128, 128)
8080
T.attr([128, 128], "kWarpSpecializationScope", 0)
8181
if v >= 128:
8282
T.set_max_nreg(24, 0)
8383
for k in range(16):
84-
T.mbarrier_wait_parity(T.get_mbarrier(k % 3 + 3), T.bitwise_xor(k // 3 % 2, 1))
84+
T.call_intrin("handle", tir.op.Op.get("tl.mbarrier_wait_parity"), T.call_intrin("handle", tir.op.Op.get("tl.get_mbarrier"), k % 3 + 3), T.bitwise_xor(k // 3 % 2, 1))
8585
if v - 128 == 0:
86-
T.mbarrier_expect_tx(T.get_mbarrier(k % 3), 4096)
86+
T.call_intrin("handle", tir.op.Op.get("tl.mbarrier_expect_tx"), T.call_intrin("handle", tir.op.Op.get("tl.get_mbarrier"), k % 3), 4096)
8787
if v - 128 == 0:
8888
T.tma_load(
8989
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0),
90-
T.get_mbarrier(k % 3),
90+
T.call_intrin("handle", tir.op.Op.get("tl.get_mbarrier"), k % 3),
9191
T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 2),
9292
k * 32,
9393
by * 64,
9494
)
9595
if v - 128 == 0:
96-
T.mbarrier_expect_tx(T.get_mbarrier(k % 3), 4096)
96+
T.call_intrin("handle", tir.op.Op.get("tl.mbarrier_expect_tx"), T.call_intrin("handle", tir.op.Op.get("tl.get_mbarrier"), k % 3), 4096)
9797
if v - 128 == 0:
9898
T.tma_load(
9999
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0),
100-
T.get_mbarrier(k % 3),
100+
T.call_intrin("handle", tir.op.Op.get("tl.get_mbarrier"), k % 3),
101101
T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 2),
102102
bx * 64,
103103
k * 32,
104104
)
105-
T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3)]))
105+
T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.call_intrin("handle", tir.op.Op.get("tl.get_mbarrier"), k % 3)]))
106106
else:
107107
T.set_max_nreg(240, 1)
108108
for k in range(16):
109-
T.mbarrier_wait_parity(T.get_mbarrier(k % 3), k // 3 % 2)
109+
T.call_intrin("handle", tir.op.Op.get("tl.mbarrier_wait_parity"), T.call_intrin("handle", tir.op.Op.get("tl.get_mbarrier"), k % 3), k // 3 % 2)
110110
T.call_extern(
111111
"handle",
112112
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
113113
T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 1),
114114
T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 1),
115115
T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3),
116116
)
117-
T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)]))
117+
T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.call_intrin("handle", tir.op.Op.get("tl.get_mbarrier"), k % 3 + 3)]))
118118

119119
_check(before, after)
120120

0 commit comments

Comments
 (0)