Skip to content

Conversation

@kurisu6912
Copy link
Collaborator

@kurisu6912 kurisu6912 commented Dec 19, 2025

In the previous version of LazyJIT, some constant/symblic are written outside of the function. This makes the global namespace crowded and name conflict.

In this revision, we move the type annotation inside function, which is simpler and clear:

import tilelang
import tilelang.language as T
import torch

@tilelang.lazy_jit
def gemm(
    A, B, C,
    out_dtype: T.dtype = T.float32,
    block_M: int = 128,
    block_N: int = 128,
    block_K: int = 32,
):
    M, N, K = T.const('M, N, K')

    A: T.Tensor[[M, K], T.float16]
    B: T.Tensor[[K, N], T.float16]
    C: T.Tensor[[M, N], out_dtype]

    with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N)) as (bx, by):
        A_shared = T.alloc_shared((block_M, block_K), A.dtype)
        B_shared = T.alloc_shared((block_K, block_N), B.dtype)
        C_local = T.alloc_fragment((block_M, block_N), out_dtype)
        T.clear(C_local)
        for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
            T.copy(A[bx * block_M, k * block_K], A_shared)
            T.copy(B[k * block_K, by * block_N], B_shared)
            T.gemm(A_shared, B_shared, C_local)
        T.copy(C_local, C[bx * block_M, by * block_N])

A = torch.randn(1024, 1024, dtype=torch.float16, device='cuda')
B = torch.randn(1024, 1024, dtype=torch.float16, device='cuda')
C = torch.randn(1024, 1024, dtype=torch.float32, device='cuda')

gemm(A, B, C)

TODO

  • Two-phase elaboration
  • Add support for T.empty and return values
  • Syntax sugar: remove T.ptr annotation completely
  • Add error report for Python positional-only arguments
  • Write comprehensive tests
  • Add examples

Summary by CodeRabbit

  • New Features

    • Dynamic-shape annotations and in-body shape/constexpr declarations; lazy JIT/template-driven IR generation exposed.
    • Public side-effect helper added.
  • Refactor

    • Simplified public APIs and exports; many type/shape annotations moved from signatures into function bodies; legacy annotation module removed.
    • Streamlined lazy JIT, caching, and logging flows.
  • Documentation

    • Notebooks updated to show direct-call JIT, dynamic-macro usage, and cleaner outputs.
  • Tests

    • Tests consolidated to validate lazy JIT numeric paths; some legacy generator/annotated tests removed.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions

This comment was marked as resolved.

@LeiWang1999 LeiWang1999 self-requested a review December 19, 2025 10:07
@kurisu6912 kurisu6912 marked this pull request as ready for review December 23, 2025 05:03
coderabbitai[bot]

This comment was marked as resolved.

coderabbitai[bot]

This comment was marked as resolved.

@kurisu6912
Copy link
Collaborator Author

kurisu6912 commented Dec 23, 2025

This PR:

  1. Add SideEffect function to the Python side, which allows the user to write some complex shape expressoin in the function body:

    def foo(A, B):
        N = T.dynamic('N')
        # In the previous version, this generates a LetStmt
        # But in this version, we check its side effects to disable LetStmt generation
        M = N * 2 + 1
        A: T.Tensor[[N], T.float32]
        B: T.Tensor[[M], T.float32]
  2. Two-phase elaboration:

    1. Phase 1: Generating a function with a dynamic/const shape placeholder
    2. Phase 2: Replace the const placeholder with the matched tensor shape/stride
      https://github.com/kurisu6912/tilelang/blob/71ed76bbbee5beb96a12f151be83a9b20bfd67f2/tilelang/language/v2/builder.py#L919-L933
  3. Function annotation heuristic: In both JIT and LazyJIT, if the function arguments are annotated inside the function body, we apply transforms to change it to match_buffer and add it to arguments

    def foo(A, B):
        A: T.Tensor[[128], T.float32]
        B: T.float32

    Is transformed to

    def foo(A: T.ptr, B: T.float32):
        A = T.match_buffer(A, [128], T.float32)
        B = B

@tile-ai tile-ai deleted a comment from coderabbitai bot Dec 23, 2025
@kurisu6912 kurisu6912 changed the title [LazyJIT] Move type annotation inside function [LazyJIT] Move Type Annotations to Function Body Dec 23, 2025
@coderabbitai

This comment was marked as resolved.

coderabbitai[bot]

This comment was marked as resolved.

coderabbitai[bot]

This comment was marked as resolved.

@kurisu6912

This comment was marked as resolved.

@kurisu6912

This comment was marked as resolved.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
tilelang/jit/__init__.py (1)

409-416: Non-lazy path logic is correct, pending fix for __tune_params bug.

The non-lazy path correctly returns the kernel object itself (line 416) rather than executing it, which aligns with the jit decorator's return type JITKernel[_KP, _T]. However, this code is affected by the critical __tune_params double-pop bug identified in lines 376-381.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bd665cf and 9aee247.

📒 Files selected for processing (2)
  • tilelang/cache/kernel_cache.py
  • tilelang/jit/__init__.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tilelang/cache/kernel_cache.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (6)
