Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
2a67387
[Feat] PDL Support
w169q169 Dec 21, 2025
4898b56
remove unused comments/codes
w169q169 Dec 22, 2025
f50aa99
fix mismatch funcname
w169q169 Dec 22, 2025
b846450
Resolve duplicate config issue for AI review
w169q169 Dec 22, 2025
cb052f4
fix typo error && clang-tidy error
w169q169 Dec 22, 2025
8ed4124
chore: retrigger CI
w169q169 Dec 23, 2025
466c94a
change file comment
w169q169 Dec 23, 2025
05b8dc0
rename test_tilelang_jit_ctypes.py to test_tilelang_jit_cython.py
w169q169 Dec 23, 2025
62ed58f
remove ctype backend
w169q169 Dec 24, 2025
46fe83f
Move pdl attributes to header and remove redundant VisitStmt_
w169q169 Dec 24, 2025
898135e
chore: trigger CI
w169q169 Dec 24, 2025
1c75231
Merge branch 'main' of https://github.com/tile-ai/tilelang
silentCoder-dev Dec 25, 2025
7ed6748
modify the cuda kernel with pdl due to nvcc's bug
silentCoder-dev Dec 25, 2025
8978b11
throw an error when invoking __ldg with pdl due to nvcc's bug
silentCoder-dev Dec 25, 2025
af5f26d
Throw an error when pdl is not supported
silentCoder-dev Dec 25, 2025
a84034b
fix checking about pdl
silentCoder-dev Dec 25, 2025
c123fd7
add comments about MarkCudaSyncCalls
silentCoder-dev Dec 25, 2025
e0f058f
remove pdl support with compute_capability<90
silentCoder-dev Dec 26, 2025
017d314
Merge branch 'main' of https://github.com/tile-ai/tilelang
silentCoder-dev Dec 26, 2025
877962b
ruff check
silentCoder-dev Dec 26, 2025
8ad39ab
fix nvrtc about pdl
w169q169 Dec 26, 2025
22455f2
Merge branch 'w169q169/main' into HEAD
w169q169 Dec 26, 2025
2f83ea7
update tvm & extend pdl support for tvm_ffi & add test for tvm_ffi
silentCoder-dev Dec 26, 2025
24117e6
ensure that kUseDynamicSharedMemoryTag is the last tag in launch_para…
silentCoder-dev Dec 29, 2025
c1cdf7c
Merge remote-tracking branch 'upstream'
silentCoder-dev Dec 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/programming_guides/instructions.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ Annotation helpers
- `T.annotate_safe_value(var, ...)`: Safety/const hints.
- `T.annotate_l2_hit_ratio(buf, ratio)`: Cache behavior hint.

Synchronization helpers
- `T.pdl_trigger()`: Signal programmatic launch completion for the current kernel.
- `T.pdl_sync()`: Wait until kernel dependencies are satisfied.

Atomics
- `T.atomic_add(dst, value, memory_order=None, return_prev=False, use_tma=False)`.
- `T.atomic_addx2(dst, value, return_prev=False)`; `T.atomic_addx4(...)`.
Expand Down
11 changes: 9 additions & 2 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <vector>

#include "../op/builtin.h"
#include "../transform/common/attr.h"
#include "./ptx.h"
#include "./utils.h"
#include "arith/pattern_match.h"
Expand Down Expand Up @@ -3332,6 +3333,9 @@ void CodeGenTileLangCUDA::PrintFunctionSignature(const String &function_name,
CodeGenC::PrintType(func->ret_type, os);
CodeGenC::PrintExtraAttrs(func, os);
bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias);
// NVCC has issues with __restrict__ on kernel parameters when using PDL (Programmatic Dependent Launch)
// synchronization. Suppress the annotation when kHasGridSync is set.
bool has_cuda_pdl_sync = func->HasNonzeroAttr(tl::attr::kHasGridSync);
std::unordered_set<const VarNode *> non_restrict;
if (auto opt =
func->GetAttr<ffi::Array<tir::Var>>(tl::attr::kNonRestrictParams)) {
Expand Down Expand Up @@ -3381,7 +3385,7 @@ void CodeGenTileLangCUDA::PrintFunctionSignature(const String &function_name,
}
}

