Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
2a67387
[Feat] PDL Support
w169q169 Dec 21, 2025
4898b56
remove unused comments/codes
w169q169 Dec 22, 2025
f50aa99
fix mismatch funcname
w169q169 Dec 22, 2025
b846450
Resolve duplicate config issue for AI review
w169q169 Dec 22, 2025
cb052f4
fix typo error && clang-tidy error
w169q169 Dec 22, 2025
8ed4124
chore: retrigger CI
w169q169 Dec 23, 2025
466c94a
change file comment
w169q169 Dec 23, 2025
05b8dc0
rename test_tilelang_jit_ctypes.py to test_tilelang_jit_cython.py
w169q169 Dec 23, 2025
62ed58f
remove ctype backend
w169q169 Dec 24, 2025
46fe83f
Move pdl attributes to header and remove redundant VisitStmt_
w169q169 Dec 24, 2025
898135e
chore: trigger CI
w169q169 Dec 24, 2025
1c75231
Merge branch 'main' of https://github.com/tile-ai/tilelang
silentCoder-dev Dec 25, 2025
7ed6748
modify the cuda kernel with pdl due to nvcc's bug
silentCoder-dev Dec 25, 2025
8978b11
throw an error when invoking __ldg with pdl due to nvcc's bug
silentCoder-dev Dec 25, 2025
af5f26d
Throw an error when pdl is not supported
silentCoder-dev Dec 25, 2025
a84034b
fix checking about pdl
silentCoder-dev Dec 25, 2025
c123fd7
add comments about MarkCudaSyncCalls
silentCoder-dev Dec 25, 2025
e0f058f
remove pdl support with compute_capability<90
silentCoder-dev Dec 26, 2025
017d314
Merge branch 'main' of https://github.com/tile-ai/tilelang
silentCoder-dev Dec 26, 2025
877962b
ruff check
silentCoder-dev Dec 26, 2025
8ad39ab
fix nvrtc about pdl
w169q169 Dec 26, 2025
22455f2
Merge branch 'w169q169/main' into HEAD
w169q169 Dec 26, 2025
2f83ea7
update tvm & extend pdl support for tvm_ffi & add test for tvm_ffi
silentCoder-dev Dec 26, 2025
24117e6
ensure that kUseDynamicSharedMemoryTag is the last tag in launch_para…
silentCoder-dev Dec 29, 2025
c1cdf7c
Merge remote-tracking branch 'upstream'
silentCoder-dev Dec 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/programming_guides/instructions.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ Annotation helpers
- `T.annotate_safe_value(var, ...)`: Safety/const hints.
- `T.annotate_l2_hit_ratio(buf, ratio)`: Cache behavior hint.

Synchronization helpers
- `T.pdl_trigger()`: Signal programmatic launch completion for the current kernel.
- `T.pdl_sync()`: Wait until kernel dependencies are satisfied.

Atomics
- `T.atomic_add(dst, value, memory_order=None, return_prev=False, use_tma=False)`.
- `T.atomic_addx2(dst, value, return_prev=False)`; `T.atomic_addx4(...)`.
Expand Down
154 changes: 154 additions & 0 deletions src/transform/lower_pdl.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*!
* \file lower_pdl.cc
* \brief Mark Device PrimFunc with attributes if CUDA sync functions are called
*/

#include "../op/builtin.h"
#include "../target/utils.h"
#include "tvm/ir/type.h"
#include "tvm/tir/builtin.h"
#include "tvm/tir/expr.h"
#include "tvm/tir/stmt.h"
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

