Skip to content

[Question] Performance Regression: Decode Kernel Recompiles at Every Step #1535

@drewjin

Description

@drewjin

Required prerequisites

Questions

Issue Description

After performing auto-tuning for the prefill kernel, I applied the selected Tile configurations (BLOCK_M, BLOCK_N, NUM_STAGES, NUM_THREADS) directly to the decode operator. However, the decode kernel triggers a re-compilation at every single step, consuming several seconds per iteration. This significantly bottlenecks the decode throughput. (Note: I haven't encountered this behavior with Triton kernels, so I am seeking guidance on the proper fix within TileLang.)

Environment

Steps to Reproduce

1. Run the test script

Setup the project Diffulex.

uv sync

Download the model JetLM/SDAR-1.7B-Chat-b32, update the model path in the script, and run:

python examples/test_sdar_diffulex_gsm8k.py 2>&1 | tee log/test_sdar_diffulex_gsm8k.log

2. Observe the logs

During the generation phase (after the first 10 outputs), observe that the decode kernel recompiles at every step.

Log Analysis

The logs in log/test_sdar_diffulex_gsm8k.log show the following:

Generating: 100%|██████████| 10/10 [02:25<00:00, 14.55s/it, Prefill=13768tok/s, Decode=42tok/s]
2025-12-25 08:28:18  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `kernel` with `out_idx=[10]`
2025-12-25 08:28:22  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `kernel`
2025-12-25 08:28:25  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `kernel` with `out_idx=[10]`
2025-12-25 08:28:29  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `kernel`
...
(Repeated recompilation, taking ~4-5 seconds per step)

**Performance Metrics:**
- **Prefill Throughput:** 13768.71 tok/s
- **Decode Throughput:** 42.53 tok/s (Severely degraded)
- **Average TPS:** 8.11 tok/s
- **Total Execution Time:** 145.53s

Code Analysis

In diffulex_kernel/python/dllm_flash_attn.py:

  1. Prefill Kernel: Uses the @tilelang.autotune decorator (line 24) to perform tuning during the warmup phase. https://github.com/zhijie-group/Diffulex/blob/main/diffulex_kernel/python/dllm_flash_attn.py#L167

  2. Decode Kernel: Only uses the @tilelang.jit decorator (line 164) without autotuning:
    https://github.com/zhijie-group/Diffulex/blob/main/diffulex_kernel/python/dllm_flash_attn.py#L28

    @tilelang.jit(
        out_idx=[-1], 
        pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,},
    )
    def dllm_flash_attn_decode_kernel(...):
  3. Trigger Point: Inside the dllm_flash_attn_decode function (lines 609-632), a new kernel instance is instantiated upon every call:
    https://github.com/zhijie-group/Diffulex/blob/main/diffulex/attention/metadata.py#L20

    def dllm_flash_attn_decode(...):
        if attn_metadata.decode_mode == "static":
            # New kernel instance created per call, triggering JIT
            decode_kernel = dllm_flash_attn_decode_kernel(
                attn_metadata.num_seqs,
                ...
                **kernel_config
            )
            return decode_kernel(...)

Root Cause Analysis

  1. The dllm_flash_attn_decode function instantiates a new kernel object on every invocation.
  2. Even though kernel_config is passed, the out_idx=[-1] requires dynamic shape inference based on runtime tensors (logged as out_idx=[10]).
  3. The TileLang JIT compiler appears to recompile for every new kernel instance and parameter combination, even if the configurations are identical.
  4. Lack of a kernel instance caching mechanism leads to massive compilation overhead during the decode phase.

Expected Behavior

  1. The Decode kernel should compile only once during the initial call.
  2. Subsequent calls with the same parameter combinations should reuse the cached, pre-compiled kernel.
  3. Eliminate the multi-second compilation latency at each decoding step.

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions