-
Notifications
You must be signed in to change notification settings - Fork 418
[fix]: fix deepseek_mla amd example and add aiter mla compare test #1740
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughReplaces a Triton MLA kernel path with an AIter-based implementation and updates the benchmark entry/defaults; fixes a TileLang accumulation bug and adds diagnostic prints; tightens perf output formatting in the Triton benchmark prints. Changes
Sequence Diagram(s)sequenceDiagram
participant Bench as Benchmark script
participant AIter as aiter.mla (mla_decode_fwd)
participant Torch as PyTorch tensors
participant GPU as GPU runtime
Bench->>Torch: Prepare inputs (kv, q, bias, scale, cfg)
Bench->>AIter: call mla_decode_fwd(inputs)
AIter->>Torch: allocate/operate on tensors
AIter->>GPU: launch attention kernels (via AIter runtime)
GPU-->>AIter: results tensor
AIter-->>Bench: return output tensor
Bench->>Bench: compare/print perf results
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@examples/deepseek_mla/amd/benchmark_mla_decode_amd_aiter.py`:
- Around line 68-114: The Qo metadata is created on the global default device
and uses a hard-coded seq length of 1; update run_mla_aiter so seq_lens_qo is
derived from s_q (e.g., torch.full((b,), s_q, dtype=torch.int, device=q.device)
or torch.ones(...)*s_q) and create qo_indptr, kv_indptr, kv_last_page_lens and
any torch.arange calls (flat_indices) on q.device (and matching dtype) so they
live on the same device as q; ensure max_seqlen_qo is computed from the new
seq_lens_qo and build kv_indices from device-local tensors to avoid
q.view/device mismatches when calling mla_decode_fwd.
- Around line 14-18: The import except block currently leaves mla_decode_fwd
undefined and will cause a NameError when run_mla_aiter() calls it; set a
sentinel (e.g., mla_decode_fwd = None or a dummy function) in the except
ImportError, and then add a guard in run_mla_aiter() (or wherever mla_decode_fwd
is invoked) to check that mla_decode_fwd is not None before calling it and raise
or log a clear error directing the user to install aiter (since mla_aiter is the
default target) so execution fails gracefully instead of crashing.
🧹 Nitpick comments (2)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py (1)
272-273: Prefer a--verboseflag or logging over commented-out debug prints.Consider gating these with a CLI flag or logger rather than leaving commented code.
examples/deepseek_mla/amd/benchmark_mla_decode_amd_aiter.py (1)
156-158: Update stale comment to match the new target.
The inline comment still referencesflash_mla_triton, but the guard now skips LSE for the new target.📝 Proposed fix
- if target not in ["mla_aiter"]: - # flash_mla_triton doesn't return lse + if target not in ["mla_aiter"]: + # mla_aiter doesn't return lse torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"
| try: | ||
| from aiter.mla import mla_decode_fwd | ||
| except ImportError: | ||
| print("aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device.") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, check if the file exists and get its basic info
fd "benchmark_mla_decode_amd_aiter.py" -type fRepository: tile-ai/tilelang
Length of output: 230
🏁 Script executed:
# Search for aiter in project dependency files
rg -n "aiter" -S pyproject.toml setup.cfg setup.py requirements*.txt requirements*.in README* 2>/dev/null || echo "No aiter found in standard dependency files"Repository: tile-ai/tilelang
Length of output: 346
🏁 Script executed:
# Read the full benchmark file to understand context and all usages
cat -n "examples/deepseek_mla/amd/benchmark_mla_decode_amd_aiter.py" 2>/dev/null || echo "File not found"Repository: tile-ai/tilelang
Length of output: 12958
Add guard for mla_decode_fwd to prevent NameError when aiter is not installed.
If aiter is not installed, mla_decode_fwd remains undefined at line 99 when run_mla_aiter() is called. Since mla_aiter is the default target (line 224), this will crash immediately on typical execution. Set a sentinel value in the except block and check it before use:
Proposed fix
try:
from aiter.mla import mla_decode_fwd
except ImportError:
+ mla_decode_fwd = None
print("aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device.") `@torch.inference_mode`()
def run_mla_aiter(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
assert d > dv, "mla with rope dim should be larger than no rope dim"
+ if mla_decode_fwd is None:
+ raise RuntimeError("aiter is not installed. Install it or use --target torch.")🤖 Prompt for AI Agents
In `@examples/deepseek_mla/amd/benchmark_mla_decode_amd_aiter.py` around lines 14
- 18, The import except block currently leaves mla_decode_fwd undefined and will
cause a NameError when run_mla_aiter() calls it; set a sentinel (e.g.,
mla_decode_fwd = None or a dummy function) in the except ImportError, and then
add a guard in run_mla_aiter() (or wherever mla_decode_fwd is invoked) to check
that mla_decode_fwd is not None before calling it and raise or log a clear error
directing the user to install aiter (since mla_aiter is the default target) so
execution fails gracefully instead of crashing.
| def run_mla_aiter(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): | ||
| assert d > dv, "mla with rope dim should be larger than no rope dim" | ||
| q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() | ||
| blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() | ||
|
|
||
| def flash_mla_triton(): | ||
| num_kv_splits = 32 | ||
| o = torch.empty([b * s_q, h_q, dv]) | ||
| attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) | ||
| mla_decode_triton( | ||
| q_nope.view(-1, h_q, dv), | ||
| q_pe.view(-1, h_q, d - dv), | ||
| blocked_k_nope.view(-1, dv), | ||
| blocked_k_pe.view(-1, d - dv), | ||
| o, | ||
| block_table, | ||
| cache_seqlens, | ||
| attn_logits, | ||
| num_kv_splits, | ||
| 1 / math.sqrt(d), | ||
| block_size, | ||
|
|
||
| qo_indptr = torch.zeros(b + 1, dtype=torch.int) | ||
| kv_indptr = torch.zeros(b + 1, dtype=torch.int) | ||
| seq_lens_qo = torch.empty(b, dtype=torch.int) | ||
| seq_lens_qo.fill_(1) | ||
| max_seqlen_qo = seq_lens_qo.max().item() | ||
|
|
||
| kv_indptr[1 : b + 1] = torch.cumsum(cache_seqlens, dim=0) | ||
| qo_indptr[1 : b + 1] = torch.cumsum(seq_lens_qo, dim=0) | ||
| total_q = qo_indptr[-1].item() | ||
|
|
||
| # set block_size to 1 | ||
| page_size = 1 | ||
| kv_buffer = blocked_k.view(-1, page_size, h_kv, d) | ||
|
|
||
| flat_indices = [] | ||
| for i in range(b): | ||
| start = i * max_seqlen_pad | ||
| end = start + cache_seqlens[i] | ||
| flat_indices.append(torch.arange(start, end, dtype=torch.int)) | ||
|
|
||
| kv_indices = torch.cat(flat_indices) | ||
|
|
||
| kv_last_page_lens = torch.ones(b, dtype=torch.int) | ||
|
|
||
| sm_scale = 1.0 / (d**0.5) | ||
|
|
||
| def mla_aiter(): | ||
| out_aiter = torch.empty((total_q, h_q, dv), dtype=dtype).fill_(-1) | ||
| attn_logits_aiter, attn_lse_aiter = mla_decode_fwd( | ||
| q.view((total_q, h_q, d)), | ||
| kv_buffer, | ||
| out_aiter, | ||
| qo_indptr, | ||
| kv_indptr, | ||
| kv_indices, | ||
| kv_last_page_lens, | ||
| max_seqlen_qo, | ||
| sm_scale, | ||
| ) | ||
| return o.view([b, s_q, h_q, dv]) | ||
| return out_aiter.view([b, s_q, h_q, dv]) | ||
|
|
||
| out_flash = flash_mla_triton() | ||
| t = triton.testing.do_bench(flash_mla_triton) | ||
| return out_flash, None, t | ||
| out_aiter = mla_aiter() | ||
| t = triton.testing.do_bench(mla_aiter) | ||
| return out_aiter, None, t |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use s_q (and q.device) when building Qo metadata.
seq_lens_qo is hard-coded to 1 and metadata tensors rely on the global default device. If s_q changes or the default device isn’t set, q.view and the metadata can mismatch. Consider deriving from s_q and placing metadata on q.device.
♻️ Proposed fix
- qo_indptr = torch.zeros(b + 1, dtype=torch.int)
- kv_indptr = torch.zeros(b + 1, dtype=torch.int)
- seq_lens_qo = torch.empty(b, dtype=torch.int)
- seq_lens_qo.fill_(1)
- max_seqlen_qo = seq_lens_qo.max().item()
+ device = q.device
+ qo_indptr = torch.zeros(b + 1, dtype=torch.int, device=device)
+ kv_indptr = torch.zeros(b + 1, dtype=torch.int, device=device)
+ seq_lens_qo = torch.full((b,), s_q, dtype=torch.int, device=device)
+ max_seqlen_qo = s_q
@@
- kv_indptr[1 : b + 1] = torch.cumsum(cache_seqlens, dim=0)
- qo_indptr[1 : b + 1] = torch.cumsum(seq_lens_qo, dim=0)
+ kv_indptr[1 : b + 1] = torch.cumsum(cache_seqlens, dim=0)
+ qo_indptr[1 : b + 1] = torch.cumsum(seq_lens_qo, dim=0)
@@
- flat_indices.append(torch.arange(start, end, dtype=torch.int))
+ flat_indices.append(torch.arange(start, end, dtype=torch.int, device=device))
@@
- kv_last_page_lens = torch.ones(b, dtype=torch.int)
+ kv_last_page_lens = torch.ones(b, dtype=torch.int, device=device)🤖 Prompt for AI Agents
In `@examples/deepseek_mla/amd/benchmark_mla_decode_amd_aiter.py` around lines 68
- 114, The Qo metadata is created on the global default device and uses a
hard-coded seq length of 1; update run_mla_aiter so seq_lens_qo is derived from
s_q (e.g., torch.full((b,), s_q, dtype=torch.int, device=q.device) or
torch.ones(...)*s_q) and create qo_indptr, kv_indptr, kv_last_page_lens and any
torch.arange calls (flat_indices) on q.device (and matching dtype) so they live
on the same device as q; ensure max_seqlen_qo is computed from the new
seq_lens_qo and build kv_indices from device-local tensors to avoid
q.view/device mismatches when calling mla_decode_fwd.
LeiWang1999
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM, Thanks for your contributions! but I left a simple comment.
|
LGTM, Merged:) |
Modifications
Tests
python examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py --comparepython examples/deepseek_mla/amd/benchmark_mla_decode_amd_aiter.py --comparepython examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.pypython examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py --autotuneSummary by CodeRabbit
New Features
Improvements
✏️ Tip: You can customize this high-level summary in your review settings.