tilelang/jit/__init__.py (6)

275-275: LGTM!

Type annotation correctly includes LazyJITFunc[_KP, _T] to support lazy JIT compilation paths.


291-292: LGTM!

Correctly delegates to LazyJITFunc.get_tir for lazy JIT functions. Consistent with the two-phase elaboration pattern described in PR objectives.


496-514: LGTM!

The jit decorator correctly checks for PrimFunc instances and constructs JITImpl with lazy_jit=False. Removal of PrimFuncCreater handling aligns with PR objectives.


569-573: LGTM!

The lazy_jit decorator correctly creates a LazyJITFunc via prim_func(func, lazy_jit=True) and wraps it in JITImpl with lazy_jit=True, enabling the lazy compilation flow.


27-27: No issues found. LazyJITFunc is properly exported from tilelang.language.v2 and implements the expected methods (get_tir and parse_args).


400-407: No changes needed. The kernel_args returned by LazyJITFunc.parse_args is a dict (tensor_args in the implementation) initialized as {} and populated with tensor arguments. Calling kernel(*kernel_args.values()) is correct.

Likely an incorrect or invalid review comment.

@LeiWang1999
Copy link
Member

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

@LeiWang1999
Copy link
Member

In the previous version of LazyJIT, some constant/symblic are written outside of the function. This makes the global namespace crowded and name conflict.

In this revision, we move the type annotation inside function, which is simpler and clear:

import tilelang
import tilelang.language as T
import torch

@tilelang.lazy_jit
def gemm(
    A: T.ptr, B: T.ptr, C: T.ptr,
    out_dtype: T.dtype = T.float32,
    block_M: int = 128,
    block_N: int = 128,
    block_K: int = 32,
):
    M, N, K = T.const('M, N, K')

    A: T.Tensor[[M, K], T.float16]
    B: T.Tensor[[K, N], T.float16]
    C: T.Tensor[[M, N], out_dtype]

    with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N)) as (bx, by):
        A_shared = T.alloc_shared((block_M, block_K), A.dtype)
        B_shared = T.alloc_shared((block_K, block_N), B.dtype)
        C_local = T.alloc_fragment((block_M, block_N), out_dtype)
        T.clear(C_local)
        for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
            T.copy(A[bx * block_M, k * block_K], A_shared)
            T.copy(B[k * block_K, by * block_N], B_shared)
            T.gemm(A_shared, B_shared, C_local)
        T.copy(C_local, C[bx * block_M, by * block_N])

A = torch.randn(1024, 1024, dtype=torch.float16, device='cuda')
B = torch.randn(1024, 1024, dtype=torch.float16, device='cuda')
C = torch.randn(1024, 1024, dtype=torch.float32, device='cuda')

gemm(A, B, C)

TODO

  • Two-phase elaboration
  • Add support for T.empty and return values
  • Syntax sugar: remove T.ptr annotation completely
  • Add error report for Python positional-only arguments
  • Write comprehensive tests
  • Add examples

Summary by CodeRabbit

  • New Features

    • Dynamic-shape annotations and in-body shape declarations.
    • Lazy JIT/template-driven compilation for on-demand IR generation.
    • Public side-effect helper exposed for scheduling.
  • Refactor

    • Moved type/shape declarations from signatures into function bodies; simplified public APIs and exports.
    • Removed legacy annotation framework and streamlined JIT/caching/logging flows.
  • Documentation

    • Notebooks revised to showcase direct-call JIT, macro/mapping wording, and cleaner outputs.
  • Tests

    • Tests consolidated and updated to validate lazy JIT numeric paths; some generator-based tests removed.

✏️ Tip: You can customize this high-level summary in your review settings.

in your comments, A, B, C is annotated with T.ptr, I think that's necessary and may lead to confusion, maybe we should remove it.

@kurisu6912
Copy link
Collaborator Author

Updated, user doesn't need to annotate A, B, C

@LeiWang1999
Copy link
Member

Overall LGTM, but just for Function annotation heuristic: In both JIT and LazyJIT, if the function arguments are annotated inside the function body, we apply transforms to change it to match_buffer and add it to arguments I've some concerns, can user still use T.ptr as function annotations if they just want to use ptr as input?

@tilelang.lazy_jit
def func(
   A:T.ptr
):
   ...

func(tensor.data_ptr())

LeiWang1999
LeiWang1999 previously approved these changes Dec 27, 2025
@LeiWang1999
Copy link
Member

Overall LGTM, but just for Function annotation heuristic: In both JIT and LazyJIT, if the function arguments are annotated inside the function body, we apply transforms to change it to match_buffer and add it to arguments I've some concerns, can user still use T.ptr as function annotations if they just want to use ptr as input?

@tilelang.lazy_jit
def func(
   A:T.ptr
):
   ...

func(tensor.data_ptr())

We already have a ptr based test and it can pass.

@LeiWang1999
Copy link
Member

@codex review

chatgpt-codex-connector[bot]

This comment was marked as resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants