Skip to content

gemm_v0 init parameter drops inner-loop conditions during lowering to Ascend C #170

@blueWatermelonFri

Description

@blueWatermelonFri

Description

When using T.gemm_v0(..., init=condition) in TileLang-ascend, if the init condition involves nested loop variables (e.g., (k == 0) and (j == 0) where k is outer and j is inner), the generated Ascend C code only preserves the outer-loop condition (k == 0) and silently drops the inner-loop part (j == 0).

tilelang code as follows:

import argparse

import tilelang
import tilelang.language as T
import torch

tilelang.cache.clear_cache()

parser = argparse.ArgumentParser(description="NPU Kernel Compilation")
parser.add_argument("--m", type=int, default=16, help="Matrix M dimension")
parser.add_argument("--n", type=int, default=8192, help="Matrix N dimension")
parser.add_argument("--k", type=int, default=8191, help="Matrix K dimension")
args = parser.parse_args()

M = args.m
N = args.n
K = args.k


@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, K_L1, dtype="float16", accum_dtype="float"):
    m_num = M // block_M
    n_num = N // block_N

    @T.prim_func
    def main(
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((K, N), dtype),
            C: T.Tensor((M, N), dtype),
    ):
        with T.Kernel(m_num * n_num, is_npu=True) as (cid, _):
            bx = cid // n_num
            by = cid % n_num

            A_L1 = T.alloc_L1((block_M, K_L1), dtype) # 16 * 64
            B_L1 = T.alloc_L1((K_L1, block_N), dtype) # 64 * 256

            C_L0 = T.alloc_L0C((block_M, block_N), accum_dtype) # 16 * 256

            with T.Scope("C"):
                loop_k = T.ceildiv(K, K_L1)
                for k in T.serial(loop_k):
                    for j in T.serial(8):
                        T.copy(A[bx * block_M, k * K_L1], A_L1)
                        T.copy(B[k * K_L1, by * block_N], B_L1)

                        T.barrier_all()
                        T.gemm_v0(A_L1, B_L1, C_L0, init=((k == 0) and (j == 0))) # <---------------------

                        T.barrier_all()

                T.copy(C_L0, C[bx * block_M, by * block_N])

    return main


func = matmul(M, N, K, 16, 256, 64)

print(func.get_kernel_source())


torch.manual_seed(0)

a = torch.randn(M, K).half().npu()
b = torch.randn(K, N).half().npu()
c = torch.empty(M, N).half().npu()
print("init successful!")

c = func(a, b)

ref_c = a @ b

torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel Output Match!")

The obtained Ascend C code is as follows:

tl::ascend::gemm_v0<half, float, 16, 256, 64, false, false>(A_L1[0], B_L1[0], C_L0[0], ascend_l0a, ascend_l0b, (k == 0));

And what I want is:

tl::ascend::gemm_v0<half, float, 16, 256, 64, false, false>(A_L1[0], B_L1[0], C_L0[0], ascend_l0a, ascend_l0b, (k == 0 && j ==0));

Is this behavior intentional?

If so, could the documentation clarify that init only supports conditions based on outermost loops ?

Target: Ascend 910B
Backend: Ascend C

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions