-
Notifications
You must be signed in to change notification settings - Fork 359
Description
Required prerequisites
- I have read the documentation https://tilelang.com.
- I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
What version of TileLang are you using?
0.1.7.post1
System information
manual collect:
3.12.12 | packaged by conda-forge | (main, Oct 22 2025, 23:25:55) [GCC 14.3.0] linux
0.1.7.post1
2.9.0+cu128torch.utils.collect_env:
PyTorch version: 2.9.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31
Python version: 3.12.12 | packaged by conda-forge | (main, Oct 22 2025, 23:25:55) [GCC 14.3.0] (64-bit runtime)
Python platform: Linux-5.4.0-216-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.6.20
CUDA_MODULE_LOADING set to:
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 3090
GPU 1: NVIDIA GeForce RTX 3090
GPU 2: NVIDIA GeForce RTX 3090
GPU 3: NVIDIA GeForce RTX 3090
GPU 4: NVIDIA GeForce RTX 3090
GPU 5: NVIDIA GeForce RTX 3090
GPU 6: NVIDIA GeForce RTX 3090
GPU 7: NVIDIA GeForce RTX 3090
Nvidia driver version: 570.133.07
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 80
On-line CPU(s) list: 0-79
Thread(s) per core: 2
Core(s) per socket: 20
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Gold 6138T CPU @ 2.00GHz
Stepping: 4
CPU MHz: 1000.040
CPU max MHz: 3700.0000
CPU min MHz: 1000.0000
BogoMIPS: 4000.00
Virtualization: VT-x
L1d cache: 1.3 MiB
L1i cache: 1.3 MiB
L2 cache: 40 MiB
L3 cache: 55 MiB
NUMA node0 CPU(s): 0-19,40-59
NUMA node1 CPU(s): 20-39,60-79
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit: KVM: Mitigation: Split huge pages
Vulnerability L1tf: Mitigation; PTE Inversion; VMX conditional cache flushes, SMT vulnerable
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed: Mitigation; IBRS
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; IBRS; IBPB conditional; STIBP conditional; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; Clear CPU buffers; SMT vulnerable
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single pti intel_ppin ssbd mba ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke md_clear flush_l1d arch_capabilities
Versions of relevant libraries:
[pip3] numpy==2.3.5
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cudnn-frontend==1.16.0
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] torch==2.9.0
[pip3] torch_c_dlpack_ext==0.1.4
[pip3] torchaudio==2.9.0
[pip3] torchvision==0.24.0
[pip3] triton==3.5.0
[conda] Could not collectProblem description
We are experiencing a critical issue where the TileLang compiler generates different CUDA kernels for the same source code, PROBABLY leading to correctness failures in production while passing in test environments. The issue manifests as precision errors when running multi-round decode operations in the engine, despite the test suite passing successfully.
Code Comparison
TileLang Source Code
The source code in diffulex_kernel/python/dllm_flash_attn_kernels.py:
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}
)
def dllm_flash_attn_decode_kernel(...):
# ... initialization code ...
# Stage 1: KV Cache Attention (Context)
for page_block_idx_local in T.Pipelined(MAX_SEQ_NUM_BLOCKS, num_stages=NUM_STAGES):
page_block_idx_global = block_tables[seq_idx, page_block_idx_local]
if page_block_idx_global >= 0:
T.copy(K_Cache[page_block_idx_global, :, kv_head_idx, :], K_Cache_shared)
for i, j in T.Parallel(BLOCK_M, PAGE_BLOCK_SIZE):
acc_score_kvcache[i, j] = T.if_then_else(
(i >= cur_q_seqlen or
page_block_idx_local * PAGE_BLOCK_SIZE + j >= cur_context_len),
-1e9, 0
)
# ... rest of the computation ...Key Point: The boundary check uses a uniform formula for all blocks:
page_block_idx_local * PAGE_BLOCK_SIZE + j >= cur_context_lenGenerated CUDA code
-
Failed Version: https://github.com/drewjin/Tilelang-failed_test_cases/blob/master/failed_test_cases/decode_kernel_failure_20251228_142155_370544/kernel.cu
1. Loop Structure Differences
Failed Version:
- Uses a simple
forloop to process all blocks (0 to MAX_SEQ_NUM_BLOCKS-1) - Uses the same processing logic for all blocks
Success Version:
- Prefetches data for the first block before the loop
- The loop only processes the first N-1 blocks (0 to MAX_SEQ_NUM_BLOCKS-2)
- The last block is handled separately, using different boundary check logic
2. Memory Access Pattern Differences
Failed Version:
- Uses synchronous memory access
*(uint4*) - Each iteration blocks while waiting for data loading to complete
- Poorer performance but simpler logic
Success Version:
- Uses asynchronous prefetch
tl::cp_async_gs_conditional - Implements pipeline optimization: prefetch the next block while processing the current block
- Requires correct synchronization points (
tl::cp_async_wait<0>()and__syncthreads())
3. Boundary Check Condition Differences
This is the most critical difference, directly affecting the calculation of attention scores:
Failed Version (Unified for all blocks):
// Line 84: Uses the same boundary check for all blocks
if (cur_context_len <= (page_block_idx_local * 32 + offset)) {
// mask this position
}Success Version (Special handling for the last block):
// Line 89: Blocks within the loop
if (cur_context_len <= (page_block_idx_local * 32 + offset)) {
// mask
}
// Line 232: The last block - uses a different formula!
if ((cur_context_len + 32) <= (MAX_SEQ_NUM_BLOCKS * 32 + offset)) {
// mask
}Key Differences:
- The success version adds an offset of
+ 32for the last block. - The success version uses
MAX_SEQ_NUM_BLOCKS * 32instead ofpage_block_idx_local * 32.
Impact Example:
Suppose MAX_SEQ_NUM_BLOCKS = 64, page_block_idx_local = 63, cur_context_len = 2048:
- Failed Version:
2048 <= (63 * 32 + offset)=2048 <= (2016 + offset)- If the offset is small, the result is false (no mask), potentially leading to incorrect calculation of tokens that should not be included.
- Success Version:
(2048 + 32) <= (64 * 32 + offset)=2080 <= (2048 + offset)- If the offset is small, the result is true (mask), correctly masking tokens that exceed the
context_len.
- If the offset is small, the result is true (mask), correctly masking tokens that exceed the
This results in:
- Certain tokens in the last block being incorrectly included or excluded from the attention calculation.
- Directly affecting attention scores and final output precision.
Reproducible example code
1. Generate Failed Test Cases
The engine automatically saves failed test cases when precision errors are detected. To manually generate a test case, which is called from here, the checker function saves the cuda kernel code into failed_test_cases.
To quickly reproduce this failed test case, you can run:
2. Run the All-Passed Test Cases
Run the test file and check the cuda kernel code.
Traceback
[rank0]: Traceback (most recent call last):
[rank0]: File "/data1/jyj/Diffulex/test/python/utils/checker.py", line 248, in CHECK_FLASH_ATTN_DECODE
[rank0]: torch.testing.assert_close(
[rank0]: File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/testing/_comparison.py", line 1589, in assert_close
[rank0]: raise error_metas[0].to_error(msg)
[rank0]: AssertionError: Decode kernel output does not match reference implementation
[rank0]: During handling of the above exception, another exception occurred:
[rank0]: Traceback (most recent call last):
[rank0]: File "/data1/jyj/.cache/micromamba/envs/syspy/lib/python3.12/runpy.py", line 198, in _run_module_as_main
[rank0]: return _run_code(code, main_globals, None,
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data1/jyj/.cache/micromamba/envs/syspy/lib/python3.12/runpy.py", line 88, in _run_code
[rank0]: exec(code, run_globals)
[rank0]: File "/home/jyj/.cursor-server/extensions/ms-python.debugpy-2025.18.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 71, in <module>
[rank0]: cli.main()
[rank0]: File "/home/jyj/.cursor-server/extensions/ms-python.debugpy-2025.18.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 508, in main
[rank0]: run()
[rank0]: File "/home/jyj/.cursor-server/extensions/ms-python.debugpy-2025.18.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 358, in run_file
[rank0]: runpy.run_path(target, run_name="__main__")
[rank0]: File "/home/jyj/.cursor-server/extensions/ms-python.debugpy-2025.18.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 310, in run_path
[rank0]: return _run_module_code(code, init_globals, run_name, pkg_name=pkg_name, script_name=fname)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/jyj/.cursor-server/extensions/ms-python.debugpy-2025.18.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 127, in _run_module_code
[rank0]: _run_code(code, mod_globals, init_globals, mod_name, mod_spec, pkg_name, script_name)
[rank0]: File "/home/jyj/.cursor-server/extensions/ms-python.debugpy-2025.18.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 118, in _run_code
[rank0]: exec(code, run_globals)
[rank0]: File "/home/jyj/workspace/Diffulex/examples/test_fastdllmv2_diffulex_gsm8k.py", line 76, in <module>
[rank0]: outputs = LLM.generate(prompts, sampling_params)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data1/jyj/Diffulex/diffulex/engine/tp_worker.py", line 101, in generate
[rank0]: output, num_tokens, is_prefill, cur_n_diff_steps, _ = self.step()
[rank0]: ^^^^^^^^^^^
[rank0]: File "/data1/jyj/Diffulex/diffulex/engine/tp_worker.py", line 68, in step
[rank0]: sample_output = self.model_runner.call("run", seqs, is_prefill)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data1/jyj/Diffulex/diffulex/engine/model_runner.py", line 112, in call
[rank0]: return method(*args)
[rank0]: ^^^^^^^^^^^^^
[rank0]: File "/data1/jyj/Diffulex/diffulex/strategy/block_diffusion/engine/model_runner.py", line 183, in run
[rank0]: logits = self.run_model(input_ids, positions, is_prefill)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data1/jyj/Diffulex/diffulex/strategy/block_diffusion/engine/model_runner.py", line 160, in run_model
[rank0]: return self.model.compute_logits(self.model(input_ids, positions))
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data1/jyj/Diffulex/diffulex/model/fast_dllm_v2.py", line 230, in forward
[rank0]: hidden_states = self.model(input_ids, positions, mask)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data1/jyj/Diffulex/diffulex/model/fast_dllm_v2.py", line 204, in forward
[rank0]: hidden_states, residual = layer(positions, hidden_states, residual, mask)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data1/jyj/Diffulex/diffulex/model/fast_dllm_v2.py", line 177, in forward
[rank0]: hidden_states = self.self_attn(positions, hidden_states, mask)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data1/jyj/Diffulex/diffulex/model/fast_dllm_v2.py", line 99, in forward
[rank0]: o = self.attn(q, k, v, mask)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/jyj/workspace/Diffulex/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data1/jyj/Diffulex/diffulex/attention/attn_impl.py", line 67, in forward
[rank0]: o = dllm_flash_attn_decode(q, k, v, k_cache, v_cache, self.scale, attn_metadata)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data1/jyj/Diffulex/diffulex_kernel/python/dllm_flash_attn_kernels.py", line 464, in dllm_flash_attn_decode
[rank0]: CHECK_FLASH_ATTN_DECODE(
[rank0]: File "/data1/jyj/Diffulex/test/python/utils/checker.py", line 505, in CHECK_FLASH_ATTN_DECODE
[rank0]: raise AssertionError(
[rank0]: AssertionError: Decode kernel verification failed!
[rank0]: Max absolute difference: 0.062500
[rank0]: Mean absolute difference: 0.000483
[rank0]: Max relative difference: 18176.000000
[rank0]: Mean relative difference: 0.020630
[rank0]: Total elements: 1605632
[rank0]: Elements exceeding absolute tolerance (atol=0.01): 8161 (0.51%)
[rank0]: Elements exceeding relative tolerance (rtol=0.01): 55206 (3.44%)
[rank0]: Elements exceeding either tolerance: 63362 (3.95%)
[rank0]: Kernel output shape: torch.Size([448, 28, 128])
[rank0]: Reference output shape: torch.Size([448, 28, 128])
[rank0]: Mismatched elements (showing first 50 of 63362):
[rank0]: ----------------------------------------------------------------------------------------------------
[rank0]: Index Kernel Value Ref Value Abs Diff Rel Diff
[rank0]: ----------------------------------------------------------------------------------------------------
[rank0]: (0, 0, 8) 0.012451 0.011963 0.000488 0.040771
[rank0]: (0, 0, 28) 0.112793 0.114746 0.001953 0.016968
[rank0]: (0, 0, 37) -2.062500 -2.078125 0.015625 0.007507
[rank0]: (0, 0, 69) -0.017822 -0.016357 0.001465 0.089355
[rank0]: (0, 0, 76) -0.079590 -0.078125 0.001465 0.018799
[rank0]: (0, 0, 79) 0.016357 0.016968 0.000610 0.035889
[rank0]: (0, 0, 85) -0.034424 -0.033691 0.000732 0.021729
[rank0]: (0, 0, 89) -0.042236 -0.043457 0.001221 0.028076
[rank0]: (0, 0, 97) -0.022583 -0.023193 0.000610 0.026367
[rank0]: (0, 1, 22) -0.025024 -0.025757 0.000732 0.028442
[rank0]: (0, 1, 95) 0.066895 0.067871 0.000977 0.014404
[rank0]: (0, 1, 101) -0.033447 -0.032959 0.000488 0.014832
[rank0]: (0, 2, 11) 0.033447 0.034180 0.000732 0.021484
[rank0]: (0, 2, 15) 0.001480 0.002151 0.000671 0.312500
[rank0]: (0, 2, 27) 2.203125 2.218750 0.015625 0.007050
[rank0]: (0, 2, 53) -2.671875 -2.656250 0.015625 0.005890
[rank0]: (0, 2, 87) -0.012695 -0.014160 0.001465 0.103516
[rank0]: (0, 3, 12) 0.060791 0.061768 0.000977 0.015869
[rank0]: (0, 3, 23) -2.687500 -2.703125 0.015625 0.005768
[rank0]: (0, 4, 14) 0.008179 0.008362 0.000183 0.021851
[rank0]: (0, 4, 17) 0.008057 0.008362 0.000305 0.036377
[rank0]: (0, 4, 21) 0.005676 0.006561 0.000885 0.134766
[rank0]: (0, 4, 96) 0.099121 0.100586 0.001465 0.014587
[rank0]: (0, 4, 116) 0.118652 0.120117 0.001465 0.012207
[rank0]: (0, 5, 2) -0.010925 -0.010376 0.000549 0.052979
[rank0]: (0, 5, 4) 0.014771 0.014465 0.000305 0.021118
[rank0]: (0, 5, 5) 0.061035 0.062012 0.000977 0.015747
[rank0]: (0, 5, 26) -0.004639 -0.004730 0.000092 0.019409
[rank0]: (0, 5, 69) 0.052979 0.052002 0.000977 0.018799
[rank0]: (0, 6, 15) 0.032959 0.033447 0.000488 0.014587
[rank0]: (0, 6, 24) 0.043701 0.043213 0.000488 0.011292
[rank0]: (0, 6, 27) -0.023926 -0.023560 0.000366 0.015564
[rank0]: (0, 6, 42) 0.052002 0.051270 0.000732 0.014282
[rank0]: (0, 6, 69) 0.000131 0.001091 0.000961 0.882812
[rank0]: (0, 6, 87) -0.060059 -0.061768 0.001709 0.027710
[rank0]: (0, 6, 100) -0.041748 -0.041260 0.000488 0.011841
[rank0]: (0, 6, 101) 0.081543 0.080078 0.001465 0.018311
[rank0]: (0, 6, 106) -0.111816 -0.113281 0.001465 0.012939
[rank0]: (0, 6, 110) 0.005280 0.004791 0.000488 0.102051
[rank0]: (0, 8, 3) 2.718750 2.734375 0.015625 0.005707
[rank0]: (0, 8, 34) -0.003693 -0.003998 0.000305 0.076172
[rank0]: (0, 8, 42) -0.000679 -0.000717 0.000038 0.053223
[rank0]: (0, 8, 54) -0.026489 -0.026123 0.000366 0.014038
[rank0]: (0, 8, 125) 2.859375 2.875000 0.015625 0.005432
[rank0]: (0, 9, 3) 2.984375 2.968750 0.015625 0.005249
[rank0]: (0, 9, 34) 0.008911 0.009094 0.000183 0.020142
[rank0]: (0, 11, 98) -0.001869 -0.001907 0.000038 0.020020
[rank0]: (0, 14, 32) 0.171875 0.169922 0.001953 0.011475
[rank0]: (0, 14, 42) 0.052246 0.053223 0.000977 0.018311
[rank0]: (0, 14, 80) 0.093750 0.094727 0.000977 0.010315
[rank0]: ... and 63312 more mismatches
[rank0]: Mismatch distribution by dimensions:
[rank0]: Dim 0 (size 448): [120, 173, 151, 174, 154, 134, 153, 176, 175, 159, 182, 192, 246, 164, 181, 127, 175, 105, 135, 146, 146, 97, 97, 112, 123, 95, 100, 131, 116, 150, 147, 137, 133, 134, 100, 100, 98, 94, 95, 101, 140, 121, 114, 172, 147, 126, 134, 132, 144, 137, 135, 141, 149, 118, 138, 129, 136, 138, 96, 133, 128, 140, 141, 114, 185, 147, 148, 147, 165, 107, 138, 133, 124, 147, 132, 125, 87, 72, 71, 105, 82, 109, 92, 119, 170, 121, 118, 104, 112, 113, 125, 140, 119, 138, 119, 145, 112, 165, 113, 127, 133, 127, 100, 134, 84, 106, 115, 122, 117, 123, 145, 119, 150, 139, 154, 120, 103, 124, 101, 112, 122, 127, 105, 118, 125, 118, 111, 122, 155, 184, 177, 215, 181, 184, 193, 156, 166, 176, 202, 183, 196, 202, 173, 134, 194, 171, 160, 151, 175, 201, 129, 179, 160, 184, 191, 169, 171, 173, 133, 191, 146, 163, 141, 155, 142, 137, 135, 170, 186, 128, 131, 124, 152, 138, 152, 115, 85, 102, 55, 84, 100, 64, 86, 90, 123, 115, 106, 120, 116, 99, 151, 190, 217, 169, 161, 164, 168, 136, 182, 160, 163, 160, 150, 129, 148, 143, 140, 146, 151, 152, 130, 151, 149, 162, 156, 149, 142, 153, 162, 168, 151, 160, 159, 139, 45, 57, 63, 41, 31, 45, 79, 43, 53, 80, 67, 64, 38, 52, 56, 67, 62, 76, 64, 67, 79, 75, 51, 53, 64, 61, 78, 62, 68, 67, 63, 75, 130, 174, 180, 175, 144, 152, 130, 132, 150, 153, 167, 142, 149, 156, 169, 167, 141, 152, 153, 133, 141, 144, 179, 140, 189, 162, 165, 178, 180, 160, 180, 151, 153, 141, 168, 150, 148, 181, 120, 137, 139, 117, 164, 191, 121, 88, 156, 141, 146, 136, 147, 165, 130, 129, 135, 146, 109, 97, 100, 106, 99, 133, 99, 164, 182, 182, 184, 217, 208, 168, 164, 182, 146, 155, 138, 132, 130, 182, 167, 146, 145, 155, 177, 179, 168, 158, 165, 191, 146, 169, 169, 158, 199, 192, 167, 159, 186, 196, 199, 225, 195, 179, 168, 166, 163, 162, 137, 136, 135, 161, 133, 153, 149, 141, 138, 156, 144, 135, 156, 145, 140, 158, 143, 167, 170, 150, 140, 152, 187, 155, 233, 211, 170, 130, 137, 165, 169, 89, 150, 140, 132, 144, 174, 133, 180, 137, 120, 103, 156, 108, 127, 132, 187, 134, 169, 162, 135, 129, 164, 157, 140, 176, 162, 188, 155, 143, 194, 183, 172, 118, 213, 161, 184, 186, 209, 96, 226, 165, 180, 167, 147, 228, 172, 184, 215, 218, 201, 144, 174, 151, 177, 173]
[rank0]: Dim 1 (size 28): [2840, 2765, 3022, 2721, 2402, 2731, 3362, 180, 281, 575, 32, 30, 116, 81, 2911, 3115, 3007, 2601, 2987, 2933, 3453, 3191, 2701, 2949, 3265, 3261, 2914, 2936]
[rank0]: Dim 2 (size 128): [507, 509, 503, 506, 500, 485, 499, 444, 429, 507, 415, 563, 553, 362, 539, 467, 550, 543, 510, 511, 550, 482, 546, 450, 472, 516, 453, 549, 470, 517, 606, 413, 518, 470, 498, 443, 535, 473, 381, 497, 500, 515, 522, 418, 501, 475, 465, 471, 552, 528, 493, 475, 586, 486, 503, 501, 544, 522, 520, 539, 431, 500, 569, 480, 470, 446, 447, 436, 508, 435, 527, 478, 525, 422, 520, 475, 530, 399, 470, 450, 535, 438, 527, 540, 421, 507, 483, 441, 514, 482, 472, 473, 591, 488, 524, 514, 485, 464, 493, 436, 587, 591, 477, 520, 553, 524, 522, 508, 433, 471, 530, 562, 522, 474, 457, 533, 474, 462, 522, 502, 510, 428, 480, 529, 514, 541, 434, 499]
[rank0]: Test case data saved to: failed_test_cases/decode_kernel_failure_20251228_163209_272379
[rank0]: - test_data.pkl: All input/output tensors and metadata
[rank0]: - reproduce_test.py: Script to reproduce the test case
[rank0]: - error_summary.txt: Summary of the failure
[rank0]: - kernel.cu: CUDA kernel source code
[rank0]: Original error: Decode kernel output does not match reference implementationExpected behavior
Expected Behavior
- All blocks should use consistent boundary checking logic
- The boundary check should correctly mask tokens beyond
cur_context_len - The generated kernel should be deterministic for the same source code and parameters
Actual Behavior
- Test environment: Generates kernel with special handling for last block (passes tests)
- Production environment: Generates kernel with uniform handling (fails precision checks)
- Inconsistent behavior: Same source code produces different kernels
Additional context
No response