if (no_alias && !non_restrict.count(v.get())) {
if (!has_cuda_pdl_sync && no_alias && !non_restrict.count(v.get())) {
PrintRestrict(v, os);
}
} else {
Expand Down Expand Up @@ -3417,6 +3421,9 @@ void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar,
ICHECK(global_symbol)
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
// NVCC has issues with __restrict__ on kernel parameters when using PDL (Programmatic Dependent Launch)
// synchronization. Suppress the annotation when kHasGridSync is set.
bool has_cuda_pdl_sync = f->HasNonzeroAttr(tl::attr::kHasGridSync);
std::unordered_set<const VarNode *> non_restrict;
if (auto opt =
f->GetAttr<ffi::Array<tir::Var>>(tl::attr::kNonRestrictParams)) {
Expand Down Expand Up @@ -3468,7 +3475,7 @@ void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar,
}
}

if (no_alias && !non_restrict.count(v.get())) {
if (!has_cuda_pdl_sync && no_alias && !non_restrict.count(v.get())) {
PrintRestrict(v, stream);
}
} else {
Expand Down
6 changes: 6 additions & 0 deletions src/transform/common/attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,11 @@ constexpr const char *MainBlockName = "tilelang_root";
constexpr const char *tilelang_is_cpu_kernel_frame =
"tilelang.is_cpu_kernel_frame";

namespace attr {
// Attributes to mark CUDA sync calls
constexpr const char *kHasTriggerLaunch = "has_cuda_pdl_trigger";
constexpr const char *kHasGridSync = "has_cuda_pdl_sync";
} // namespace attr

} // namespace tl
} // namespace tvm
147 changes: 147 additions & 0 deletions src/transform/lower_pdl.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*!
* \file lower_pdl.cc
* \brief Mark Device PrimFunc with attributes if CUDA PDL functions are called
*/

#include "../op/builtin.h"
#include "../target/utils.h"
#include "common/attr.h"
#include "tvm/ir/type.h"
#include "tvm/tir/builtin.h"
#include "tvm/tir/expr.h"
#include "tvm/tir/stmt.h"
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

namespace tvm {
namespace tl {

using namespace tir;

// NVCC has issues with __ldg when using PDL (Programmatic Dependent Launch)
// synchronization. Suppress the annotation when kHasGridSync is set.
class CheckLDGCalls : public StmtExprVisitor {
public:
void VisitExpr_(const tir::CallNode *op) final {
if (op->op.same_as(tl::__ldg())) {
LOG(FATAL) << "Cannot invoke __ldg function with pdl_sync";
}
StmtExprVisitor::VisitExpr_(op);
}
};

class MarkCudaSyncCalls : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc f) {
MarkCudaSyncCalls mutator;
PrimFunc new_f = f;
new_f.CopyOnWrite()->body = mutator.VisitStmt(f->body);

if (mutator.has_trigger_launch_) {
new_f = WithAttr(std::move(new_f), attr::kHasTriggerLaunch, 1);
}
if (mutator.has_grid_sync_) {
new_f = WithAttr(std::move(new_f), attr::kHasGridSync, 1);
CheckLDGCalls analyzer;
analyzer(f->body);
}
return new_f;
}

PrimExpr VisitExpr_(const tir::CallNode *op) final {
CheckCall(op);
return StmtExprMutator::VisitExpr_(op);
}

private:
void CheckCall(const tir::CallNode *call) {
if (!call)
return;
if (call->op.same_as(builtin::call_extern())) {
if (!call->args.empty()) {
if (const auto *str_node =
call->args[0].as<tvm::tir::StringImmNode>()) {
std::string func_name = str_node->value;
if (func_name == "cudaTriggerProgrammaticLaunchCompletion") {
has_trigger_launch_ = true;
} else if (func_name == "cudaGridDependencySynchronize") {
has_grid_sync_ = true;
}
}
}
}
}

private:
bool has_trigger_launch_ = false;
bool has_grid_sync_ = false;

MarkCudaSyncCalls() = default;
};

class EliminateCudaSyncCalls : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc f) {
EliminateCudaSyncCalls mutator;
PrimFunc new_f = f;
new_f.CopyOnWrite()->body = mutator.VisitStmt(f->body);

return new_f;
}

PrimExpr VisitExpr_(const tir::CallNode *op) final {
if (CheckCall(op)) {
return make_zero(op->dtype);
}

return StmtExprMutator::VisitExpr_(op);
}

private:
bool CheckCall(const tir::CallNode *call) {
if (!call)
return false;

if (call->op.same_as(builtin::call_extern())) {
if (!call->args.empty()) {
if (const auto *str_node =
call->args[0].as<tvm::tir::StringImmNode>()) {
std::string func_name = str_node->value;
if (func_name == "cudaTriggerProgrammaticLaunchCompletion") {
return true;
} else if (func_name == "cudaGridDependencySynchronize") {
return true;
}
}
}
}

return false;
}

private:
EliminateCudaSyncCalls() = default;
};

using namespace tir::transform;

tvm::transform::Pass MarkCudaSyncCallsPass(bool have_pdl) {
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return have_pdl ? MarkCudaSyncCalls::Substitute(f)
: EliminateCudaSyncCalls::Substitute(f);
};

return CreatePrimFuncPass(pass_func, 0, "tl.MarkCudaSyncCalls", {});
}

TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MarkCudaSyncCalls",
MarkCudaSyncCallsPass);
}

} // namespace tl
} // namespace tvm
12 changes: 12 additions & 0 deletions src/transform/warp_specialized_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,18 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
if (call->op.same_as(loop_break())) {
role = Role::kBoth;
}
if (call->op.same_as(builtin::call_extern())) {
if (!call->args.empty()) {
if (const auto *str_node =
call->args[0].as<tvm::tir::StringImmNode>()) {
std::string func_name = str_node->value;
if (func_name == "cudaGridDependencySynchronize" ||
func_name == "cudaTriggerProgrammaticLaunchCompletion") {
role = Role::kBoth;
}
}
}
}
}
SetRole(op, role);
}
Expand Down
59 changes: 59 additions & 0 deletions testing/python/jit/test_tilelang_jit_cython.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from tilelang import tvm as tvm
import tilelang.language as T
import tilelang.testing
import tilelang
import torch


def test_cython_pdl():
"""Test pdl."""

N = 64

@tilelang.jit(execution_backend="cython")
def multi_kernels_with_pdl(N, block_size=256, dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
C: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as (bx,):
for i in T.Parallel(block_size):
idx = bx * block_size + i
if idx < N:
B[idx] = A[idx] + 1.0
T.pdl_trigger()

with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as (bx2,):
T.pdl_sync()
for i in T.Parallel(block_size):
idx = bx2 * block_size + i
if idx < N:
C[idx] = B[idx] * 2.0

return main

# Compile the kernel
kernel = multi_kernels_with_pdl(N)

# Create test tensors
a = torch.randn(N, dtype=torch.float32).cuda()
b = torch.randn(N, dtype=torch.float32).cuda()
c = torch.randn(N, dtype=torch.float32).cuda()

ref_b = a + 1.0
ref_c = ref_b * 2.0

kernel(a, b, c)

# Verify correctness

tilelang.testing.torch_assert_close(b, ref_b, atol=1e-5, rtol=1e-5)
tilelang.testing.torch_assert_close(c, ref_c, atol=1e-5, rtol=1e-5)

print("pdl test passed!")


if __name__ == "__main__":
tilelang.testing.main()
50 changes: 50 additions & 0 deletions testing/python/jit/test_tilelang_jit_nvrtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,5 +494,55 @@ def kernel(
print("L2 persistent map test passed!")


def test_nvrtc_pdl():
"""Test pdl."""

N = 64

@tilelang.jit(execution_backend="nvrtc")
def multi_kernels_with_pdl(N, block_size=256, dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
C: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as (bx,):
for i in T.Parallel(block_size):
idx = bx * block_size + i
if idx < N:
B[idx] = A[idx] + 1.0
T.pdl_trigger()

with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as (bx2,):
T.pdl_sync()
for i in T.Parallel(block_size):
idx = bx2 * block_size + i
if idx < N:
C[idx] = B[idx] * 2.0

return main

# Compile the kernel
kernel = multi_kernels_with_pdl(N)

# Create test tensors
a = torch.randn(N, dtype=torch.float32).cuda()
b = torch.randn(N, dtype=torch.float32).cuda()
c = torch.randn(N, dtype=torch.float32).cuda()

ref_b = a + 1.0
ref_c = ref_b * 2.0

kernel(a, b, c)

# Verify correctness

tilelang.testing.torch_assert_close(b, ref_b, atol=1e-5, rtol=1e-5)
tilelang.testing.torch_assert_close(c, ref_c, atol=1e-5, rtol=1e-5)

print("pdl test passed!")


if __name__ == "__main__":
tilelang.testing.main()
Loading