-
Notifications
You must be signed in to change notification settings - Fork 52
Open
Description
examples/pipeline/matmul_add_pipeline.py will fail when the TL_ASCEND_AUTO_CV_COMBINE switch is additionally turned on.
The modified code
import argparse
import tilelang
import tilelang.language as T
import torch
tilelang.cache.clear_cache()
parser = argparse.ArgumentParser(description="NPU Kernel Compilation")
parser.add_argument("--m", type=int, default=1024, help="Matrix M dimension")
parser.add_argument("--n", type=int, default=1024, help="Matrix N dimension")
parser.add_argument("--k", type=int, default=1024, help="Matrix K dimension")
args = parser.parse_args()
M = args.m
N = args.n
K = args.k
@tilelang.jit(
out_idx=[-2],
pass_configs={
# tilelang.PassConfigKey.TL_ASCEND_AUTO_CV_SYNC: True,
tilelang.PassConfigKey.TL_ASCEND_AUTO_CV_COMBINE: True, # turn on the switch
}
)
def matmul_add(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
m_num = M // block_M
n_num = N // block_N
VEC_NUM = 2
vec_proc = 4
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
D: T.Tensor((M, N), dtype),
):
with T.Kernel(m_num * n_num, is_npu=True) as (cid, vid):
bx = cid // n_num
by = cid % n_num
A_L1 = T.alloc_L1((block_M, block_K), dtype)
B_L1 = T.alloc_L1((block_K, block_N), dtype)
C_L0 = T.alloc_L0C((block_M, block_N), accum_dtype)
c_ub = T.alloc_ub((block_M // VEC_NUM, block_N // vec_proc), dtype)
d_ub = T.alloc_ub((block_M // VEC_NUM, block_N // vec_proc), dtype)
e_ub = T.alloc_ub((block_M // VEC_NUM, block_N // vec_proc), dtype)
# with T.Scope("C"): # remove with T.Scope
loop_k = T.ceildiv(K, block_K)
for k in T.Pipelined(loop_k, num_stages=3):
T.copy(A[bx * block_M, k * block_K], A_L1)
T.copy(B[k * block_K, by * block_N], B_L1)
T.barrier_all()
if k == 0:
T.gemm_v0(A_L1, B_L1, C_L0, init=True)
else:
T.gemm_v0(A_L1, B_L1, C_L0)
T.barrier_all()
T.copy(C_L0, C[bx * block_M, by * block_N])
T.set_cross_flag("FIX", 0)
# with T.Scope("V"): # remove with T.Scope
T.wait_cross_flag(0)
for i in T.Pipelined(vec_proc, num_stages=2):
T.copy(C[bx * block_M + vid * block_M // VEC_NUM, by * block_N + i * block_N // vec_proc], c_ub)
T.copy(D[bx * block_M + vid * block_M // VEC_NUM, by * block_N + i * block_N // vec_proc], d_ub)
T.barrier_all()
T.tile.add(e_ub, c_ub, d_ub)
T.barrier_all()
T.copy(e_ub, C[bx * block_M + vid * block_M // VEC_NUM, by * block_N + i * block_N // vec_proc])
T.barrier_all()
return main
func = matmul_add(M, N, K, 128, 256, 64)
torch.manual_seed(0)
a = torch.randn(M, K).half().npu()
b = torch.randn(K, N).half().npu()
d = torch.randn(M, N).half().npu()
print("init successful!")
c = func(a, b, d)
ref_c = a @ b + d
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel Output Match!")Here's the error message:
> python matmul_add_pipeline.py
Traceback (most recent call last):
File " .../tilelang-ascend/examples/pipeline/matmul_add_pipeline_cv.py", line 88, in <module>
func = matmul_add(M, N, K, 128, 256, 64)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File " .../tilelang-ascend/tilelang/jit/__init__.py", line 204, in wrapper
kernel_result = compile(
^^^^^^^^
File " .../tilelang-ascend/tilelang/jit/__init__.py", line 76, in compile
return cached(
^^^^^^^
File " .../tilelang-ascend/tilelang/cache/__init__.py", line 31, in cached
return _kernel_cache_instance.cached(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File " .../tilelang-ascend/tilelang/cache/kernel_cache.py", line 174, in cached
kernel = JITKernel(
^^^^^^^^^^
File " .../tilelang-ascend/tilelang/jit/kernel.py", line 112, in __init__
adapter = self._compile_and_create_adapter(func, out_idx, workspace_idx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File " .../tilelang-ascend/tilelang/jit/kernel.py", line 216, in _compile_and_create_adapter
artifact = tilelang.lower(
^^^^^^^^^^^^^^^
File " .../tilelang-ascend/tilelang/engine/lower.py", line 223, in lower
mod = OptimizeForTarget(mod, target)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File " .../tilelang-ascend/tilelang/engine/phase.py", line 81, in OptimizeForTarget
mod = tilelang.transform.PipelinePlanning()(mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File " .../tilelang-ascend/tilelang/../3rdparty/tvm/python/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File " .../tilelang-ascend/tilelang/../3rdparty/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File " .../tilelang-ascend/tilelang/../3rdparty/tvm/python/tvm/_ffi/base.py", line 465, in raise_last_ffi_error
raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
30: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}>(tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::allocator<char>, tvm::runtime::TVMArgs const&)
29: tvm::transform::Pass::operator()(tvm::IRModule) const
28: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
27: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
26: _ZN3tvm7runtime13PackedF
25: tvm::runtime::TypedPackedFunc<tvm::tir::PrimFunc (tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::tl::PipelinePlanning()::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::tl::PipelinePlanning()::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
24: tvm::tl::PipelinePlanner::Substitute(tvm::tir::PrimFunc const&)
23: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTable
22: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
21: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTable
20: tvm::tl::PipelinePlanner::VisitStmt_(tvm::tir::BlockNode const*)
19: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
18: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTable
17: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
16: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTable
15: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
14: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTable
13: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
12: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTable
11: tvm::tl::PipelinePlanner::VisitStmt_(tvm::tir::BlockNode const*)
10: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
9: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTable
8: tvm::runtime::ObjectPtr<tvm::runtime::Object> tvm::runtime::Array<tvm::tir::Stmt, void>::MapHelper<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::runtime::Object, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
7: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
6: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTable
5: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
4: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTable
3: tvm::runtime::ObjectPtr<tvm::runtime::Object> tvm::runtime::Array<tvm::tir::Stmt, void>::MapHelper<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::runtime::Object, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
2: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
1: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTable
0: tvm::tl::PipelinePlanner::VisitStmt_(tvm::tir::ForNode const*)
File " .../tilelang-ascend/src/transform/pipeline_planning.cc", line 404
Pipeline_Planning: Can't handle the body of the loop because it is not a SeqStmt or IfThenElse
[ERROR] 2025-12-23-21:56:31 (PID:454121, Device:-1, RankID:-1) ERR99999 UNKNOWN applicaiton exceptionMetadata
Metadata
Assignees
Labels
No labels