-
Notifications
You must be signed in to change notification settings - Fork 359
[Feat] PDL Support #1494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
w169q169
wants to merge
25
commits into
tile-ai:main
Choose a base branch
from
w169q169:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+512
−4
Open
[Feat] PDL Support #1494
Changes from 22 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
2a67387
[Feat] PDL Support
w169q169 4898b56
remove unused comments/codes
w169q169 f50aa99
fix mismatch funcname
w169q169 b846450
Resolve duplicate config issue for AI review
w169q169 cb052f4
fix typo error && clang-tidy error
w169q169 8ed4124
chore: retrigger CI
w169q169 466c94a
change file comment
w169q169 05b8dc0
rename test_tilelang_jit_ctypes.py to test_tilelang_jit_cython.py
w169q169 62ed58f
remove ctype backend
w169q169 46fe83f
Move pdl attributes to header and remove redundant VisitStmt_
w169q169 898135e
chore: trigger CI
w169q169 1c75231
Merge branch 'main' of https://github.com/tile-ai/tilelang
silentCoder-dev 7ed6748
modify the cuda kernel with pdl due to nvcc's bug
silentCoder-dev 8978b11
throw an error when invoking __ldg with pdl due to nvcc's bug
silentCoder-dev af5f26d
Throw an error when pdl is not supported
silentCoder-dev a84034b
fix checking about pdl
silentCoder-dev c123fd7
add comments about MarkCudaSyncCalls
silentCoder-dev e0f058f
remove pdl support with compute_capability<90
silentCoder-dev 017d314
Merge branch 'main' of https://github.com/tile-ai/tilelang
silentCoder-dev 877962b
ruff check
silentCoder-dev 8ad39ab
fix nvrtc about pdl
w169q169 22455f2
Merge branch 'w169q169/main' into HEAD
w169q169 2f83ea7
update tvm & extend pdl support for tvm_ffi & add test for tvm_ffi
silentCoder-dev 24117e6
ensure that kUseDynamicSharedMemoryTag is the last tag in launch_para…
silentCoder-dev c1cdf7c
Merge remote-tracking branch 'upstream'
silentCoder-dev File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| /*! | ||
| * \file lower_pdl.cc | ||
| * \brief Mark Device PrimFunc with attributes if CUDA PDL functions are called | ||
| */ | ||
|
|
||
| #include "../op/builtin.h" | ||
| #include "../target/utils.h" | ||
| #include "common/attr.h" | ||
| #include "tvm/ir/type.h" | ||
| #include "tvm/tir/builtin.h" | ||
| #include "tvm/tir/expr.h" | ||
| #include "tvm/tir/stmt.h" | ||
| #include <tvm/ffi/reflection/registry.h> | ||
| #include <tvm/tir/analysis.h> | ||
| #include <tvm/tir/builtin.h> | ||
| #include <tvm/tir/stmt_functor.h> | ||
| #include <tvm/tir/transform.h> | ||
|
|
||
| namespace tvm { | ||
| namespace tl { | ||
|
|
||
| using namespace tir; | ||
|
|
||
| // NVCC has issues with __ldg when using PDL (Programmatic Dependent Launch) | ||
| // synchronization. Suppress the annotation when kHasGridSync is set. | ||
| class CheckLDGCalls : public StmtExprVisitor { | ||
| public: | ||
| void VisitExpr_(const tir::CallNode *op) final { | ||
| if (op->op.same_as(tl::__ldg())) { | ||
| LOG(FATAL) << "Cannot invoke __ldg function with pdl_sync"; | ||
| } | ||
| StmtExprVisitor::VisitExpr_(op); | ||
| } | ||
| }; | ||
|
|
||
| class MarkCudaSyncCalls : public StmtExprMutator { | ||
| public: | ||
| static PrimFunc Substitute(PrimFunc f, bool support_pdl) { | ||
| MarkCudaSyncCalls mutator; | ||
| PrimFunc new_f = f; | ||
| new_f.CopyOnWrite()->body = mutator.VisitStmt(f->body); | ||
|
|
||
| if (!support_pdl) { | ||
| ICHECK(!mutator.has_trigger_launch_ && !mutator.has_grid_sync_) | ||
| << "PDL is not supported"; | ||
| } | ||
|
|
||
| if (mutator.has_trigger_launch_) { | ||
| new_f = WithAttr(std::move(new_f), attr::kHasTriggerLaunch, 1); | ||
| } | ||
| if (mutator.has_grid_sync_) { | ||
| new_f = WithAttr(std::move(new_f), attr::kHasGridSync, 1); | ||
| CheckLDGCalls analyzer; | ||
| analyzer(f->body); | ||
| } | ||
| return new_f; | ||
| } | ||
|
|
||
| PrimExpr VisitExpr_(const tir::CallNode *op) final { | ||
| if (op && op->op.same_as(builtin::call_extern())) { | ||
| if (!op->args.empty()) { | ||
| if (const auto *str_node = op->args[0].as<tvm::tir::StringImmNode>()) { | ||
| std::string func_name = str_node->value; | ||
| if (func_name == "cudaTriggerProgrammaticLaunchCompletion") { | ||
| has_trigger_launch_ = true; | ||
| } else if (func_name == "cudaGridDependencySynchronize") { | ||
| has_grid_sync_ = true; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| return StmtExprMutator::VisitExpr_(op); | ||
| } | ||
|
|
||
| private: | ||
| bool has_trigger_launch_ = false; | ||
| bool has_grid_sync_ = false; | ||
|
|
||
| MarkCudaSyncCalls() = default; | ||
| }; | ||
|
|
||
| using namespace tir::transform; | ||
|
|
||
| tvm::transform::Pass MarkCudaSyncCallsPass(bool support_pdl) { | ||
| auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { | ||
| return MarkCudaSyncCalls::Substitute(f, support_pdl); | ||
| }; | ||
|
|
||
| return CreatePrimFuncPass(pass_func, 0, "tl.MarkCudaSyncCalls", {}); | ||
| } | ||
|
|
||
| TVM_FFI_STATIC_INIT_BLOCK() { | ||
| namespace refl = tvm::ffi::reflection; | ||
| refl::GlobalDef().def("tl.transform.MarkCudaSyncCalls", | ||
| MarkCudaSyncCallsPass); | ||
| } | ||
|
|
||
| } // namespace tl | ||
| } // namespace tvm |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
w169q169 marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| from tilelang import tvm as tvm | ||
| import tilelang.language as T | ||
| import tilelang.testing | ||
| import tilelang | ||
| import torch | ||
| import pytest | ||
|
|
||
|
|
||
| def check_pdl(): | ||
| if not torch.cuda.is_available(): | ||
| return False | ||
| props = torch.cuda.get_device_properties(0) | ||
| compute_capability = props.major, props.minor | ||
| return compute_capability[0] >= 9 | ||
|
|
||
|
|
||
| def test_cython_pdl(): | ||
| """Test pdl.""" | ||
|
|
||
| if not check_pdl(): | ||
| pytest.skip("PDL Test requires compute capability >= 9") | ||
|
|
||
| N = 64 | ||
|
|
||
| @tilelang.jit(execution_backend="cython") | ||
| def multi_kernels_with_pdl(N, block_size=256, dtype=T.float32): | ||
| @T.prim_func | ||
| def main( | ||
| A: T.Tensor((N,), dtype), | ||
| B: T.Tensor((N,), dtype), | ||
| C: T.Tensor((N,), dtype), | ||
| ): | ||
| with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as (bx,): | ||
| for i in T.Parallel(block_size): | ||
| idx = bx * block_size + i | ||
| if idx < N: | ||
| B[idx] = A[idx] + 1.0 | ||
| T.pdl_trigger() | ||
|
|
||
| with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as (bx2,): | ||
| T.pdl_sync() | ||
| for i in T.Parallel(block_size): | ||
| idx = bx2 * block_size + i | ||
| if idx < N: | ||
| C[idx] = B[idx] * 2.0 | ||
|
|
||
| return main | ||
|
|
||
| # Compile the kernel | ||
| kernel = multi_kernels_with_pdl(N) | ||
|
|
||
| # Create test tensors | ||
| a = torch.randn(N, dtype=torch.float32).cuda() | ||
| b = torch.randn(N, dtype=torch.float32).cuda() | ||
| c = torch.randn(N, dtype=torch.float32).cuda() | ||
|
|
||
| ref_b = a + 1.0 | ||
| ref_c = ref_b * 2.0 | ||
|
|
||
| kernel(a, b, c) | ||
|
|
||
| # Verify correctness | ||
|
|
||
| tilelang.testing.torch_assert_close(b, ref_b, atol=1e-5, rtol=1e-5) | ||
| tilelang.testing.torch_assert_close(c, ref_c, atol=1e-5, rtol=1e-5) | ||
|
|
||
| print("pdl test passed!") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tilelang.testing.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| import tilelang.testing | ||
| import tilelang.language as T | ||
w169q169 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def kernels_with_pdl_trigger(N, block_size=256, dtype=T.float32): | ||
| @T.prim_func | ||
| def main( | ||
| A: T.Tensor((N,), dtype), | ||
| B: T.Tensor((N,), dtype), | ||
| ): | ||
| with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as (bx,): | ||
| for i in T.Parallel(block_size): | ||
| idx = bx * block_size + i | ||
| if idx < N: | ||
| B[idx] = A[idx] + 1.0 | ||
| T.pdl_trigger() | ||
|
|
||
| return main | ||
|
|
||
|
|
||
| def kernels_with_pdl_sync(N, block_size=256, dtype=T.float32): | ||
| @T.prim_func | ||
| def main( | ||
| A: T.Tensor((N,), dtype), | ||
| B: T.Tensor((N,), dtype), | ||
| ): | ||
| with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as (bx2,): | ||
| T.pdl_sync() | ||
| for i in T.Parallel(block_size): | ||
| idx = bx2 * block_size + i | ||
| if idx < N: | ||
| B[idx] = A[idx] * 2.0 | ||
|
|
||
| return main | ||
|
|
||
|
|
||
| def test_pdl_trigger(): | ||
| N = 64 | ||
| program = kernels_with_pdl_trigger(N) | ||
|
|
||
| pdl_kernel = tilelang.compile(program, target="cuda -arch=sm_90") | ||
| code = pdl_kernel.get_kernel_source() | ||
| assert "cudaTriggerProgrammaticLaunchCompletion" in code | ||
|
|
||
|
|
||
| def test_pdl_sync(): | ||
| N = 64 | ||
| program = kernels_with_pdl_sync(N) | ||
|
|
||
| pdl_kernel = tilelang.compile(program, target="cuda -arch=sm_90") | ||
| code = pdl_kernel.get_kernel_source() | ||
| assert "cudaGridDependencySynchronize" in code | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tilelang.testing.main() | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.