Skip to content

[BUG] Inconsistent CUDA Kernel Generation for Decode Attention with Paged KV Cache #1552

@drewjin

Description

@drewjin

Required prerequisites

What version of TileLang are you using?

0.1.7.post1

System information

manual collect:

3.12.12 | packaged by conda-forge | (main, Oct 22 2025, 23:25:55) [GCC 14.3.0] linux
0.1.7.post1
2.9.0+cu128

torch.utils.collect_env:

PyTorch version: 2.9.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31

Python version: 3.12.12 | packaged by conda-forge | (main, Oct 22 2025, 23:25:55) [GCC 14.3.0] (64-bit runtime)
Python platform: Linux-5.4.0-216-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.6.20
CUDA_MODULE_LOADING set to: 
GPU models and configuration: 
GPU 0: NVIDIA GeForce RTX 3090
GPU 1: NVIDIA GeForce RTX 3090
GPU 2: NVIDIA GeForce RTX 3090
GPU 3: NVIDIA GeForce RTX 3090
GPU 4: NVIDIA GeForce RTX 3090
GPU 5: NVIDIA GeForce RTX 3090
GPU 6: NVIDIA GeForce RTX 3090
GPU 7: NVIDIA GeForce RTX 3090

Nvidia driver version: 570.133.07
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      46 bits physical, 48 bits virtual
CPU(s):                             80
On-line CPU(s) list:                0-79
Thread(s) per core:                 2
Core(s) per socket:                 20
Socket(s):                          2
NUMA node(s):                       2
Vendor ID:                          GenuineIntel
CPU family:                         6
Model:                              85
Model name:                         Intel(R) Xeon(R) Gold 6138T CPU @ 2.00GHz
Stepping:                           4
CPU MHz:                            1000.040
CPU max MHz:                        3700.0000
CPU min MHz:                        1000.0000
BogoMIPS:                           4000.00
Virtualization:                     VT-x
L1d cache:                          1.3 MiB
L1i cache:                          1.3 MiB
L2 cache:                           40 MiB
L3 cache:                           55 MiB
NUMA node0 CPU(s):                  0-19,40-59
NUMA node1 CPU(s):                  20-39,60-79
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit:        KVM: Mitigation: Split huge pages
Vulnerability L1tf:                 Mitigation; PTE Inversion; VMX conditional cache flushes, SMT vulnerable
Vulnerability Mds:                  Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown:             Mitigation; PTI
Vulnerability Mmio stale data:      Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed:             Mitigation; IBRS
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; IBRS; IBPB conditional; STIBP conditional; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Mitigation; Clear CPU buffers; SMT vulnerable
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single pti intel_ppin ssbd mba ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke md_clear flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] numpy==2.3.5
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cudnn-frontend==1.16.0
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] torch==2.9.0
[pip3] torch_c_dlpack_ext==0.1.4
[pip3] torchaudio==2.9.0
[pip3] torchvision==0.24.0
[pip3] triton==3.5.0
[conda] Could not collect

Problem description

We are experiencing a critical issue where the TileLang compiler generates different CUDA kernels for the same source code, PROBABLY leading to correctness failures in production while passing in test environments. The issue manifests as precision errors when running multi-round decode operations in the engine, despite the test suite passing successfully.

Code Comparison

TileLang Source Code

The source code in diffulex_kernel/python/dllm_flash_attn_kernels.py:

@tilelang.jit(
    out_idx=[-1], 
    pass_configs={
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
    }
)
def dllm_flash_attn_decode_kernel(...):
    # ... initialization code ...
    
    # Stage 1: KV Cache Attention (Context)
    for page_block_idx_local in T.Pipelined(MAX_SEQ_NUM_BLOCKS, num_stages=NUM_STAGES):
        page_block_idx_global = block_tables[seq_idx, page_block_idx_local]
        if page_block_idx_global >= 0:
            T.copy(K_Cache[page_block_idx_global, :, kv_head_idx, :], K_Cache_shared)
            
            for i, j in T.Parallel(BLOCK_M, PAGE_BLOCK_SIZE):
                acc_score_kvcache[i, j] = T.if_then_else(
                    (i >= cur_q_seqlen or 
                    page_block_idx_local * PAGE_BLOCK_SIZE + j >= cur_context_len), 
                    -1e9, 0
                )
            
            # ... rest of the computation ...