namespace tvm {
namespace tl {

namespace attr {
// Attributes to mark CUDA sync calls
constexpr const char *kHasTriggerLaunch = "has_cuda_pdl_trigger";
constexpr const char *kHasGridSync = "has_cuda_pdl_sync";
} // namespace attr

using namespace tir;

class MarkCudaSyncCalls : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc f) {
MarkCudaSyncCalls mutator;
PrimFunc new_f = f;
new_f.CopyOnWrite()->body = mutator.VisitStmt(f->body);

if (mutator.has_trigger_launch_) {
new_f = WithAttr(std::move(new_f), attr::kHasTriggerLaunch, 1);
}
if (mutator.has_grid_sync_) {
new_f = WithAttr(std::move(new_f), attr::kHasGridSync, 1);
}
return new_f;
}

Stmt VisitStmt_(const EvaluateNode *op) final {
if (const auto *call = op->value.as<tir::CallNode>()) {
CheckCall(call);
}
return StmtExprMutator::VisitStmt_(op);
}

PrimExpr VisitExpr_(const tir::CallNode *op) final {
CheckCall(op);
return StmtExprMutator::VisitExpr_(op);
}

private:
void CheckCall(const tir::CallNode *call) {
if (!call)
return;
if (call->op.same_as(builtin::call_extern())) {
if (!call->args.empty()) {
if (const auto *str_node =
call->args[0].as<tvm::tir::StringImmNode>()) {
std::string func_name = str_node->value;
if (func_name == "cudaTriggerProgrammaticLaunchCompletion") {
has_trigger_launch_ = true;
} else if (func_name == "cudaGridDependencySynchronize") {
has_grid_sync_ = true;
}
}
}
}
}

private:
bool has_trigger_launch_ = false;
bool has_grid_sync_ = false;

MarkCudaSyncCalls() = default;
};

class EliminateCudaSyncCalls : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc f) {
EliminateCudaSyncCalls mutator;
PrimFunc new_f = f;
new_f.CopyOnWrite()->body = mutator.VisitStmt(f->body);

return new_f;
}

Stmt VisitStmt_(const EvaluateNode *op) final {
if (const auto *call = op->value.as<tir::CallNode>()) {
if (CheckCall(call)) {
return Evaluate(make_zero(call->dtype));
}
}
return StmtExprMutator::VisitStmt_(op);
}

PrimExpr VisitExpr_(const tir::CallNode *op) final {
if (CheckCall(op)) {
return make_zero(op->dtype);
}

return StmtExprMutator::VisitExpr_(op);
}

private:
bool CheckCall(const tir::CallNode *call) {
if (!call)
return false;

if (call->op.same_as(builtin::call_extern())) {
if (!call->args.empty()) {
if (const auto *str_node =
call->args[0].as<tvm::tir::StringImmNode>()) {
std::string func_name = str_node->value;
if (func_name == "cudaTriggerProgrammaticLaunchCompletion") {
return true;
} else if (func_name == "cudaGridDependencySynchronize") {
return true;
}
}
}
}

return false;
}

private:
EliminateCudaSyncCalls() = default;
};

using namespace tir::transform;

tvm::transform::Pass MarkCudaSyncCallsPass(bool have_pdl) {
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return have_pdl ? MarkCudaSyncCalls::Substitute(f)
: EliminateCudaSyncCalls::Substitute(f);
};

return CreatePrimFuncPass(pass_func, 0, "tl.MarkCudaSyncCalls", {});
}

TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MarkCudaSyncCalls",
MarkCudaSyncCallsPass);
}

} // namespace tl
} // namespace tvm
59 changes: 59 additions & 0 deletions testing/python/jit/test_tilelang_jit_ctypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from tilelang import tvm as tvm
import tilelang.language as T
import tilelang.testing
import tilelang
import torch


def test_ctypes_pdl():
"""Test pdl."""

N = 64

@tilelang.jit(execution_backend="ctypes")
def multi_kernels_with_pdl(N, block_size=256, dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
C: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as (bx,):
for i in T.Parallel(block_size):
idx = bx * block_size + i
if idx < N:
B[idx] = A[idx] + 1.0
T.pdl_trigger()

with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as (bx2,):
T.pdl_sync()
for i in T.Parallel(block_size):
idx = bx2 * block_size + i
if idx < N:
C[idx] = B[idx] * 2.0

return main

# Compile the kernel
kernel = multi_kernels_with_pdl(N)

# Create test tensors
a = torch.randn(N, dtype=torch.float32).cuda()
b = torch.randn(N, dtype=torch.float32).cuda()
c = torch.randn(N, dtype=torch.float32).cuda()

ref_b = a + 1.0
ref_c = ref_b * 2.0

kernel(a, b, c)

# Verify correctness

tilelang.testing.torch_assert_close(b, ref_b, atol=1e-5, rtol=1e-5)
tilelang.testing.torch_assert_close(c, ref_c, atol=1e-5, rtol=1e-5)

print("pdl test passed!")


if __name__ == "__main__":
tilelang.testing.main()
50 changes: 50 additions & 0 deletions testing/python/jit/test_tilelang_jit_nvrtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,5 +502,55 @@ def kernel(
print("L2 persistent map test passed!")


def test_nvrtc_pdl():
"""Test pdl."""

N = 64

@tilelang.jit(execution_backend="nvrtc")
def multi_kernels_with_pdl(N, block_size=256, dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
C: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as (bx,):
for i in T.Parallel(block_size):
idx = bx * block_size + i
if idx < N:
B[idx] = A[idx] + 1.0
T.pdl_trigger()

with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as (bx2,):
T.pdl_sync()
for i in T.Parallel(block_size):
idx = bx2 * block_size + i
if idx < N:
C[idx] = B[idx] * 2.0

return main

# Compile the kernel
kernel = multi_kernels_with_pdl(N)

# Create test tensors
a = torch.randn(N, dtype=torch.float32).cuda()
b = torch.randn(N, dtype=torch.float32).cuda()
c = torch.randn(N, dtype=torch.float32).cuda()

ref_b = a + 1.0
ref_c = ref_b * 2.0

kernel(a, b, c)

# Verify correctness

tilelang.testing.torch_assert_close(b, ref_b, atol=1e-5, rtol=1e-5)
tilelang.testing.torch_assert_close(c, ref_c, atol=1e-5, rtol=1e-5)

print("pdl test passed!")


if __name__ == "__main__":
tilelang.testing.main()
64 changes: 64 additions & 0 deletions testing/python/language/test_tilelang_language_pdl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import tilelang.testing
import tilelang.language as T


def kernels_with_pdl_trigger(N, block_size=256, dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as (bx,):
for i in T.Parallel(block_size):
idx = bx * block_size + i
if idx < N:
B[idx] = A[idx] + 1.0
T.pdl_trigger()

return main


def kernels_with_pdl_sync(N, block_size=256, dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as (bx2,):
T.pdl_sync()
for i in T.Parallel(block_size):
idx = bx2 * block_size + i
if idx < N:
B[idx] = A[idx] * 2.0

return main


def test_pdl_trigger():
N = 64
program = kernels_with_pdl_trigger(N)

pdl_kernel = tilelang.compile(program, target="cuda -arch=sm_90")
code = pdl_kernel.get_kernel_source()
assert "cudaTriggerProgrammaticLaunchCompletion" in code

old_kernel = tilelang.compile(program, target="cuda -arch=sm_75")
code = old_kernel.get_kernel_source()
assert "cudaTriggerProgrammaticLaunchCompletion" not in code


def test_pdl_sync():
N = 64
program = kernels_with_pdl_sync(N)

pdl_kernel = tilelang.compile(program, target="cuda -arch=sm_90")
code = pdl_kernel.get_kernel_source()
assert "cudaGridDependencySynchronize" in code

old_kernel = tilelang.compile(program, target="cuda -arch=sm_75")
code = old_kernel.get_kernel_source()
assert "cudaGridDependencySynchronize" not in code


if __name__ == "__main__":
tilelang.testing.main()
8 changes: 8 additions & 0 deletions tilelang/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,14 @@ def is_hopper(target):
return major == 9 and minor == 0


def have_pdl(target):
if target.kind.name != "cuda":
return False
compute_version = get_target_compute_version(target)
major, minor = parse_compute_version(compute_version)
return major >= 9


def get_nvcc_compiler() -> str:
"""Get the path to the nvcc compiler"""
return os.path.join(find_cuda_path(), "bin", "nvcc")
5 changes: 4 additions & 1 deletion tilelang/engine/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from tvm.target import Target
import tilelang
from tilelang.transform import PassContext
from tilelang.contrib.nvcc import have_tma, is_hopper
from tilelang.contrib.nvcc import have_tma, is_hopper, have_pdl


def allow_warp_specialized(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool:
Expand Down Expand Up @@ -252,6 +252,9 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.ThreadSync("global")(mod)
mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tilelang.transform.SplitHostDevice()(mod)

mod = tilelang.transform.MarkCudaSyncCalls(have_pdl(target))(mod)

mod = tilelang.transform.AnnotateReadOnlyParams()(mod)
# MergeSharedMemoryAllocations must be applied after SplitHostDevice
# because the merged allocation site is at the beginning of each device function
Expand Down
Loading
Loading