-
Notifications
You must be signed in to change notification settings - Fork 359
Description
Required prerequisites
- I have read the documentation https://tilelang.com.
- I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
Questions
Issue Description
After performing auto-tuning for the prefill kernel, I applied the selected Tile configurations (BLOCK_M, BLOCK_N, NUM_STAGES, NUM_THREADS) directly to the decode operator. However, the decode kernel triggers a re-compilation at every single step, consuming several seconds per iteration. This significantly bottlenecks the decode throughput. (Note: I haven't encountered this behavior with Triton kernels, so I am seeking guidance on the proper fix within TileLang.)
Environment
- Test Script:
examples/test_sdar_diffulex_gsm8k.py - Model:
SDAR-1.7B-Chat-b32 - Decoding Strategy: block_diffusion
- Log File:
log/test_sdar_diffulex_gsm8k.log
Steps to Reproduce
1. Run the test script
Setup the project Diffulex.
uv sync
Download the model JetLM/SDAR-1.7B-Chat-b32, update the model path in the script, and run:
python examples/test_sdar_diffulex_gsm8k.py 2>&1 | tee log/test_sdar_diffulex_gsm8k.log2. Observe the logs
During the generation phase (after the first 10 outputs), observe that the decode kernel recompiles at every step.
Log Analysis
The logs in log/test_sdar_diffulex_gsm8k.log show the following:
Generating: 100%|██████████| 10/10 [02:25<00:00, 14.55s/it, Prefill=13768tok/s, Decode=42tok/s]
2025-12-25 08:28:18 [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `kernel` with `out_idx=[10]`
2025-12-25 08:28:22 [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `kernel`
2025-12-25 08:28:25 [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `kernel` with `out_idx=[10]`
2025-12-25 08:28:29 [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `kernel`
...
(Repeated recompilation, taking ~4-5 seconds per step)
**Performance Metrics:**
- **Prefill Throughput:** 13768.71 tok/s
- **Decode Throughput:** 42.53 tok/s (Severely degraded)
- **Average TPS:** 8.11 tok/s
- **Total Execution Time:** 145.53s
Code Analysis
In diffulex_kernel/python/dllm_flash_attn.py:
-
Prefill Kernel: Uses the
@tilelang.autotunedecorator (line 24) to perform tuning during the warmup phase. https://github.com/zhijie-group/Diffulex/blob/main/diffulex_kernel/python/dllm_flash_attn.py#L167 -
Decode Kernel: Only uses the
@tilelang.jitdecorator (line 164) without autotuning:
https://github.com/zhijie-group/Diffulex/blob/main/diffulex_kernel/python/dllm_flash_attn.py#L28@tilelang.jit( out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,}, ) def dllm_flash_attn_decode_kernel(...):
-
Trigger Point: Inside the
dllm_flash_attn_decodefunction (lines 609-632), a new kernel instance is instantiated upon every call:
https://github.com/zhijie-group/Diffulex/blob/main/diffulex/attention/metadata.py#L20def dllm_flash_attn_decode(...): if attn_metadata.decode_mode == "static": # New kernel instance created per call, triggering JIT decode_kernel = dllm_flash_attn_decode_kernel( attn_metadata.num_seqs, ... **kernel_config ) return decode_kernel(...)
Root Cause Analysis
- The
dllm_flash_attn_decodefunction instantiates a new kernel object on every invocation. - Even though
kernel_configis passed, theout_idx=[-1]requires dynamic shape inference based on runtime tensors (logged asout_idx=[10]). - The TileLang JIT compiler appears to recompile for every new kernel instance and parameter combination, even if the configurations are identical.
- Lack of a kernel instance caching mechanism leads to massive compilation overhead during the decode phase.
Expected Behavior
- The Decode kernel should compile only once during the initial call.
- Subsequent calls with the same parameter combinations should reuse the cached, pre-compiled kernel.
- Eliminate the multi-second compilation latency at each decoding step.