-
Notifications
You must be signed in to change notification settings - Fork 52
Open
Description
Description
When using T.gemm_v0(..., init=condition) in TileLang-ascend, if the init condition involves nested loop variables (e.g., (k == 0) and (j == 0) where k is outer and j is inner), the generated Ascend C code only preserves the outer-loop condition (k == 0) and silently drops the inner-loop part (j == 0).
tilelang code as follows:
import argparse
import tilelang
import tilelang.language as T
import torch
tilelang.cache.clear_cache()
parser = argparse.ArgumentParser(description="NPU Kernel Compilation")
parser.add_argument("--m", type=int, default=16, help="Matrix M dimension")
parser.add_argument("--n", type=int, default=8192, help="Matrix N dimension")
parser.add_argument("--k", type=int, default=8191, help="Matrix K dimension")
args = parser.parse_args()
M = args.m
N = args.n
K = args.k
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, K_L1, dtype="float16", accum_dtype="float"):
m_num = M // block_M
n_num = N // block_N
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(m_num * n_num, is_npu=True) as (cid, _):
bx = cid // n_num
by = cid % n_num
A_L1 = T.alloc_L1((block_M, K_L1), dtype) # 16 * 64
B_L1 = T.alloc_L1((K_L1, block_N), dtype) # 64 * 256
C_L0 = T.alloc_L0C((block_M, block_N), accum_dtype) # 16 * 256
with T.Scope("C"):
loop_k = T.ceildiv(K, K_L1)
for k in T.serial(loop_k):
for j in T.serial(8):
T.copy(A[bx * block_M, k * K_L1], A_L1)
T.copy(B[k * K_L1, by * block_N], B_L1)
T.barrier_all()
T.gemm_v0(A_L1, B_L1, C_L0, init=((k == 0) and (j == 0))) # <---------------------
T.barrier_all()
T.copy(C_L0, C[bx * block_M, by * block_N])
return main
func = matmul(M, N, K, 16, 256, 64)
print(func.get_kernel_source())
torch.manual_seed(0)
a = torch.randn(M, K).half().npu()
b = torch.randn(K, N).half().npu()
c = torch.empty(M, N).half().npu()
print("init successful!")
c = func(a, b)
ref_c = a @ b
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel Output Match!")The obtained Ascend C code is as follows:
tl::ascend::gemm_v0<half, float, 16, 256, 64, false, false>(A_L1[0], B_L1[0], C_L0[0], ascend_l0a, ascend_l0b, (k == 0));And what I want is:
tl::ascend::gemm_v0<half, float, 16, 256, 64, false, false>(A_L1[0], B_L1[0], C_L0[0], ascend_l0a, ascend_l0b, (k == 0 && j ==0));Is this behavior intentional?
If so, could the documentation clarify that init only supports conditions based on outermost loops ?
Target: Ascend 910B
Backend: Ascend C
Metadata
Metadata
Assignees
Labels
No labels