-
Notifications
You must be signed in to change notification settings - Fork 359
Description
Required prerequisites
- I have read the documentation https://tilelang.com.
- I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
What version of TileLang are you using?
0.1.7.post1
System information
PyTorch version: 2.9.0+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.10
Libc version: glibc-2.35
Python version: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.6.105+-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.5.82
CUDA_MODULE_LOADING set to:
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 550.54.15
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.2.1
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 2
On-line CPU(s) list: 0,1
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) CPU @ 2.20GHz
CPU family: 6
Model: 79
Thread(s) per core: 2
Core(s) per socket: 1
Socket(s): 1
Stepping: 0
BogoMIPS: 4399.99
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 32 KiB (1 instance)
L1i cache: 32 KiB (1 instance)
L2 cache: 256 KiB (1 instance)
L3 cache: 55 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0,1
Vulnerability Gather data sampling: Not affected
Vulnerability Indirect target selection: Vulnerable
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable; SMT Host state unknown
Vulnerability Meltdown: Vulnerable
Vulnerability Mmio stale data: Vulnerable
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Vulnerable
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2: Vulnerable; IBPB: disabled; STIBP: disabled; PBRSB-eIBRS: Not affected; BHI: Vulnerable
Vulnerability Srbds: Not affected
Vulnerability Tsa: Not affected
Vulnerability Tsx async abort: Vulnerable
Versions of relevant libraries:
[pip3] intel-cmplr-lib-ur==2025.3.1
[pip3] intel-openmp==2025.3.1
[pip3] mkl==2025.3.0
[pip3] numpy==2.0.2
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] nvtx==0.2.14
[pip3] onemkl-license==2025.3.0
[pip3] optree==0.18.0
[pip3] tbb==2022.3.0
[pip3] tcmlib==1.4.1
[pip3] torch==2.9.0+cu126
[pip3] torch_c_dlpack_ext==0.1.4
[pip3] torchao==0.10.0
[pip3] torchaudio==2.9.0+cu126
[pip3] torchdata==0.11.0
[pip3] torchsummary==1.5.1
[pip3] torchtune==0.6.1
[pip3] torchvision==0.24.0+cu126
[pip3] triton==3.5.0
[pip3] umf==1.0.2
[conda] Could not collect
Problem description
when using T.gemm() on GPU SM132 (H200), everything runs normally as usual, but when using colab SM75 (Tesla T4) with target="auto", the compiled kernel fails at runtime with a CUDA assertion error. the error indicates that SM80 instructions were generated for the SM75.
Reproducible example code
The Python snippets:
import torch
import tilelang
import tilelang.language as T
tilelang.set_log_level("WARNING")
def create_gemm(N: int, K: int, dtype: str):
M = T.symbolic("M")
@T.prim_func
def kernel(
A: T.Tensor[(M, K), dtype],
B: T.Tensor[(N, K), dtype],
C: T.Tensor[(M, N), dtype],
):
with T.Kernel(T.ceildiv(N, 64), T.ceildiv(M, 64), threads=128) as (bx, by):
A_sh = T.alloc_shared((64, 32), dtype)
B_sh = T.alloc_shared((64, 32), dtype)
C_sh = T.alloc_shared((64, 64), dtype)
C_lo = T.alloc_fragment((64, 64), "float32")
T.use_swizzle(panel_size=4)
T.clear(C_lo)
for k in T.Pipelined(T.ceildiv(K, 32), num_stages=2):
T.copy(A[by * 64, k * 32], A_sh)
T.copy(B[bx * 64, k * 32], B_sh)
T.gemm(A_sh, B_sh, C_lo, transpose_B=True)
T.copy(C_lo, C_sh)
T.copy(C_sh, C[by * 64, bx * 64])
return kernel
# Compile with auto target detection
kernel_func = create_gemm(256, 256, "float16")
compiled = tilelang.compile(kernel_func, target="auto")
# Test
device = "cuda"
a = torch.randn(128, 256, device=device, dtype=torch.float16)
b = torch.randn(256, 256, device=device, dtype=torch.float16)
c = torch.empty(128, 256, device=device, dtype=torch.float16)
compiled(a, b, c)
torch.cuda.synchronize() # <-- FAILS HERE
print("Success!")Traceback
/usr/local/lib/python3.12/dist-packages/tilelang/3rdparty/cutlass/include/cute/arch/mma_sm80.hpp:183: static void cute::SM80_16x8x16_F32F16F16F32_TN::fma(float &, float &, float &, float &, const unsigned int &, const unsigned int &, const unsigned int &, const unsigned int &, const unsigned int &, const unsigned int &, const float &, const float &, const float &, const float &): block: [3,0,0], thread: [127,0,0] Assertion `0 && "Attempting to use SM80_16x8x16_F32F16F16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"` failed.
Traceback (most recent call last):
File "/content/test/l.py", line 47, in <module>
torch.cuda.synchronize() # <-- FAILS HERE
^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py", line 1083, in synchronize
return torch._C._cuda_synchronize()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: device-side assert triggered
Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.Expected behavior
TileLang must do one of the following:
- generate code that is compatible with SM75 (using the
mma.syncinstruction available in Turing), or - display a clear error at compile time
indicating thatT.gemm()requires a GPU SM80+.
Additional context
from the error, it appears the CUTLASS/CuTe backend is selecting SM80 MMA instructions regardless of the target GPU architecture. The target="auto" should detect SM75 and either:
- Use SM75-compatible WMMA/MMA instructions
- Fall back to non-Tensor Core GEMM
- Reject the compilation with a clear error message