Skip to content

T.Pipelined is not compatible with the TL_ASCEND_AUTO_CV_COMBINE switch #162

@liggest

Description

@liggest

examples/pipeline/matmul_add_pipeline.py will fail when the TL_ASCEND_AUTO_CV_COMBINE switch is additionally turned on.

The modified code
import argparse

import tilelang
import tilelang.language as T
import torch

tilelang.cache.clear_cache()

parser = argparse.ArgumentParser(description="NPU Kernel Compilation")
parser.add_argument("--m", type=int, default=1024, help="Matrix M dimension")
parser.add_argument("--n", type=int, default=1024, help="Matrix N dimension")
parser.add_argument("--k", type=int, default=1024, help="Matrix K dimension")
args = parser.parse_args()

M = args.m
N = args.n
K = args.k


@tilelang.jit(
    out_idx=[-2],
    pass_configs={
        # tilelang.PassConfigKey.TL_ASCEND_AUTO_CV_SYNC: True,
        tilelang.PassConfigKey.TL_ASCEND_AUTO_CV_COMBINE: True,  # turn on the switch
    }
)
def matmul_add(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
    m_num = M // block_M
    n_num = N // block_N

    VEC_NUM = 2
    vec_proc = 4

    @T.prim_func
    def main(
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((K, N), dtype),
            C: T.Tensor((M, N), dtype),
            D: T.Tensor((M, N), dtype),
    ):
        with T.Kernel(m_num * n_num, is_npu=True) as (cid, vid):
            bx = cid // n_num
            by = cid % n_num
            A_L1 = T.alloc_L1((block_M, block_K), dtype)
            B_L1 = T.alloc_L1((block_K, block_N), dtype)

            C_L0 = T.alloc_L0C((block_M, block_N), accum_dtype)

            c_ub = T.alloc_ub((block_M // VEC_NUM, block_N // vec_proc), dtype)
            d_ub = T.alloc_ub((block_M // VEC_NUM, block_N // vec_proc), dtype)
            e_ub = T.alloc_ub((block_M // VEC_NUM, block_N // vec_proc), dtype)

            # with T.Scope("C"):    # remove with T.Scope
            loop_k = T.ceildiv(K, block_K)
            for k in T.Pipelined(loop_k, num_stages=3):
                T.copy(A[bx * block_M, k * block_K], A_L1)
                T.copy(B[k * block_K, by * block_N], B_L1)

                T.barrier_all()
                if k == 0:
                    T.gemm_v0(A_L1, B_L1, C_L0, init=True)
                else:
                    T.gemm_v0(A_L1, B_L1, C_L0)

                T.barrier_all()

            T.copy(C_L0, C[bx * block_M, by * block_N])

            T.set_cross_flag("FIX", 0)

            # with T.Scope("V"):    # remove with T.Scope
            T.wait_cross_flag(0)

            for i in T.Pipelined(vec_proc, num_stages=2):
                T.copy(C[bx * block_M + vid * block_M // VEC_NUM, by * block_N + i * block_N // vec_proc], c_ub)
                T.copy(D[bx * block_M + vid * block_M // VEC_NUM, by * block_N + i * block_N // vec_proc], d_ub)

                T.barrier_all()
                T.tile.add(e_ub, c_ub, d_ub)
                T.barrier_all()

                T.copy(e_ub, C[bx * block_M + vid * block_M // VEC_NUM, by * block_N + i * block_N // vec_proc])
                T.barrier_all()

    return main


func = matmul_add(M, N, K, 128, 256, 64)

torch.manual_seed(0)

a = torch.randn(M, K).half().npu()
b = torch.randn(K, N).half().npu()
d = torch.randn(M, N).half().npu()
print("init successful!")

c = func(a, b, d)

ref_c = a @ b + d

torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel Output Match!")

Here's the error message:

> python matmul_add_pipeline.py 
Traceback (most recent call last):
  File " .../tilelang-ascend/examples/pipeline/matmul_add_pipeline_cv.py", line 88, in <module>
    func = matmul_add(M, N, K, 128, 256, 64)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File " .../tilelang-ascend/tilelang/jit/__init__.py", line 204, in wrapper
    kernel_result = compile(
                    ^^^^^^^^
  File " .../tilelang-ascend/tilelang/jit/__init__.py", line 76, in compile
    return cached(
           ^^^^^^^
  File " .../tilelang-ascend/tilelang/cache/__init__.py", line 31, in cached
    return _kernel_cache_instance.cached(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File " .../tilelang-ascend/tilelang/cache/kernel_cache.py", line 174, in cached
    kernel = JITKernel(
             ^^^^^^^^^^
  File " .../tilelang-ascend/tilelang/jit/kernel.py", line 112, in __init__
    adapter = self._compile_and_create_adapter(func, out_idx, workspace_idx)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File " .../tilelang-ascend/tilelang/jit/kernel.py", line 216, in _compile_and_create_adapter
    artifact = tilelang.lower(
               ^^^^^^^^^^^^^^^
  File " .../tilelang-ascend/tilelang/engine/lower.py", line 223, in lower
    mod = OptimizeForTarget(mod, target)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File " .../tilelang-ascend/tilelang/engine/phase.py", line 81, in OptimizeForTarget
    mod = tilelang.transform.PipelinePlanning()(mod)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File " .../tilelang-ascend/tilelang/../3rdparty/tvm/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File " .../tilelang-ascend/tilelang/../3rdparty/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File " .../tilelang-ascend/tilelang/../3rdparty/tvm/python/tvm/_ffi/base.py", line 465, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  30: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}>(tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::allocator<char>, tvm::runtime::TVMArgs const&)
  29: tvm::transform::Pass::operator()(tvm::IRModule) const
  28: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  27: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  26: _ZN3tvm7runtime13PackedF
  25: tvm::runtime::TypedPackedFunc<tvm::tir::PrimFunc (tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::tl::PipelinePlanning()::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::tl::PipelinePlanning()::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  24: tvm::tl::PipelinePlanner::Substitute(tvm::tir::PrimFunc const&)
  23: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTable
  22: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  21: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTable
  20: tvm::tl::PipelinePlanner::VisitStmt_(tvm::tir::BlockNode const*)
  19: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  18: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTable
  17: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  16: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTable
  15: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  14: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTable
  13: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  12: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTable
  11: tvm::tl::PipelinePlanner::VisitStmt_(tvm::tir::BlockNode const*)
  10: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  9: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTable
  8: tvm::runtime::ObjectPtr<tvm::runtime::Object> tvm::runtime::Array<tvm::tir::Stmt, void>::MapHelper<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::runtime::Object, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
  7: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  6: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTable
  5: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  4: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTable
  3: tvm::runtime::ObjectPtr<tvm::runtime::Object> tvm::runtime::Array<tvm::tir::Stmt, void>::MapHelper<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::runtime::Object, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
  2: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  1: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTable
  0: tvm::tl::PipelinePlanner::VisitStmt_(tvm::tir::ForNode const*)
  File " .../tilelang-ascend/src/transform/pipeline_planning.cc", line 404
Pipeline_Planning: Can't handle the body of the loop because it is not a SeqStmt or IfThenElse
[ERROR] 2025-12-23-21:56:31 (PID:454121, Device:-1, RankID:-1) ERR99999 UNKNOWN applicaiton exception

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions