Skip to content

Conversation

@ZiguanWang
Copy link
Contributor

@ZiguanWang ZiguanWang commented Jan 27, 2026

Modifications

  1. fix deepseek_mla amd example.
  2. add aiter mla compare test.

Tests

python examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py --compare

comparing torch vs flash_mla_triton: b=64, s_q=1, mean_seqlens=1087.0, h_q=128, h_kv=1, d=576, dv=512, causal=True, dtype=torch.float16
perf torch: 33.568 ms, 0.577 TFLOPS, 2.919 GB/s
perf flash_mla_triton: 1.083 ms, 17.891 TFLOPS, 90.456 GB/s
comparing torch vs flash_mla_triton: b=64, s_q=1, mean_seqlens=2111.0, h_q=128, h_kv=1, d=576, dv=512, causal=True, dtype=torch.float16
perf torch: 56.801 ms, 0.662 TFLOPS, 3.054 GB/s
perf flash_mla_triton: 1.544 ms, 24.376 TFLOPS, 112.368 GB/s
comparing torch vs flash_mla_triton: b=64, s_q=1, mean_seqlens=4159.0, h_q=128, h_kv=1, d=576, dv=512, causal=True, dtype=torch.float16
perf torch: 110.014 ms, 0.674 TFLOPS, 2.949 GB/s
perf flash_mla_triton: 2.190 ms, 33.857 TFLOPS, 148.175 GB/s
comparing torch vs flash_mla_triton: b=64, s_q=1, mean_seqlens=8255.0, h_q=128, h_kv=1, d=576, dv=512, causal=True, dtype=torch.float16
perf torch: 216.139 ms, 0.681 TFLOPS, 2.898 GB/s
perf flash_mla_triton: 3.438 ms, 42.797 TFLOPS, 182.192 GB/s
comparing torch vs flash_mla_triton: b=64, s_q=1, mean_seqlens=16447.0, h_q=128, h_kv=1, d=576, dv=512, causal=True, dtype=torch.float16
perf torch: 427.875 ms, 0.685 TFLOPS, 2.876 GB/s
perf flash_mla_triton: 6.021 ms, 48.696 TFLOPS, 204.368 GB/s
comparing torch vs flash_mla_triton: b=128, s_q=1, mean_seqlens=1151.0, h_q=128, h_kv=1, d=576, dv=512, causal=True, dtype=torch.float16
perf torch: 70.943 ms, 0.578 TFLOPS, 2.895 GB/s
perf flash_mla_triton: 2.133 ms, 19.241 TFLOPS, 96.297 GB/s
comparing torch vs flash_mla_triton: b=128, s_q=1, mean_seqlens=2175.0, h_q=128, h_kv=1, d=576, dv=512, causal=True, dtype=torch.float16
perf torch: 118.045 ms, 0.657 TFLOPS, 3.019 GB/s
perf flash_mla_triton: 3.025 ms, 25.634 TFLOPS, 117.808 GB/s
comparing torch vs flash_mla_triton: b=128, s_q=1, mean_seqlens=4223.0, h_q=128, h_kv=1, d=576, dv=512, causal=True, dtype=torch.float16
perf torch: 225.619 ms, 0.667 TFLOPS, 2.918 GB/s
perf flash_mla_triton: 4.323 ms, 34.825 TFLOPS, 152.286 GB/s
comparing torch vs flash_mla_triton: b=128, s_q=1, mean_seqlens=8319.0, h_q=128, h_kv=1, d=576, dv=512, causal=True, dtype=torch.float16
perf torch: 436.803 ms, 0.679 TFLOPS, 2.890 GB/s
perf flash_mla_triton: 6.888 ms, 43.061 TFLOPS, 183.277 GB/s
comparing torch vs flash_mla_triton: b=128, s_q=1, mean_seqlens=16511.0, h_q=128, h_kv=1, d=576, dv=512, causal=True, dtype=torch.float16
perf torch: 861.401 ms, 0.683 TFLOPS, 2.868 GB/s
perf flash_mla_triton: 12.045 ms, 48.869 TFLOPS, 205.084 GB/s

python examples/deepseek_mla/amd/benchmark_mla_decode_amd_aiter.py --compare

Loading tilelang libs from dev root: /home/roywan/tilelang/build
[aiter] import [module_aiter_enum] under /home/roywan/aiter/aiter/jit/module_aiter_enum.so
comparing torch vs mla_aiter: b=64, s_q=1, mean_seqlens=1087.0, h_q=128, h_kv=1, d=576, dv=512, causal=True, dtype=torch.bfloat16
[aiter] import [module_mla_asm] under /home/roywan/aiter/aiter/jit/module_mla_asm.so
[aiter] hipModuleLoad: /home/roywan/aiter/hsa//gfx942//mla/mla_dec_stage1_bf16_a16w16_subQ128_mqa128.co GetFunction: _ZN5aiter41mla_dec_stage1_bf16_a16w16_subQ128_mqa128E Success
perf torch: 33.097 ms, 0.585 TFLOPS, 2.960 GB/s
perf mla_aiter: 0.106 ms, 182.661 TFLOPS, 923.533 GB/s
comparing torch vs mla_aiter: b=64, s_q=1, mean_seqlens=2111.0, h_q=128, h_kv=1, d=576, dv=512, causal=True, dtype=torch.bfloat16
perf torch: 56.094 ms, 0.671 TFLOPS, 3.092 GB/s
perf mla_aiter: 0.151 ms, 248.406 TFLOPS, 1145.085 GB/s
comparing torch vs mla_aiter: b=64, s_q=1, mean_seqlens=4159.0, h_q=128, h_kv=1, d=576, dv=512, causal=True, dtype=torch.bfloat16
perf torch: 109.147 ms, 0.679 TFLOPS, 2.973 GB/s
perf mla_aiter: 0.300 ms, 247.078 TFLOPS, 1081.330 GB/s
comparing torch vs mla_aiter: b=64, s_q=1, mean_seqlens=8255.0, h_q=128, h_kv=1, d=576, dv=512, causal=True, dtype=torch.bfloat16
perf torch: 215.010 ms, 0.684 TFLOPS, 2.914 GB/s
perf mla_aiter: 0.473 ms, 311.369 TFLOPS, 1325.551 GB/s
comparing torch vs mla_aiter: b=64, s_q=1, mean_seqlens=16447.0, h_q=128, h_kv=1, d=576, dv=512, causal=True, dtype=torch.bfloat16
perf torch: 428.526 ms, 0.684 TFLOPS, 2.871 GB/s
perf mla_aiter: 0.906 ms, 323.650 TFLOPS, 1358.303 GB/s
comparing torch vs mla_aiter: b=128, s_q=1, mean_seqlens=1151.0, h_q=128, h_kv=1, d=576, dv=512, causal=True, dtype=torch.bfloat16
perf torch: 69.371 ms, 0.592 TFLOPS, 2.961 GB/s
perf mla_aiter: 0.173 ms, 236.939 TFLOPS, 1185.841 GB/s
comparing torch vs mla_aiter: b=128, s_q=1, mean_seqlens=2175.0, h_q=128, h_kv=1, d=576, dv=512, causal=True, dtype=torch.bfloat16
perf torch: 116.404 ms, 0.666 TFLOPS, 3.061 GB/s
perf mla_aiter: 0.282 ms, 275.030 TFLOPS, 1263.983 GB/s
comparing torch vs mla_aiter: b=128, s_q=1, mean_seqlens=4223.0, h_q=128, h_kv=1, d=576, dv=512, causal=True, dtype=torch.bfloat16
perf torch: 223.690 ms, 0.673 TFLOPS, 2.943 GB/s
perf mla_aiter: 0.543 ms, 277.283 TFLOPS, 1212.510 GB/s
comparing torch vs mla_aiter: b=128, s_q=1, mean_seqlens=8319.0, h_q=128, h_kv=1, d=576, dv=512, causal=True, dtype=torch.bfloat16
perf torch: 437.925 ms, 0.677 TFLOPS, 2.883 GB/s
perf mla_aiter: 0.933 ms, 317.949 TFLOPS, 1353.265 GB/s
comparing torch vs mla_aiter: b=128, s_q=1, mean_seqlens=16511.0, h_q=128, h_kv=1, d=576, dv=512, causal=True, dtype=torch.bfloat16
perf torch: 858.635 ms, 0.686 TFLOPS, 2.877 GB/s
perf mla_aiter: 1.598 ms, 368.458 TFLOPS, 1546.269 GB/s

python examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py

Using batch=128, heads=128, kv_heads=1, kv_ctx=8192, dim=512, pe_dim=64
2026-01-27 08:45:58,882 WARNING:Tunable parameters ['block_N', 'block_H', 'num_split', 'threads'] already provided during auto-tuning. Skipping compilation and using direct JIT
2026-01-27 08:46:00  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `main_split` with `out_idx=[6]`
2026-01-27 08:46:08  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `main_split`
Latency: 7.468443870544434 ms
TFlops: 39.10557288645856 TFlops

python examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py --autotune

Using batch=128, heads=128, kv_heads=1, kv_ctx=8192, dim=512, pe_dim=64
Latency: 1.5578199625015259 ms
TFlops: 187.4785168749652 TFlops

Summary by CodeRabbit

  • New Features

    • Added an AMD-optimized MLA attention benchmarking path (AIter) as the primary alternate target.
  • Improvements

    • Switched default numeric precision to bfloat16.
    • Improved performance metric formatting to show consistent decimal precision.
    • Refined benchmark batch-size exploration.
    • Added diagnostic runtime logging for configuration visibility.
    • Removed the previous Triton-based benchmark target in favor of the MLA AIter option.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 27, 2026

📝 Walkthrough

Walkthrough

Replaces a Triton MLA kernel path with an AIter-based implementation and updates the benchmark entry/defaults; fixes a TileLang accumulation bug and adds diagnostic prints; tightens perf output formatting in the Triton benchmark prints.

Changes

Cohort / File(s) Summary
AIter implementation & benchmark entry
examples/deepseek_mla/amd/benchmark_mla_decode_amd_aiter.py
Adds AIter-based MLA path (mla_decode_fwd import with fallback), introduces run_mla_aiter (replaces previous Triton entry), updates FUNC_TABLE / available_targets to expose "mla_aiter", changes default target to "mla_aiter", sets default dtype to bfloat16, adjusts batch sizes, and normalizes perf print formatting.
TileLang fix & diagnostics
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py
Fixes accumulation to use vector scale_local (was scale_local[0]), adds diagnostic prints at startup and before autotuning, and removes some noisy output prints.
Formatting consistency for Triton benchmark prints
examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py
Changes perf print formatting in compare_ab and compare_a to display TFLOPS and GB/s with three decimal places.

Sequence Diagram(s)

sequenceDiagram
  participant Bench as Benchmark script
  participant AIter as aiter.mla (mla_decode_fwd)
  participant Torch as PyTorch tensors
  participant GPU as GPU runtime

  Bench->>Torch: Prepare inputs (kv, q, bias, scale, cfg)
  Bench->>AIter: call mla_decode_fwd(inputs)
  AIter->>Torch: allocate/operate on tensors
  AIter->>GPU: launch attention kernels (via AIter runtime)
  GPU-->>AIter: results tensor
  AIter-->>Bench: return output tensor
  Bench->>Bench: compare/print perf results
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Suggested reviewers

  • LeiWang1999
  • tzj-fxz

Poem

🐰 Hopping through code where kernels part,
AIter arrives to play its part,
scales fixed by vector, prints trimmed and neat,
benchmarks hum a quicker beat,
carrots for tests — a tiny treat. 🥕

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly addresses the main changes: fixing the deepseek_mla AMD example and adding an aiter MLA comparison test, which aligns with the primary modifications across three benchmark files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@examples/deepseek_mla/amd/benchmark_mla_decode_amd_aiter.py`:
- Around line 68-114: The Qo metadata is created on the global default device
and uses a hard-coded seq length of 1; update run_mla_aiter so seq_lens_qo is
derived from s_q (e.g., torch.full((b,), s_q, dtype=torch.int, device=q.device)
or torch.ones(...)*s_q) and create qo_indptr, kv_indptr, kv_last_page_lens and
any torch.arange calls (flat_indices) on q.device (and matching dtype) so they
live on the same device as q; ensure max_seqlen_qo is computed from the new
seq_lens_qo and build kv_indices from device-local tensors to avoid
q.view/device mismatches when calling mla_decode_fwd.
- Around line 14-18: The import except block currently leaves mla_decode_fwd
undefined and will cause a NameError when run_mla_aiter() calls it; set a
sentinel (e.g., mla_decode_fwd = None or a dummy function) in the except
ImportError, and then add a guard in run_mla_aiter() (or wherever mla_decode_fwd
is invoked) to check that mla_decode_fwd is not None before calling it and raise
or log a clear error directing the user to install aiter (since mla_aiter is the
default target) so execution fails gracefully instead of crashing.
🧹 Nitpick comments (2)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py (1)

272-273: Prefer a --verbose flag or logging over commented-out debug prints.

Consider gating these with a CLI flag or logger rather than leaving commented code.

examples/deepseek_mla/amd/benchmark_mla_decode_amd_aiter.py (1)

156-158: Update stale comment to match the new target.
The inline comment still references flash_mla_triton, but the guard now skips LSE for the new target.

📝 Proposed fix
-    if target not in ["mla_aiter"]:
-        # flash_mla_triton doesn't return lse
+    if target not in ["mla_aiter"]:
+        # mla_aiter doesn't return lse
         torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"

Comment on lines +14 to +18
try:
from aiter.mla import mla_decode_fwd
except ImportError:
print("aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device.")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, check if the file exists and get its basic info
fd "benchmark_mla_decode_amd_aiter.py" -type f

Repository: tile-ai/tilelang

Length of output: 230


🏁 Script executed:

# Search for aiter in project dependency files
rg -n "aiter" -S pyproject.toml setup.cfg setup.py requirements*.txt requirements*.in README* 2>/dev/null || echo "No aiter found in standard dependency files"

Repository: tile-ai/tilelang

Length of output: 346


🏁 Script executed:

# Read the full benchmark file to understand context and all usages
cat -n "examples/deepseek_mla/amd/benchmark_mla_decode_amd_aiter.py" 2>/dev/null || echo "File not found"

Repository: tile-ai/tilelang

Length of output: 12958


Add guard for mla_decode_fwd to prevent NameError when aiter is not installed.

If aiter is not installed, mla_decode_fwd remains undefined at line 99 when run_mla_aiter() is called. Since mla_aiter is the default target (line 224), this will crash immediately on typical execution. Set a sentinel value in the except block and check it before use:

Proposed fix
 try:
     from aiter.mla import mla_decode_fwd
 except ImportError:
+    mla_decode_fwd = None
     print("aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device.")
 `@torch.inference_mode`()
 def run_mla_aiter(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
     assert d > dv, "mla with rope dim should be larger than no rope dim"
+    if mla_decode_fwd is None:
+        raise RuntimeError("aiter is not installed. Install it or use --target torch.")
🤖 Prompt for AI Agents
In `@examples/deepseek_mla/amd/benchmark_mla_decode_amd_aiter.py` around lines 14
- 18, The import except block currently leaves mla_decode_fwd undefined and will
cause a NameError when run_mla_aiter() calls it; set a sentinel (e.g.,
mla_decode_fwd = None or a dummy function) in the except ImportError, and then
add a guard in run_mla_aiter() (or wherever mla_decode_fwd is invoked) to check
that mla_decode_fwd is not None before calling it and raise or log a clear error
directing the user to install aiter (since mla_aiter is the default target) so
execution fails gracefully instead of crashing.

Comment on lines +68 to +114
def run_mla_aiter(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()

def flash_mla_triton():
num_kv_splits = 32
o = torch.empty([b * s_q, h_q, dv])
attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1])
mla_decode_triton(
q_nope.view(-1, h_q, dv),
q_pe.view(-1, h_q, d - dv),
blocked_k_nope.view(-1, dv),
blocked_k_pe.view(-1, d - dv),
o,
block_table,
cache_seqlens,
attn_logits,
num_kv_splits,
1 / math.sqrt(d),
block_size,

qo_indptr = torch.zeros(b + 1, dtype=torch.int)
kv_indptr = torch.zeros(b + 1, dtype=torch.int)
seq_lens_qo = torch.empty(b, dtype=torch.int)
seq_lens_qo.fill_(1)
max_seqlen_qo = seq_lens_qo.max().item()

kv_indptr[1 : b + 1] = torch.cumsum(cache_seqlens, dim=0)
qo_indptr[1 : b + 1] = torch.cumsum(seq_lens_qo, dim=0)
total_q = qo_indptr[-1].item()

# set block_size to 1
page_size = 1
kv_buffer = blocked_k.view(-1, page_size, h_kv, d)

flat_indices = []
for i in range(b):
start = i * max_seqlen_pad
end = start + cache_seqlens[i]
flat_indices.append(torch.arange(start, end, dtype=torch.int))

kv_indices = torch.cat(flat_indices)

kv_last_page_lens = torch.ones(b, dtype=torch.int)

sm_scale = 1.0 / (d**0.5)

def mla_aiter():
out_aiter = torch.empty((total_q, h_q, dv), dtype=dtype).fill_(-1)
attn_logits_aiter, attn_lse_aiter = mla_decode_fwd(
q.view((total_q, h_q, d)),
kv_buffer,
out_aiter,
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
max_seqlen_qo,
sm_scale,
)
return o.view([b, s_q, h_q, dv])
return out_aiter.view([b, s_q, h_q, dv])

out_flash = flash_mla_triton()
t = triton.testing.do_bench(flash_mla_triton)
return out_flash, None, t
out_aiter = mla_aiter()
t = triton.testing.do_bench(mla_aiter)
return out_aiter, None, t
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use s_q (and q.device) when building Qo metadata.
seq_lens_qo is hard-coded to 1 and metadata tensors rely on the global default device. If s_q changes or the default device isn’t set, q.view and the metadata can mismatch. Consider deriving from s_q and placing metadata on q.device.

♻️ Proposed fix
-    qo_indptr = torch.zeros(b + 1, dtype=torch.int)
-    kv_indptr = torch.zeros(b + 1, dtype=torch.int)
-    seq_lens_qo = torch.empty(b, dtype=torch.int)
-    seq_lens_qo.fill_(1)
-    max_seqlen_qo = seq_lens_qo.max().item()
+    device = q.device
+    qo_indptr = torch.zeros(b + 1, dtype=torch.int, device=device)
+    kv_indptr = torch.zeros(b + 1, dtype=torch.int, device=device)
+    seq_lens_qo = torch.full((b,), s_q, dtype=torch.int, device=device)
+    max_seqlen_qo = s_q
@@
-    kv_indptr[1 : b + 1] = torch.cumsum(cache_seqlens, dim=0)
-    qo_indptr[1 : b + 1] = torch.cumsum(seq_lens_qo, dim=0)
+    kv_indptr[1 : b + 1] = torch.cumsum(cache_seqlens, dim=0)
+    qo_indptr[1 : b + 1] = torch.cumsum(seq_lens_qo, dim=0)
@@
-    flat_indices.append(torch.arange(start, end, dtype=torch.int))
+    flat_indices.append(torch.arange(start, end, dtype=torch.int, device=device))
@@
-    kv_last_page_lens = torch.ones(b, dtype=torch.int)
+    kv_last_page_lens = torch.ones(b, dtype=torch.int, device=device)
🤖 Prompt for AI Agents
In `@examples/deepseek_mla/amd/benchmark_mla_decode_amd_aiter.py` around lines 68
- 114, The Qo metadata is created on the global default device and uses a
hard-coded seq length of 1; update run_mla_aiter so seq_lens_qo is derived from
s_q (e.g., torch.full((b,), s_q, dtype=torch.int, device=q.device) or
torch.ones(...)*s_q) and create qo_indptr, kv_indptr, kv_last_page_lens and any
torch.arange calls (flat_indices) on q.device (and matching dtype) so they live
on the same device as q; ensure max_seqlen_qo is computed from the new
seq_lens_qo and build kv_indices from device-local tensors to avoid
q.view/device mismatches when calling mla_decode_fwd.

Copy link
Member

@LeiWang1999 LeiWang1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, Thanks for your contributions! but I left a simple comment.

@LeiWang1999 LeiWang1999 merged commit 413ecbb into tile-ai:main Jan 28, 2026
2 checks passed
@LeiWang1999
Copy link
Member

LGTM, Merged:)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants