@@ -76,45 +76,45 @@ def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)):
7676 A_shared = T .decl_buffer ((3 , 1 , 8 , 256 ), T .float16 , scope = "shared.dyn" )
7777 B_shared = T .decl_buffer ((3 , 1 , 4 , 512 ), T .float16 , scope = "shared.dyn" )
7878 C_local = T .decl_buffer ((32 ,), scope = "local" )
79- T .create_list_of_mbarrier ( 128 , 128 , 128 , 128 , 128 , 128 )
79+ T .call_intrin ( "handle" , tir . op . Op . get ( "tl.create_list_of_mbarrier" ), 128 , 128 , 128 , 128 , 128 , 128 )
8080 T .attr ([128 , 128 ], "kWarpSpecializationScope" , 0 )
8181 if v >= 128 :
8282 T .set_max_nreg (24 , 0 )
8383 for k in range (16 ):
84- T .mbarrier_wait_parity ( T . get_mbarrier ( k % 3 + 3 ), T .bitwise_xor (k // 3 % 2 , 1 ))
84+ T .call_intrin ( "handle" , tir . op . Op . get ( "tl.mbarrier_wait_parity" ), T . call_intrin ( "handle" , tir . op . Op . get ( "tl.get_mbarrier" ), k % 3 + 3 ), T .bitwise_xor (k // 3 % 2 , 1 ))
8585 if v - 128 == 0 :
86- T .mbarrier_expect_tx ( T . get_mbarrier ( k % 3 ), 4096 )
86+ T .call_intrin ( "handle" , tir . op . Op . get ( "tl.mbarrier_expect_tx" ), T . call_intrin ( "handle" , tir . op . Op . get ( "tl.get_mbarrier" ), k % 3 ), 4096 )
8787 if v - 128 == 0 :
8888 T .tma_load (
8989 T .create_tma_descriptor (6 , 2 , A .data , 512 , 512 , 2 , 1024 , 32 , 64 , 1 , 1 , 0 , 2 , 2 , 0 ),
90- T .get_mbarrier ( k % 3 ),
90+ T .call_intrin ( "handle" , tir . op . Op . get ( "tl.get_mbarrier" ), k % 3 ),
9191 T .tvm_access_ptr (T .type_annotation (T .float16 ), A_shared .data , k % 3 * 2048 , 2048 , 2 ),
9292 k * 32 ,
9393 by * 64 ,
9494 )
9595 if v - 128 == 0 :
96- T .mbarrier_expect_tx ( T . get_mbarrier ( k % 3 ), 4096 )
96+ T .call_intrin ( "handle" , tir . op . Op . get ( "tl.mbarrier_expect_tx" ), T . call_intrin ( "handle" , tir . op . Op . get ( "tl.get_mbarrier" ), k % 3 ), 4096 )
9797 if v - 128 == 0 :
9898 T .tma_load (
9999 T .create_tma_descriptor (6 , 2 , B .data , 512 , 512 , 2 , 1024 , 64 , 32 , 1 , 1 , 0 , 3 , 2 , 0 ),
100- T .get_mbarrier ( k % 3 ),
100+ T .call_intrin ( "handle" , tir . op . Op . get ( "tl.get_mbarrier" ), k % 3 ),
101101 T .tvm_access_ptr (T .type_annotation (T .float16 ), B_shared .data , k % 3 * 2048 , 2048 , 2 ),
102102 bx * 64 ,
103103 k * 32 ,
104104 )
105- T .evaluate (tir .Call ("handle" , "tir.ptx_arrive_barrier" , [T .get_mbarrier ( k % 3 )]))
105+ T .evaluate (tir .Call ("handle" , "tir.ptx_arrive_barrier" , [T .call_intrin ( "handle" , tir . op . Op . get ( "tl.get_mbarrier" ), k % 3 )]))
106106 else :
107107 T .set_max_nreg (240 , 1 )
108108 for k in range (16 ):
109- T .mbarrier_wait_parity ( T . get_mbarrier ( k % 3 ), k // 3 % 2 )
109+ T .call_intrin ( "handle" , tir . op . Op . get ( "tl.mbarrier_wait_parity" ), T . call_intrin ( "handle" , tir . op . Op . get ( "tl.get_mbarrier" ), k % 3 ), k // 3 % 2 )
110110 T .call_extern (
111111 "handle" ,
112112 "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>" ,
113113 T .tvm_access_ptr (T .type_annotation (T .float16 ), A_shared .data , k % 3 * 2048 , 2048 , 1 ),
114114 T .tvm_access_ptr (T .type_annotation (T .float16 ), B_shared .data , k % 3 * 2048 , 2048 , 1 ),
115115 T .tvm_access_ptr (T .type_annotation (T .float32 ), C_local .data , 0 , 32 , 3 ),
116116 )
117- T .evaluate (tir .Call ("handle" , "tir.ptx_arrive_barrier" , [T .get_mbarrier ( k % 3 + 3 )]))
117+ T .evaluate (tir .Call ("handle" , "tir.ptx_arrive_barrier" , [T .call_intrin ( "handle" , tir . op . Op . get ( "tl.get_mbarrier" ), k % 3 + 3 )]))
118118
119119 _check (before , after )
120120
0 commit comments