Key Point: The boundary check uses a uniform formula for all blocks:

page_block_idx_local * PAGE_BLOCK_SIZE + j >= cur_context_len

Generated CUDA code

1. Loop Structure Differences

Failed Version:

  • Uses a simple for loop to process all blocks (0 to MAX_SEQ_NUM_BLOCKS-1)
  • Uses the same processing logic for all blocks

Success Version:

  • Prefetches data for the first block before the loop
  • The loop only processes the first N-1 blocks (0 to MAX_SEQ_NUM_BLOCKS-2)
  • The last block is handled separately, using different boundary check logic

2. Memory Access Pattern Differences

Failed Version:

  • Uses synchronous memory access *(uint4*)
  • Each iteration blocks while waiting for data loading to complete
  • Poorer performance but simpler logic

Success Version:

  • Uses asynchronous prefetch tl::cp_async_gs_conditional
  • Implements pipeline optimization: prefetch the next block while processing the current block
  • Requires correct synchronization points (tl::cp_async_wait<0>() and __syncthreads())

3. Boundary Check Condition Differences

This is the most critical difference, directly affecting the calculation of attention scores:

Failed Version (Unified for all blocks):
// Line 84: Uses the same boundary check for all blocks
if (cur_context_len <= (page_block_idx_local * 32 + offset)) {
    // mask this position
}
Success Version (Special handling for the last block):
// Line 89: Blocks within the loop
if (cur_context_len <= (page_block_idx_local * 32 + offset)) {
    // mask
}

// Line 232: The last block - uses a different formula!
if ((cur_context_len + 32) <= (MAX_SEQ_NUM_BLOCKS * 32 + offset)) {
    // mask
}

Key Differences:

  1. The success version adds an offset of + 32 for the last block.
  2. The success version uses MAX_SEQ_NUM_BLOCKS * 32 instead of page_block_idx_local * 32.

Impact Example:
Suppose MAX_SEQ_NUM_BLOCKS = 64, page_block_idx_local = 63, cur_context_len = 2048:

  • Failed Version: 2048 <= (63 * 32 + offset) = 2048 <= (2016 + offset)
    • If the offset is small, the result is false (no mask), potentially leading to incorrect calculation of tokens that should not be included.
  • Success Version: (2048 + 32) <= (64 * 32 + offset) = 2080 <= (2048 + offset)
    • If the offset is small, the result is true (mask), correctly masking tokens that exceed the context_len.

This results in:

  • Certain tokens in the last block being incorrectly included or excluded from the attention calculation.
  • Directly affecting attention scores and final output precision.

Reproducible example code

1. Generate Failed Test Cases

The engine automatically saves failed test cases when precision errors are detected. To manually generate a test case, which is called from here, the checker function saves the cuda kernel code into failed_test_cases.

To quickly reproduce this failed test case, you can run:

https://github.com/drewjin/Tilelang-failed_test_cases/blob/master/failed_test_cases/decode_kernel_failure_20251228_142155_370544/reproduce_test.py

2. Run the All-Passed Test Cases

Run the test file and check the cuda kernel code.

Traceback

[rank0]: Traceback (most recent call last):
[rank0]:   File "/data1/jyj/Diffulex/test/python/utils/checker.py", line 248, in CHECK_FLASH_ATTN_DECODE
[rank0]:     torch.testing.assert_close(
[rank0]:   File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/testing/_comparison.py", line 1589, in assert_close
[rank0]:     raise error_metas[0].to_error(msg)
[rank0]: AssertionError: Decode kernel output does not match reference implementation

[rank0]: During handling of the above exception, another exception occurred:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/data1/jyj/.cache/micromamba/envs/syspy/lib/python3.12/runpy.py", line 198, in _run_module_as_main
[rank0]:     return _run_code(code, main_globals, None,
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data1/jyj/.cache/micromamba/envs/syspy/lib/python3.12/runpy.py", line 88, in _run_code
[rank0]:     exec(code, run_globals)
[rank0]:   File "/home/jyj/.cursor-server/extensions/ms-python.debugpy-2025.18.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 71, in <module>
[rank0]:     cli.main()
[rank0]:   File "/home/jyj/.cursor-server/extensions/ms-python.debugpy-2025.18.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 508, in main
[rank0]:     run()
[rank0]:   File "/home/jyj/.cursor-server/extensions/ms-python.debugpy-2025.18.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 358, in run_file
[rank0]:     runpy.run_path(target, run_name="__main__")
[rank0]:   File "/home/jyj/.cursor-server/extensions/ms-python.debugpy-2025.18.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 310, in run_path
[rank0]:     return _run_module_code(code, init_globals, run_name, pkg_name=pkg_name, script_name=fname)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jyj/.cursor-server/extensions/ms-python.debugpy-2025.18.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 127, in _run_module_code
[rank0]:     _run_code(code, mod_globals, init_globals, mod_name, mod_spec, pkg_name, script_name)
[rank0]:   File "/home/jyj/.cursor-server/extensions/ms-python.debugpy-2025.18.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 118, in _run_code
[rank0]:     exec(code, run_globals)
[rank0]:   File "/home/jyj/workspace/Diffulex/examples/test_fastdllmv2_diffulex_gsm8k.py", line 76, in <module>
[rank0]:     outputs = LLM.generate(prompts, sampling_params)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data1/jyj/Diffulex/diffulex/engine/tp_worker.py", line 101, in generate
[rank0]:     output, num_tokens, is_prefill, cur_n_diff_steps, _ = self.step()
[rank0]:                                                           ^^^^^^^^^^^
[rank0]:   File "/data1/jyj/Diffulex/diffulex/engine/tp_worker.py", line 68, in step
[rank0]:     sample_output = self.model_runner.call("run", seqs, is_prefill)
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data1/jyj/Diffulex/diffulex/engine/model_runner.py", line 112, in call
[rank0]:     return method(*args)
[rank0]:            ^^^^^^^^^^^^^
[rank0]:   File "/data1/jyj/Diffulex/diffulex/strategy/block_diffusion/engine/model_runner.py", line 183, in run
[rank0]:     logits = self.run_model(input_ids, positions, is_prefill)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data1/jyj/Diffulex/diffulex/strategy/block_diffusion/engine/model_runner.py", line 160, in run_model
[rank0]:     return self.model.compute_logits(self.model(input_ids, positions))
[rank0]:                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data1/jyj/Diffulex/diffulex/model/fast_dllm_v2.py", line 230, in forward
[rank0]:     hidden_states = self.model(input_ids, positions, mask)
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data1/jyj/Diffulex/diffulex/model/fast_dllm_v2.py", line 204, in forward
[rank0]:     hidden_states, residual = layer(positions, hidden_states, residual, mask)
[rank0]:                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data1/jyj/Diffulex/diffulex/model/fast_dllm_v2.py", line 177, in forward
[rank0]:     hidden_states = self.self_attn(positions, hidden_states, mask)
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data1/jyj/Diffulex/diffulex/model/fast_dllm_v2.py", line 99, in forward
[rank0]:     o = self.attn(q, k, v, mask)
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data1/jyj/Diffulex/diffulex/attention/attn_impl.py", line 67, in forward
[rank0]:     o = dllm_flash_attn_decode(q, k, v, k_cache, v_cache, self.scale, attn_metadata)
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data1/jyj/Diffulex/diffulex_kernel/python/dllm_flash_attn_kernels.py", line 464, in dllm_flash_attn_decode
[rank0]:     CHECK_FLASH_ATTN_DECODE(
[rank0]:   File "/data1/jyj/Diffulex/test/python/utils/checker.py", line 505, in CHECK_FLASH_ATTN_DECODE
[rank0]:     raise AssertionError(
[rank0]: AssertionError: Decode kernel verification failed!
[rank0]: Max absolute difference: 0.062500
[rank0]: Mean absolute difference: 0.000483
[rank0]: Max relative difference: 18176.000000
[rank0]: Mean relative difference: 0.020630
[rank0]: Total elements: 1605632
[rank0]: Elements exceeding absolute tolerance (atol=0.01): 8161 (0.51%)
[rank0]: Elements exceeding relative tolerance (rtol=0.01): 55206 (3.44%)
[rank0]: Elements exceeding either tolerance: 63362 (3.95%)
[rank0]: Kernel output shape: torch.Size([448, 28, 128])
[rank0]: Reference output shape: torch.Size([448, 28, 128])


[rank0]: Mismatched elements (showing first 50 of 63362):
[rank0]: ----------------------------------------------------------------------------------------------------
[rank0]: Index                          Kernel Value         Ref Value            Abs Diff        Rel Diff       
[rank0]: ----------------------------------------------------------------------------------------------------
[rank0]: (0, 0, 8)                                 0.012451            0.011963       0.000488       0.040771
[rank0]: (0, 0, 28)                                0.112793            0.114746       0.001953       0.016968
[rank0]: (0, 0, 37)                               -2.062500           -2.078125       0.015625       0.007507
[rank0]: (0, 0, 69)                               -0.017822           -0.016357       0.001465       0.089355
[rank0]: (0, 0, 76)                               -0.079590           -0.078125       0.001465       0.018799
[rank0]: (0, 0, 79)                                0.016357            0.016968       0.000610       0.035889
[rank0]: (0, 0, 85)                               -0.034424           -0.033691       0.000732       0.021729
[rank0]: (0, 0, 89)                               -0.042236           -0.043457       0.001221       0.028076
[rank0]: (0, 0, 97)                               -0.022583           -0.023193       0.000610       0.026367
[rank0]: (0, 1, 22)                               -0.025024           -0.025757       0.000732       0.028442
[rank0]: (0, 1, 95)                                0.066895            0.067871       0.000977       0.014404
[rank0]: (0, 1, 101)                              -0.033447           -0.032959       0.000488       0.014832
[rank0]: (0, 2, 11)                                0.033447            0.034180       0.000732       0.021484
[rank0]: (0, 2, 15)                                0.001480            0.002151       0.000671       0.312500
[rank0]: (0, 2, 27)                                2.203125            2.218750       0.015625       0.007050
[rank0]: (0, 2, 53)                               -2.671875           -2.656250       0.015625       0.005890
[rank0]: (0, 2, 87)                               -0.012695           -0.014160       0.001465       0.103516
[rank0]: (0, 3, 12)                                0.060791            0.061768       0.000977       0.015869
[rank0]: (0, 3, 23)                               -2.687500           -2.703125       0.015625       0.005768
[rank0]: (0, 4, 14)                                0.008179            0.008362       0.000183       0.021851
[rank0]: (0, 4, 17)                                0.008057            0.008362       0.000305       0.036377
[rank0]: (0, 4, 21)                                0.005676            0.006561       0.000885       0.134766
[rank0]: (0, 4, 96)                                0.099121            0.100586       0.001465       0.014587
[rank0]: (0, 4, 116)                               0.118652            0.120117       0.001465       0.012207
[rank0]: (0, 5, 2)                                -0.010925           -0.010376       0.000549       0.052979
[rank0]: (0, 5, 4)                                 0.014771            0.014465       0.000305       0.021118
[rank0]: (0, 5, 5)                                 0.061035            0.062012       0.000977       0.015747
[rank0]: (0, 5, 26)                               -0.004639           -0.004730       0.000092       0.019409
[rank0]: (0, 5, 69)                                0.052979            0.052002       0.000977       0.018799
[rank0]: (0, 6, 15)                                0.032959            0.033447       0.000488       0.014587
[rank0]: (0, 6, 24)                                0.043701            0.043213       0.000488       0.011292
[rank0]: (0, 6, 27)                               -0.023926           -0.023560       0.000366       0.015564
[rank0]: (0, 6, 42)                                0.052002            0.051270       0.000732       0.014282
[rank0]: (0, 6, 69)                                0.000131            0.001091       0.000961       0.882812
[rank0]: (0, 6, 87)                               -0.060059           -0.061768       0.001709       0.027710
[rank0]: (0, 6, 100)                              -0.041748           -0.041260       0.000488       0.011841
[rank0]: (0, 6, 101)                               0.081543            0.080078       0.001465       0.018311
[rank0]: (0, 6, 106)                              -0.111816           -0.113281       0.001465       0.012939
[rank0]: (0, 6, 110)                               0.005280            0.004791       0.000488       0.102051
[rank0]: (0, 8, 3)                                 2.718750            2.734375       0.015625       0.005707
[rank0]: (0, 8, 34)                               -0.003693           -0.003998       0.000305       0.076172
[rank0]: (0, 8, 42)                               -0.000679           -0.000717       0.000038       0.053223
[rank0]: (0, 8, 54)                               -0.026489           -0.026123       0.000366       0.014038
[rank0]: (0, 8, 125)                               2.859375            2.875000       0.015625       0.005432
[rank0]: (0, 9, 3)                                 2.984375            2.968750       0.015625       0.005249
[rank0]: (0, 9, 34)                                0.008911            0.009094       0.000183       0.020142
[rank0]: (0, 11, 98)                              -0.001869           -0.001907       0.000038       0.020020
[rank0]: (0, 14, 32)                               0.171875            0.169922       0.001953       0.011475
[rank0]: (0, 14, 42)                               0.052246            0.053223       0.000977       0.018311
[rank0]: (0, 14, 80)                               0.093750            0.094727       0.000977       0.010315

[rank0]: ... and 63312 more mismatches

[rank0]: Mismatch distribution by dimensions:
[rank0]:   Dim 0 (size 448): [120, 173, 151, 174, 154, 134, 153, 176, 175, 159, 182, 192, 246, 164, 181, 127, 175, 105, 135, 146, 146, 97, 97, 112, 123, 95, 100, 131, 116, 150, 147, 137, 133, 134, 100, 100, 98, 94, 95, 101, 140, 121, 114, 172, 147, 126, 134, 132, 144, 137, 135, 141, 149, 118, 138, 129, 136, 138, 96, 133, 128, 140, 141, 114, 185, 147, 148, 147, 165, 107, 138, 133, 124, 147, 132, 125, 87, 72, 71, 105, 82, 109, 92, 119, 170, 121, 118, 104, 112, 113, 125, 140, 119, 138, 119, 145, 112, 165, 113, 127, 133, 127, 100, 134, 84, 106, 115, 122, 117, 123, 145, 119, 150, 139, 154, 120, 103, 124, 101, 112, 122, 127, 105, 118, 125, 118, 111, 122, 155, 184, 177, 215, 181, 184, 193, 156, 166, 176, 202, 183, 196, 202, 173, 134, 194, 171, 160, 151, 175, 201, 129, 179, 160, 184, 191, 169, 171, 173, 133, 191, 146, 163, 141, 155, 142, 137, 135, 170, 186, 128, 131, 124, 152, 138, 152, 115, 85, 102, 55, 84, 100, 64, 86, 90, 123, 115, 106, 120, 116, 99, 151, 190, 217, 169, 161, 164, 168, 136, 182, 160, 163, 160, 150, 129, 148, 143, 140, 146, 151, 152, 130, 151, 149, 162, 156, 149, 142, 153, 162, 168, 151, 160, 159, 139, 45, 57, 63, 41, 31, 45, 79, 43, 53, 80, 67, 64, 38, 52, 56, 67, 62, 76, 64, 67, 79, 75, 51, 53, 64, 61, 78, 62, 68, 67, 63, 75, 130, 174, 180, 175, 144, 152, 130, 132, 150, 153, 167, 142, 149, 156, 169, 167, 141, 152, 153, 133, 141, 144, 179, 140, 189, 162, 165, 178, 180, 160, 180, 151, 153, 141, 168, 150, 148, 181, 120, 137, 139, 117, 164, 191, 121, 88, 156, 141, 146, 136, 147, 165, 130, 129, 135, 146, 109, 97, 100, 106, 99, 133, 99, 164, 182, 182, 184, 217, 208, 168, 164, 182, 146, 155, 138, 132, 130, 182, 167, 146, 145, 155, 177, 179, 168, 158, 165, 191, 146, 169, 169, 158, 199, 192, 167, 159, 186, 196, 199, 225, 195, 179, 168, 166, 163, 162, 137, 136, 135, 161, 133, 153, 149, 141, 138, 156, 144, 135, 156, 145, 140, 158, 143, 167, 170, 150, 140, 152, 187, 155, 233, 211, 170, 130, 137, 165, 169, 89, 150, 140, 132, 144, 174, 133, 180, 137, 120, 103, 156, 108, 127, 132, 187, 134, 169, 162, 135, 129, 164, 157, 140, 176, 162, 188, 155, 143, 194, 183, 172, 118, 213, 161, 184, 186, 209, 96, 226, 165, 180, 167, 147, 228, 172, 184, 215, 218, 201, 144, 174, 151, 177, 173]
[rank0]:   Dim 1 (size 28): [2840, 2765, 3022, 2721, 2402, 2731, 3362, 180, 281, 575, 32, 30, 116, 81, 2911, 3115, 3007, 2601, 2987, 2933, 3453, 3191, 2701, 2949, 3265, 3261, 2914, 2936]
[rank0]:   Dim 2 (size 128): [507, 509, 503, 506, 500, 485, 499, 444, 429, 507, 415, 563, 553, 362, 539, 467, 550, 543, 510, 511, 550, 482, 546, 450, 472, 516, 453, 549, 470, 517, 606, 413, 518, 470, 498, 443, 535, 473, 381, 497, 500, 515, 522, 418, 501, 475, 465, 471, 552, 528, 493, 475, 586, 486, 503, 501, 544, 522, 520, 539, 431, 500, 569, 480, 470, 446, 447, 436, 508, 435, 527, 478, 525, 422, 520, 475, 530, 399, 470, 450, 535, 438, 527, 540, 421, 507, 483, 441, 514, 482, 472, 473, 591, 488, 524, 514, 485, 464, 493, 436, 587, 591, 477, 520, 553, 524, 522, 508, 433, 471, 530, 562, 522, 474, 457, 533, 474, 462, 522, 502, 510, 428, 480, 529, 514, 541, 434, 499]


[rank0]: Test case data saved to: failed_test_cases/decode_kernel_failure_20251228_163209_272379
[rank0]:   - test_data.pkl: All input/output tensors and metadata
[rank0]:   - reproduce_test.py: Script to reproduce the test case
[rank0]:   - error_summary.txt: Summary of the failure
[rank0]:   - kernel.cu: CUDA kernel source code
[rank0]: Original error: Decode kernel output does not match reference implementation

Expected behavior

Expected Behavior

  1. All blocks should use consistent boundary checking logic
  2. The boundary check should correctly mask tokens beyond cur_context_len
  3. The generated kernel should be deterministic for the same source code and parameters

Actual Behavior

  1. Test environment: Generates kernel with special handling for last block (passes tests)
  2. Production environment: Generates kernel with uniform handling (fails precision checks)
  3. Inconsistent behavior: Same source code produces different kernels

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions