Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
2a67387
[Feat] PDL Support
w169q169 Dec 21, 2025
4898b56
remove unused comments/codes
w169q169 Dec 22, 2025
f50aa99
fix mismatch funcname
w169q169 Dec 22, 2025
b846450
Resolve duplicate config issue for AI review
w169q169 Dec 22, 2025
cb052f4
fix typo error && clang-tidy error
w169q169 Dec 22, 2025
8ed4124
chore: retrigger CI
w169q169 Dec 23, 2025
466c94a
change file comment
w169q169 Dec 23, 2025
05b8dc0
rename test_tilelang_jit_ctypes.py to test_tilelang_jit_cython.py
w169q169 Dec 23, 2025
62ed58f
remove ctype backend
w169q169 Dec 24, 2025
46fe83f
Move pdl attributes to header and remove redundant VisitStmt_
w169q169 Dec 24, 2025
898135e
chore: trigger CI
w169q169 Dec 24, 2025
1c75231
Merge branch 'main' of https://github.com/tile-ai/tilelang
silentCoder-dev Dec 25, 2025
7ed6748
modify the cuda kernel with pdl due to nvcc's bug
silentCoder-dev Dec 25, 2025
8978b11
throw an error when invoking __ldg with pdl due to nvcc's bug
silentCoder-dev Dec 25, 2025
af5f26d
Throw an error when pdl is not supported
silentCoder-dev Dec 25, 2025
a84034b
fix checking about pdl
silentCoder-dev Dec 25, 2025
c123fd7
add comments about MarkCudaSyncCalls
silentCoder-dev Dec 25, 2025
e0f058f
remove pdl support with compute_capability<90
silentCoder-dev Dec 26, 2025
017d314
Merge branch 'main' of https://github.com/tile-ai/tilelang
silentCoder-dev Dec 26, 2025
877962b
ruff check
silentCoder-dev Dec 26, 2025
8ad39ab
fix nvrtc about pdl
w169q169 Dec 26, 2025
22455f2
Merge branch 'w169q169/main' into HEAD
w169q169 Dec 26, 2025
2f83ea7
update tvm & extend pdl support for tvm_ffi & add test for tvm_ffi
silentCoder-dev Dec 26, 2025
24117e6
ensure that kUseDynamicSharedMemoryTag is the last tag in launch_para…
silentCoder-dev Dec 29, 2025
c1cdf7c
Merge remote-tracking branch 'upstream'
silentCoder-dev Dec 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/programming_guides/instructions.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ Annotation helpers
- `T.annotate_safe_value(var, ...)`: Safety/const hints.
- `T.annotate_l2_hit_ratio(buf, ratio)`: Cache behavior hint.

Synchronization helpers
- `T.pdl_trigger()`: Signal programmatic launch completion for the current kernel.
- `T.pdl_sync()`: Wait until kernel dependencies are satisfied.

Atomics
- `T.atomic_add(dst, value, memory_order=None, return_prev=False, use_tma=False)`.
- `T.atomic_addx2(dst, value, return_prev=False)`; `T.atomic_addx4(...)`.
Expand Down
13 changes: 11 additions & 2 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <vector>

#include "../op/builtin.h"
#include "../transform/common/attr.h"
#include "./ptx.h"
#include "./utils.h"
#include "arith/pattern_match.h"
Expand Down Expand Up @@ -3332,6 +3333,10 @@ void CodeGenTileLangCUDA::PrintFunctionSignature(const String &function_name,
CodeGenC::PrintType(func->ret_type, os);
CodeGenC::PrintExtraAttrs(func, os);
bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias);
// NVCC has issues with __restrict__ on kernel parameters when using PDL
// (Programmatic Dependent Launch) synchronization. Suppress the annotation
// when kHasGridSync is set.
bool has_cuda_pdl_sync = func->HasNonzeroAttr(tl::attr::kHasGridSync);
std::unordered_set<const VarNode *> non_restrict;
if (auto opt =
func->GetAttr<ffi::Array<tir::Var>>(tl::attr::kNonRestrictParams)) {
Expand Down Expand Up @@ -3381,7 +3386,7 @@ void CodeGenTileLangCUDA::PrintFunctionSignature(const String &function_name,
}
}

if (no_alias && !non_restrict.count(v.get())) {
if (!has_cuda_pdl_sync && no_alias && !non_restrict.count(v.get())) {
PrintRestrict(v, os);
}
} else {
Expand Down Expand Up @@ -3417,6 +3422,10 @@ void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar,
ICHECK(global_symbol)
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
// NVCC has issues with __restrict__ on kernel parameters when using PDL
// (Programmatic Dependent Launch) synchronization. Suppress the annotation
// when kHasGridSync is set.
bool has_cuda_pdl_sync = f->HasNonzeroAttr(tl::attr::kHasGridSync);
std::unordered_set<const VarNode *> non_restrict;
if (auto opt =
f->GetAttr<ffi::Array<tir::Var>>(tl::attr::kNonRestrictParams)) {
Expand Down Expand Up @@ -3468,7 +3477,7 @@ void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar,
}
}

if (no_alias && !non_restrict.count(v.get())) {
if (!has_cuda_pdl_sync && no_alias && !non_restrict.count(v.get())) {
PrintRestrict(v, stream);
}
} else {
Expand Down
6 changes: 6 additions & 0 deletions src/transform/common/attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,11 @@ constexpr const char *MainBlockName = "tilelang_root";
constexpr const char *tilelang_is_cpu_kernel_frame =
"tilelang.is_cpu_kernel_frame";

namespace attr {
// Attributes to mark CUDA sync calls
constexpr const char *kHasTriggerLaunch = "has_cuda_pdl_trigger";
constexpr const char *kHasGridSync = "has_cuda_pdl_sync";
} // namespace attr

} // namespace tl
} // namespace tvm
99 changes: 99 additions & 0 deletions src/transform/lower_pdl.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*!
* \file lower_pdl.cc
* \brief Mark Device PrimFunc with attributes if CUDA PDL functions are called
*/

#include "../op/builtin.h"
#include "../target/utils.h"
#include "common/attr.h"
#include "tvm/ir/type.h"
#include "tvm/tir/builtin.h"
#include "tvm/tir/expr.h"
#include "tvm/tir/stmt.h"
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

namespace tvm {
namespace tl {

using namespace tir;

// NVCC has issues with __ldg when using PDL (Programmatic Dependent Launch)
// synchronization. Suppress the annotation when kHasGridSync is set.
class CheckLDGCalls : public StmtExprVisitor {
public:
void VisitExpr_(const tir::CallNode *op) final {
if (op->op.same_as(tl::__ldg())) {
LOG(FATAL) << "Cannot invoke __ldg function with pdl_sync";
}
StmtExprVisitor::VisitExpr_(op);
}
};

class MarkCudaSyncCalls : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc f, bool support_pdl) {
MarkCudaSyncCalls mutator;
PrimFunc new_f = f;
new_f.CopyOnWrite()->body = mutator.VisitStmt(f->body);

if (!support_pdl) {
ICHECK(!mutator.has_trigger_launch_ && mutator.has_grid_sync_)
<< "PDL is not supported";
}

if (mutator.has_trigger_launch_) {
new_f = WithAttr(std::move(new_f), attr::kHasTriggerLaunch, 1);
}
if (mutator.has_grid_sync_) {
new_f = WithAttr(std::move(new_f), attr::kHasGridSync, 1);
CheckLDGCalls analyzer;
analyzer(f->body);
}
return new_f;
}

PrimExpr VisitExpr_(const tir::CallNode *op) final {
if (op && op->op.same_as(builtin::call_extern())) {
if (!op->args.empty()) {
if (const auto *str_node = op->args[0].as<tvm::tir::StringImmNode>()) {
std::string func_name = str_node->value;
if (func_name == "cudaTriggerProgrammaticLaunchCompletion") {
has_trigger_launch_ = true;
} else if (func_name == "cudaGridDependencySynchronize") {
has_grid_sync_ = true;
}
}
}
}
return StmtExprMutator::VisitExpr_(op);
}

private:
bool has_trigger_launch_ = false;
bool has_grid_sync_ = false;

MarkCudaSyncCalls() = default;
};

using namespace tir::transform;

tvm::transform::Pass MarkCudaSyncCallsPass(bool support_pdl) {
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return MarkCudaSyncCalls::Substitute(f, support_pdl);
};

return CreatePrimFuncPass(pass_func, 0, "tl.MarkCudaSyncCalls", {});
}

TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MarkCudaSyncCalls",
MarkCudaSyncCallsPass);
}

} // namespace tl
} // namespace tvm
12 changes: 12 additions & 0 deletions src/transform/warp_specialized_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,18 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
if (call->op.same_as(loop_break())) {
role = Role::kBoth;
}
if (call->op.same_as(builtin::call_extern())) {
if (!call->args.empty()) {
if (const auto *str_node =
call->args[0].as<tvm::tir::StringImmNode>()) {
std::string func_name = str_node->value;
if (func_name == "cudaGridDependencySynchronize" ||
func_name == "cudaTriggerProgrammaticLaunchCompletion") {
role = Role::kBoth;
}
}
}
}
}
SetRole(op, role);
}
Expand Down
71 changes: 71 additions & 0 deletions testing/python/jit/test_tilelang_jit_cython.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from tilelang import tvm as tvm
import tilelang.language as T
import tilelang.testing
import tilelang
import torch
import pytest


def check_pdl():
if not torch.cuda.is_available():
return False
props = torch.cuda.get_device_properties(0)
compute_capability = props.major, props.minor
return compute_capability[0] >= 9


def test_cython_pdl():
"""Test pdl."""

if not check_pdl():
pytest.skip("PDL Test requires compute capability >= 9")

N = 64

@tilelang.jit(execution_backend="cython")
def multi_kernels_with_pdl(N, block_size=256, dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
C: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as (bx,):
for i in T.Parallel(block_size):
idx = bx * block_size + i
if idx < N:
B[idx] = A[idx] + 1.0
T.pdl_trigger()

with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as (bx2,):
T.pdl_sync()
for i in T.Parallel(block_size):
idx = bx2 * block_size + i
if idx < N:
C[idx] = B[idx] * 2.0

return main

# Compile the kernel
kernel = multi_kernels_with_pdl(N)

# Create test tensors
a = torch.randn(N, dtype=torch.float32).cuda()
b = torch.randn(N, dtype=torch.float32).cuda()
c = torch.randn(N, dtype=torch.float32).cuda()

ref_b = a + 1.0
ref_c = ref_b * 2.0

kernel(a, b, c)

# Verify correctness

tilelang.testing.torch_assert_close(b, ref_b, atol=1e-5, rtol=1e-5)
tilelang.testing.torch_assert_close(c, ref_c, atol=1e-5, rtol=1e-5)

print("pdl test passed!")


if __name__ == "__main__":
tilelang.testing.main()
62 changes: 62 additions & 0 deletions testing/python/jit/test_tilelang_jit_nvrtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tilelang
import torch
from tilelang.utils.tensor import map_torch_type
import pytest


def matmul(
Expand Down Expand Up @@ -494,5 +495,66 @@ def kernel(
print("L2 persistent map test passed!")


def check_pdl():
if not torch.cuda.is_available():
return False
props = torch.cuda.get_device_properties(0)
compute_capability = props.major, props.minor
return compute_capability[0] >= 9


def test_nvrtc_pdl():
"""Test pdl."""

if not check_pdl():
pytest.skip("PDL Test requires compute capability >= 9")

N = 64

@tilelang.jit(execution_backend="nvrtc")
def multi_kernels_with_pdl(N, block_size=256, dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
C: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as (bx,):
for i in T.Parallel(block_size):
idx = bx * block_size + i
if idx < N:
B[idx] = A[idx] + 1.0
T.pdl_trigger()

with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as (bx2,):
T.pdl_sync()
for i in T.Parallel(block_size):
idx = bx2 * block_size + i
if idx < N:
C[idx] = B[idx] * 2.0

return main

# Compile the kernel
kernel = multi_kernels_with_pdl(N)

# Create test tensors
a = torch.randn(N, dtype=torch.float32).cuda()
b = torch.randn(N, dtype=torch.float32).cuda()
c = torch.randn(N, dtype=torch.float32).cuda()

ref_b = a + 1.0
ref_c = ref_b * 2.0

kernel(a, b, c)

# Verify correctness

tilelang.testing.torch_assert_close(b, ref_b, atol=1e-5, rtol=1e-5)
tilelang.testing.torch_assert_close(c, ref_c, atol=1e-5, rtol=1e-5)

print("pdl test passed!")


if __name__ == "__main__":
tilelang.testing.main()
Loading