-
Notifications
You must be signed in to change notification settings - Fork 359
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Required prerequisites
- I have read the documentation https://tilelang.com.
- I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
What version of TileLang are you using?
0.1.6.post1+cu128.gitd9a0f131
System information
Details
/opt/conda/lib/python3.10/runpy.py:126: RuntimeWarning: 'torch.utils.collect_env' found in sys.modules after import of package 'torch.utils', but prior to execution of 'torch.utils.collect_env'; this may result in unpredictable behaviour
warn(RuntimeWarning(msg))
Collecting environment information...
PyTorch version: 2.9.1+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: version 4.1.2
Libc version: glibc-2.39
Python version: 3.10.19 (main, Oct 21 2025, 16:43:05) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.8.0-87-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.8.61
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
Nvidia driver version: 570.124.06
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.7.0
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Vendor ID: AuthenticAMD
Model name: AMD EPYC 9355 32-Core Processor
CPU family: 26
Model: 2
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 2
Stepping: 1
Frequency boost: enabled
CPU(s) scaling MHz: 101%
CPU max MHz: 3550.0000
CPU min MHz: 1500.0000
BogoMIPS: 7099.85
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx_vnni avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc amd_ibpb_ret arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect movdiri movdir64b overflow_recov succor smca fsrm avx512_vp2intersect flush_l1d debug_swap
Virtualization: AMD-V
L1d cache: 3 MiB (64 instances)
L1i cache: 2 MiB (64 instances)
L2 cache: 64 MiB (64 instances)
L3 cache: 512 MiB (16 instances)
NUMA node(s): 8
NUMA node0 CPU(s): 0-7,64-71
NUMA node1 CPU(s): 8-15,72-79
NUMA node2 CPU(s): 16-23,80-87
NUMA node3 CPU(s): 24-31,88-95
NUMA node4 CPU(s): 32-39,96-103
NUMA node5 CPU(s): 40-47,104-111
NUMA node6 CPU(s): 48-55,112-119
NUMA node7 CPU(s): 56-63,120-127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Vulnerability Vmscape: Not affected
Versions of relevant libraries:
[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] torch==2.9.1
[pip3] triton==3.5.1
[conda] numpy 2.2.6 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.8.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.3.83 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.9.90 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.3.90 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.8.93 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.27.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.8.90 pypi_0 pypi
[conda] torch 2.9.1 pypi_0 pypi
[conda] triton 3.5.1 pypi_0 pypi
Problem description
I have a benchmark script I wrote for fp8_lighting_indexer.py:
Details
bench_indexer_tilelang.py
#!/usr/bin/env python3
import argparse
import torch
import os
import sys
from typing import Optional
# Optional TVM runtime import to dump CUDA/PTX sources
import tilelang
from tilelang import tvm
from tvm import runtime as tvm_rt
# Prefer local examples path resolution if running from repo root
try:
from examples.deepseek_v32.utils import per_custom_dims_cast_to_fp8 as _to_fp8
def to_fp8(x):
# Cast along last dim to FP8 E4M3 to match kernel expectations
# Handle both (x, dims, use_ue8m0) and (x, dims) signatures and return the scaled tensor only.
try:
x_scaled, _ = _to_fp8(x, dims=(-1,), use_ue8m0=False)
return x_scaled
except TypeError:
out = _to_fp8(x, dims=(-1,))
return out[0] if isinstance(out, tuple) else out
except Exception:
def to_fp8(x):
# Fallback: use PyTorch FP8 E4M3 if TileLang utils are unavailable
if not hasattr(torch, "float8_e4m3fn"):
raise RuntimeError("torch.float8_e4m3fn not available; install a CUDA-enabled PyTorch.")
return x.to(torch.float8_e4m3fn)
# Try to ensure TVM runtime is importable (vendored TVM + build lib dirs)
def _ensure_tvm_runtime() -> bool:
global tvm_rt
if tvm_rt is not None:
return True
# First, try importing directly
try:
from tvm import runtime as _rt # type: ignore
tvm_rt = _rt # type: ignore
return True
except Exception:
pass
# Add vendored TVM python path
try:
here = os.path.abspath(os.path.dirname(__file__))
root = here
vendored = os.path.join(root, "3rdparty", "tvm", "python")
if os.path.isdir(vendored) and vendored not in sys.path:
sys.path.insert(0, vendored)
# Add build library dirs for TVM
libdirs = [os.path.join(root, "build", "tvm"), os.path.join(root, "build", "lib")]
libdirs = [p for p in libdirs if os.path.isdir(p)]
if libdirs:
sep = ":" if os.name != "nt" else ";"
add = sep.join(libdirs)
os.environ["TVM_LIBRARY_PATH"] = add + (sep + os.environ["TVM_LIBRARY_PATH"] if "TVM_LIBRARY_PATH" in os.environ else "")
os.environ["LD_LIBRARY_PATH"] = add + (sep + os.environ["LD_LIBRARY_PATH"] if "LD_LIBRARY_PATH" in os.environ else "")
from tvm import runtime as _rt # type: ignore
tvm_rt = _rt # type: ignore
return True
except Exception:
return False
# Fallback: use PyTorch FP8 E4M3 if TileLang utils are unavailable
if not hasattr(torch, "float8_e4m3fn"):
raise RuntimeError("torch.float8_e4m3fn not available; install a CUDA-enabled PyTorch.")
return x.to(torch.float8_e4m3fn)
# Utilities to extract CUDA/PTX sources from compiled TileLang kernels
def _get_rt_mod_from_kernel(kernel) -> Optional["tvm_rt.Module"]:
if tvm_rt is None:
return None
# Direct attachments
for attr in ("rt_mod", "module", "mod"):
m = getattr(kernel, attr, None)
if m is not None and isinstance(m, tvm_rt.Module):
return m
# Nested wrappers commonly used by TileLang frontends
for inner_name in ("impl", "fn", "kernel", "launcher"):
inner = getattr(kernel, inner_name, None)
if inner is None:
continue
for attr in ("rt_mod", "module", "mod"):
m = getattr(inner, attr, None)
if m is not None and isinstance(m, tvm_rt.Module):
return m
return None
# Prefer printing CUDA source from TileLang artifact if available
def _print_kernel_cuda_from_artifact(kernel, kernel_name: str) -> bool:
try:
art = getattr(kernel, "artifact", None)
if art is not None:
src = getattr(art, "kernel_source", None)
if src:
print(f"===== BEGIN {kernel_name} CUDA =====")
print(src)
print(f"===== END {kernel_name} CUDA =====")
return True
except Exception as e:
print(f"[KERNEL_SRC] Artifact kernel_source not available for {kernel_name}: {e}")
return False
def _print_kernel_sources(kernel, kernel_name: str):
if tvm_rt is None:
if not _ensure_tvm_runtime():
print(f"[KERNEL_SRC] TVM runtime not available; cannot dump sources for {kernel_name}")
return
try:
rt_mod = _get_rt_mod_from_kernel(kernel)
if rt_mod is None:
print(f"[KERNEL_SRC] No TVM runtime module found on kernel {kernel_name}")
return
# Many CUDA builds store device code in imported_modules[0].
try:
imported = list(rt_mod.imported_modules)
except Exception:
imported = []
if not imported:
imported = [rt_mod]
device_mod = imported[0]
any_src = False
for fmt in ("cuda", "ptx"):
try:
src = device_mod.get_source(fmt)
except Exception:
src = None
if src:
any_src = True
header = f"===== BEGIN {kernel_name} {fmt.upper()} ====="
footer = f"===== END {kernel_name} {fmt.upper()} ====="
print(header)
print(src)
print(footer)
if not any_src:
print(f"[KERNEL_SRC] Module did not expose CUDA/PTX sources for {kernel_name}")
except Exception as e:
print(f"[KERNEL_SRC] Failed to get sources for {kernel_name}: {e}")
# TileLang example kernels for lightning indexer
from examples.deepseek_v32.fp8_lighting_indexer import (
mqa_attn_return_logits,
mqa_attn_return_logits_interface,
)
def bench_tl_indexer_wrapper(seq_len: int,
seq_len_kv: int,
heads: int = 4,
index_dim: int = 64,
iters: int = 50,
warmup: int = 5):
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this benchmark.")
device = torch.device("cuda")
# Inputs
q = torch.randn(seq_len, heads, index_dim, device=device, dtype=torch.float32)
kv = torch.randn(seq_len_kv, index_dim, device=device, dtype=torch.float32)
# Convert to FP8 E4M3 to match kernel signature
q_fp8 = to_fp8(q)
kv_fp8 = to_fp8(kv)
# Precompute kv_scales similar to reference: sqrt(mean(k^2)) along dim=-1
kv_scales = kv.pow(2).mean(dim=-1).sqrt()
weights = torch.randn(seq_len, heads, device=device, dtype=torch.float32)
cu_seqlen_ks = torch.zeros(seq_len, dtype=torch.int32, device=device)
cu_seqlen_ke = torch.full((seq_len,), seq_len_kv, dtype=torch.int32, device=device)
# Warmup
for _ in range(warmup):
_ = mqa_attn_return_logits_interface(q_fp8, kv_fp8, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke)
torch.cuda.synchronize()
# Timed
times = []
for _ in range(iters):
t0 = torch.cuda.Event(enable_timing=True)
t1 = torch.cuda.Event(enable_timing=True)
t0.record()
_ = mqa_attn_return_logits_interface(q_fp8, kv_fp8, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke)
t1.record()
t1.synchronize()
times.append(t0.elapsed_time(t1)) # ms
avg_ms = sum(times) / len(times) if times else float("nan")
print(f"[TILELANG_INDEXER] WRAPPER S={seq_len} SKV={seq_len_kv} H={heads} D={index_dim} avg_ms={avg_ms:.3f} over {iters}")
# Dump kernel source (prefer artifact.kernel_source)
try:
kwrap = mqa_attn_return_logits(heads=heads, index_dim=index_dim)
# Force a tiny compile-run to populate artifact
S_sm, SKV_sm = 32, 32
q_sm = torch.randn(S_sm, heads, index_dim, device=device, dtype=torch.float32)
kv_sm = torch.randn(SKV_sm, index_dim, device=device, dtype=torch.float32)
q_sm_fp8 = to_fp8(q_sm)
kv_sm_fp8 = to_fp8(kv_sm)
kv_scales_sm = kv_sm.pow(2).mean(dim=-1).sqrt()
weights_sm = torch.randn(S_sm, heads, device=device, dtype=torch.float32)
cu_seqlen_ks_sm = torch.zeros(S_sm, dtype=torch.int32, device=device)
cu_seqlen_ke_sm = torch.full((S_sm,), SKV_sm, dtype=torch.int32, device=device)
logits_sm = torch.empty(S_sm, SKV_sm, device=device, dtype=torch.float32)
kwrap(q_sm_fp8.view(S_sm * heads, index_dim), kv_sm_fp8, kv_scales_sm,
logits_sm, weights_sm, cu_seqlen_ks_sm, cu_seqlen_ke_sm)
torch.cuda.synchronize()
if not _print_kernel_cuda_from_artifact(kwrap, "mqa_attn_return_logits_kernel"):
_print_kernel_sources(kwrap, "mqa_attn_return_logits_kernel")
except Exception as e:
print(f"[KERNEL_SRC] Wrapper: unable to dump kernel sources: {e}")
def bench_tl_indexer_impl(seq_len: int,
seq_len_kv: int,
heads: int = 4,
index_dim: int = 64,
iters: int = 50,
warmup: int = 5):
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this benchmark.")
device = torch.device("cuda")
# Compile kernel once
kernel = mqa_attn_return_logits(heads=heads, index_dim=index_dim)
# Inputs
q = torch.randn(seq_len, heads, index_dim, device=device, dtype=torch.float32)
kv = torch.randn(seq_len_kv, index_dim, device=device, dtype=torch.float32)
# Convert to FP8 E4M3 to match kernel signature
q_fp8 = to_fp8(q)
kv_fp8 = to_fp8(kv)
kv_scales = kv.pow(2).mean(dim=-1).sqrt()
weights = torch.randn(seq_len, heads, device=device, dtype=torch.float32)
cu_seqlen_ks = torch.zeros(seq_len, dtype=torch.int32, device=device)
cu_seqlen_ke = torch.full((seq_len,), seq_len_kv, dtype=torch.int32, device=device)
logits = torch.empty(seq_len, seq_len_kv, device=device, dtype=torch.float32)
# Warmup
for _ in range(warmup):
kernel(
q_fp8.view(seq_len * heads, index_dim),
kv_fp8,
kv_scales,
logits,
weights,
cu_seqlen_ks,
cu_seqlen_ke,
)
torch.cuda.synchronize()
# Timed
times = []
for _ in range(iters):
t0 = torch.cuda.Event(enable_timing=True)
t1 = torch.cuda.Event(enable_timing=True)
t0.record()
kernel(
q_fp8.view(seq_len * heads, index_dim),
kv_fp8,
kv_scales,
logits,
weights,
cu_seqlen_ks,
cu_seqlen_ke,
)
t1.record()
t1.synchronize()
times.append(t0.elapsed_time(t1)) # ms
avg_ms = sum(times) / len(times) if times else float("nan")
print(f"[TILELANG_INDEXER] IMPL S={seq_len} SKV={seq_len_kv} H={heads} D={index_dim} avg_ms={avg_ms:.3f} over {iters}")
# Dump kernel source for impl (prefer artifact.kernel_source)
try:
# kernel is already compiled by this point, but if artifact is not populated, force a tiny run
if not _print_kernel_cuda_from_artifact(kernel, "mqa_attn_return_logits_kernel"):
S_sm, SKV_sm = 32, 32
q_sm = torch.randn(S_sm, heads, index_dim, device=device, dtype=torch.float32)
kv_sm = torch.randn(SKV_sm, index_dim, device=device, dtype=torch.float32)
q_sm_fp8 = to_fp8(q_sm)
kv_sm_fp8 = to_fp8(kv_sm)
kv_scales_sm = kv_sm.pow(2).mean(dim=-1).sqrt()
weights_sm = torch.randn(S_sm, heads, device=device, dtype=torch.float32)
cu_seqlen_ks_sm = torch.zeros(S_sm, dtype=torch.int32, device=device)
cu_seqlen_ke_sm = torch.full((S_sm,), SKV_sm, dtype=torch.int32, device=device)
logits_sm = torch.empty(S_sm, SKV_sm, device=device, dtype=torch.float32)
kernel(q_sm_fp8.view(S_sm * heads, index_dim), kv_sm_fp8, kv_scales_sm,
logits_sm, weights_sm, cu_seqlen_ks_sm, cu_seqlen_ke_sm)
torch.cuda.synchronize()
if not _print_kernel_cuda_from_artifact(kernel, "mqa_attn_return_logits_kernel"):
_print_kernel_sources(kernel, "mqa_attn_return_logits_kernel")
except Exception as e:
print(f"[KERNEL_SRC] Impl: unable to dump kernel sources: {e}")
def parse_int_list(s: str):
vals = []
for part in s.split(','):
part = part.strip()
if not part:
continue
vals.append(int(part))
return vals
def main():
parser = argparse.ArgumentParser(description="Benchmark TileLang lightning indexer (DeepSeek V3.2)")
parser.add_argument("--seq-lens", type=parse_int_list, default="4096,16384,163840",
help="Comma-separated sequence lengths S (default: 4096,16384,163840)")
parser.add_argument("--kv-lens", type=parse_int_list, default=None,
help="Comma-separated KV lengths SKV; if omitted, uses seq-lens")
parser.add_argument("--heads", type=int, default=4, help="Indexer heads H (default: 4)")
parser.add_argument("--dim", type=int, default=64, help="Indexer dimension D (default: 64)")
parser.add_argument("--iters", type=int, default=50, help="Timed iterations (default: 50)")
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations (default: 5)")
parser.add_argument("--mode", choices=["both", "wrapper", "impl"], default="both",
help="Which path to benchmark: wrapper (interface), impl (kernel), or both (default)")
args = parser.parse_args()
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this benchmark.")
dev = torch.cuda.get_device_name(0)
print(f"CUDA device: {dev}")
seq_lens = args.seq_lens if isinstance(args.seq_lens, list) else parse_int_list(args.seq_lens)
kv_lens = None
if args.kv_lens is None:
kv_lens = seq_lens
else:
kv_lens = args.kv_lens if isinstance(args.kv_lens, list) else parse_int_list(args.kv_lens)
if len(kv_lens) != len(seq_lens):
raise ValueError("--kv-lens must have the same number of elements as --seq-lens")
for S, SKV in zip(seq_lens, kv_lens):
if args.mode in ("both", "wrapper"):
bench_tl_indexer_wrapper(S, SKV, heads=args.heads, index_dim=args.dim, iters=args.iters, warmup=args.warmup)
if args.mode in ("both", "impl"):
bench_tl_indexer_impl(S, SKV, heads=args.heads, index_dim=args.dim, iters=args.iters, warmup=args.warmup)
if __name__ == "__main__":
main()If I run this with 4 heads, like this:
python bench_indexer_tilelang.py --heads 4 --seq-lens '4096,16384,150840'
I get results quickly:
2025-11-14 00:40:19 [TileLang:tilelang.env:WARNING]: Loading tilelang libs from dev root: /root/TileLang/build
CUDA device: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
[TILELANG_INDEXER] WRAPPER S=4096 SKV=4096 H=4 D=64 avg_ms=0.062 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
[TILELANG_INDEXER] IMPL S=4096 SKV=4096 H=4 D=64 avg_ms=0.050 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
[TILELANG_INDEXER] WRAPPER S=16384 SKV=16384 H=4 D=64 avg_ms=0.776 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
[TILELANG_INDEXER] IMPL S=16384 SKV=16384 H=4 D=64 avg_ms=0.701 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
[TILELANG_INDEXER] WRAPPER S=150840 SKV=150840 H=4 D=64 avg_ms=65.442 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
[TILELANG_INDEXER] IMPL S=150840 SKV=150840 H=4 D=64 avg_ms=59.975 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
If I run it with 8 heads, and I reduce the seq-lens a bit, I still get results, but they're slower:
root@a1cb74468f35:~/TileLang# python bench_indexer_tilelang.py --heads 8 --seq-lens '4096,16384,150000'
2025-11-14 00:44:36 [TileLang:tilelang.env:WARNING]: Loading tilelang libs from dev root: /root/TileLang/build
CUDA device: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
[TILELANG_INDEXER] WRAPPER S=4096 SKV=4096 H=8 D=64 avg_ms=0.112 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
[TILELANG_INDEXER] IMPL S=4096 SKV=4096 H=8 D=64 avg_ms=0.100 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
[TILELANG_INDEXER] WRAPPER S=16384 SKV=16384 H=8 D=64 avg_ms=1.091 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
[TILELANG_INDEXER] IMPL S=16384 SKV=16384 H=8 D=64 avg_ms=1.032 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
[TILELANG_INDEXER] WRAPPER S=150000 SKV=150000 H=8 D=64 avg_ms=91.376 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
[TILELANG_INDEXER] IMPL S=150000 SKV=150000 H=8 D=64 avg_ms=88.629 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
root@a1cb74468f35:~/TileLang#
If I increase heads to 16, it never finishes.
Isn't DeepSeek V3.2-Exp supposed to have 64 heads? Why does this kernel only work with 8?
Reproducible example code
The Python snippets:
Traceback
Expected behavior
No response
Additional context
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working