Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions examples/gemm_fp8/example_tilelang_gemm_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import torch_assert_close
from tilelang.utils import select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype
import itertools


Expand All @@ -17,8 +18,9 @@ def supply_prog(args):
a_param, b_param = args
M, K = a_param.shape
N, _ = b_param.shape
a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
fp8_dtype = select_torch_fp8_e4m3_dtype()
a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype)
b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype)
return [a, b]


Expand Down Expand Up @@ -53,7 +55,7 @@ def get_configs():
)
@tilelang.jit(out_idx=[-1])
def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type):
dtype = T.float8_e4m3fnuz
dtype = select_fp8_e4m3_dtype()
accum_dtype = T.float32

@T.prim_func
Expand Down Expand Up @@ -104,8 +106,9 @@ def gemm_fp8_ss(

def test_gemm_fp8(M, N, K):
kernel = fp8_matmul(M, N, K)
a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
fp8_dtype = select_torch_fp8_e4m3_dtype()
a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype)
b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype)
c = kernel(a, b)
ref_c = ref_program(a, b)
torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
Expand Down
7 changes: 5 additions & 2 deletions examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tilelang.tileop.base import GemmWarpPolicy
from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mfma_macro_generator import MatrixCorePreshuffleIntrinEmitter
from tilelang.utils import select_fp8_e4m3_dtype

tilelang.testing.set_random_seed(0)

Expand Down Expand Up @@ -45,12 +46,14 @@ def tl_matmul(
num_stages,
k_pack=2,
num_threads=256,
in_dtype=T.float8_e4m3fnuz,
in_dtype=None,
out_dtype=T.float32,
accum_dtype=T.float32,
a_transposed=False,
b_transposed=True,
):
if in_dtype is None:
in_dtype = select_fp8_e4m3_dtype()
b_preshuffle = True
warp_size = 64
num_warps = num_threads // warp_size
Expand Down Expand Up @@ -164,7 +167,7 @@ def shuffle_weight(


def assert_tl_matmul_correctness(M, N, K, k_pack=1, a_transposed=False, b_transposed=True):
in_dtype = T.float8_e4m3fnuz
in_dtype = select_fp8_e4m3_dtype()
out_dtype = T.float32
accum_dtype = T.float32
kernel = tl_matmul(
Expand Down
22 changes: 13 additions & 9 deletions examples/gemm_fp8/example_tilelang_gemm_fp8.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import tilelang
import tilelang.language as T
from tilelang.utils import select_fp8_e4m3_dtype, select_fp8_e5m2_dtype


def calc_diff(x, y):
Expand Down Expand Up @@ -55,21 +56,24 @@ def test_gemm_fp8(M, N, K, dtype):


def main():
test_gemm_fp8(1024, 1024, 1024, T.float8_e4m3fn)
test_gemm_fp8(1024, 1024, 1024, T.float8_e5m2)
test_gemm_fp8(1024, 1024, 1024, select_fp8_e4m3_dtype())
test_gemm_fp8(1024, 1024, 1024, select_fp8_e5m2_dtype())


def run_regression_perf():
M, N, K = 4096, 4096, 4096
dtype = "float8_e4m3"
dtype = select_fp8_e4m3_dtype()
kernel_e4m3 = matmul(M, N, K, 128, 128, 64, dtype)
profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer)
latency_e4m3 = profiler_e4m3.do_bench(backend="cupti")
dtype = "float8_e5m2"
kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype)
profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer)
latency_e5m2 = profiler_e5m2.do_bench(backend="cupti")
return (latency_e4m3 + latency_e5m2) / 2
if torch.version.hip is None:
latency_e4m3 = profiler_e4m3.do_bench(backend="cupti")
dtype = select_fp8_e5m2_dtype()
kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype)
profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer)
latency_e5m2 = profiler_e5m2.do_bench(backend="cupti")
return (latency_e4m3 + latency_e5m2) / 2
latency_e4m3 = profiler_e4m3.do_bench()
return latency_e4m3


if __name__ == "__main__":
Expand Down
24 changes: 15 additions & 9 deletions examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import tilelang
import tilelang.language as T
from tilelang.utils import select_fp8_e4m3_dtype, select_fp8_e5m2_dtype


@tilelang.jit(out_idx=[-1])
Expand Down Expand Up @@ -73,21 +74,26 @@ def test_gemm_fp8(M, N, K, dtype):


def main():
test_gemm_fp8(1024, 1024, 8192, T.float8_e4m3fn)
test_gemm_fp8(1024, 1024, 8192, T.float8_e5m2)
test_gemm_fp8(1024, 1024, 8192, select_fp8_e4m3_dtype())
test_gemm_fp8(1024, 1024, 8192, select_fp8_e5m2_dtype())


def run_regression_perf():
M, N, K = 1024, 1024, 8192
dtype = "float8_e4m3"
dtype = select_fp8_e4m3_dtype()
kernel_e4m3 = matmul(M, N, K, 128, 128, 64, dtype)
profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer)
latency_e4m3 = profiler_e4m3.do_bench(backend="cupti")
dtype = "float8_e5m2"
kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype)
profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer)
latency_e5m2 = profiler_e5m2.do_bench(backend="cupti")
return (latency_e4m3 + latency_e5m2) / 2
if torch.version.hip is None:
latency_e4m3 = profiler_e4m3.do_bench(backend="cupti")
else:
latency_e4m3 = profiler_e4m3.do_bench()
if torch.version.hip is None:
dtype = select_fp8_e5m2_dtype()
kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype)
profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer)
latency_e5m2 = profiler_e5m2.do_bench(backend="cupti")
return (latency_e4m3 + latency_e5m2) / 2
return latency_e4m3


if __name__ == "__main__":
Expand Down
105 changes: 63 additions & 42 deletions examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from tvm import DataType
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,
)
from tilelang.intrinsics.mma_macro_generator import TensorCoreIntrinEmitter
from tilelang.intrinsics.mfma_macro_generator import MatrixCoreIntrinEmitter
from tilelang.utils.tensor import map_torch_type
from tilelang.utils import select_fp8_e4m3_dtype, select_fp8_e5m2_dtype

tilelang.testing.set_random_seed(0)

Expand Down Expand Up @@ -39,26 +39,17 @@ def tl_matmul(
assert in_dtype in [
T.float16,
T.float8_e4m3fn,
T.float8_e4m3fnuz,
T.float8_e5m2,
T.float8_e5m2fnuz,
T.int8,
], "Currently only float16 and int8 are supported"
], "Currently only float16, float8, and int8 are supported"
assert out_dtype in [
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"

micro_size_x = micro_size_y = micro_size_k = 16

is_float8 = in_dtype in [
T.float8_e4m3fn,
T.float8_e5m2,
T.float8_e4m3fn,
T.float8_e5m2fnuz,
]
if out_dtype == T.int32 or is_float8:
micro_size_k = 32

# This is a debug config
block_row_warps = 2
block_col_warps = 2
Expand All @@ -78,34 +69,51 @@ def tl_matmul(
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
is_hip = torch.version.hip is not None
# MMA Wrapper to Auto Generate Code for MMA/MFMA
if is_hip:
mma_emitter = MatrixCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
else:
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)

micro_size_x = mma_emitter.M_DIM
micro_size_y = getattr(mma_emitter, "n_dim", getattr(mma_emitter, "N_DIM", micro_size_x))
micro_size_k = mma_emitter.k_dim
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)

warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y

# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
threads = mma_emitter.threads
local_size_a = mma_emitter.local_size_a
local_size_b = mma_emitter.local_size_b
local_size_c = mma_emitter.local_size_out
warp_rows = mma_emitter.warp_rows
warp_cols = mma_emitter.warp_cols

@T.prim_func
def gemm_fp8_intrinsic(
Expand Down Expand Up @@ -158,7 +166,10 @@ def gemm_fp8_intrinsic(
)

# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
if is_hip:
mma_emitter.mfma(A_local, B_local, C_local, ki)
else:
mma_emitter.mma(A_local, B_local, C_local)

# Perform STMatrix
mma_emitter.stmatrix(
Expand Down Expand Up @@ -192,7 +203,12 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
if in_dtype in {torch.int8, torch.int32}:
A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda()
B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda()
elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
elif in_dtype in {
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
}:
A = torch.randn(M, K).to(in_dtype).cuda()
B = torch.randn(N, K).to(in_dtype).cuda()
else:
Expand All @@ -218,18 +234,23 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):


def main():
assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32)
assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32)
e4m3_dtype = select_fp8_e4m3_dtype()
assert_tl_matmul_correctness(128, 128, 128, e4m3_dtype, T.float32, T.float32)
e5m2_dtype = select_fp8_e5m2_dtype()
assert_tl_matmul_correctness(128, 128, 128, e5m2_dtype, T.float32, T.float32)
Comment on lines +237 to +240
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cd /tmp && cat > verify_fp8.py << 'EOF'
import sys
try:
    import tilelang.language as T
    from tilelang.utils import select_fp8_e4m3_dtype, select_fp8_e5m2_dtype
    for fn in (select_fp8_e4m3_dtype, select_fp8_e5m2_dtype):
        d = fn()
        print(f"{fn.__name__}: value={d}, type={type(d).__name__}, repr={repr(d)}")
except Exception as e:
    print(f"Error: {e}", file=sys.stderr)
    import traceback
    traceback.print_exc()
EOF
python verify_fp8.py

Repository: tile-ai/tilelang

Length of output: 146


🏁 Script executed:

# Find the implementations of these selector functions
rg "def select_fp8_e4m3_dtype|def select_fp8_e5m2_dtype" --type py -A 5

Repository: tile-ai/tilelang

Length of output: 818


🏁 Script executed:

# Check what tl_matmul does with in_dtype
rg "def tl_matmul|def assert_tl_matmul" --type py -A 15

Repository: tile-ai/tilelang

Length of output: 37110


🏁 Script executed:

# Check the example file itself
cat -n examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py | sed -n '230,245p'

Repository: tile-ai/tilelang

Length of output: 713


Convert string dtypes from select_fp8_*() to TileLang dtype objects.

select_fp8_e4m3_dtype() and select_fp8_e5m2_dtype() return strings (e.g., "float8_e4m3fn"), but tl_matmul asserts in_dtype against TileLang dtype objects like T.float8_e4m3fn. This will cause an AssertionError at runtime because "float8_e4m3fn" != T.float8_e4m3fn.

Safe normalization
+def _tl_dtype(d):
+    return getattr(T, d) if isinstance(d, str) else d
+
 def main():
-    e4m3_dtype = select_fp8_e4m3_dtype()
+    e4m3_dtype = _tl_dtype(select_fp8_e4m3_dtype())
     assert_tl_matmul_correctness(128, 128, 128, e4m3_dtype, T.float32, T.float32)
-    e5m2_dtype = select_fp8_e5m2_dtype()
+    e5m2_dtype = _tl_dtype(select_fp8_e5m2_dtype())
     assert_tl_matmul_correctness(128, 128, 128, e5m2_dtype, T.float32, T.float32)
🤖 Prompt for AI Agents
In `@examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py` around lines 237 -
240, The selected fp8 helpers (select_fp8_e4m3_dtype and select_fp8_e5m2_dtype)
return dtype names as strings but tl_matmul/assert_tl_matmul_correctness expect
TileLang dtype objects (e.g., T.float8_e4m3fn); update the calls so you convert
the returned string to the TileLang dtype object before passing to
assert_tl_matmul_correctness (for example by resolving the string via the
TileLang type namespace, e.g., mapping or using getattr(T, dtype_name) to get
T.float8_e4m3fn), and use those resolved dtype objects when invoking
assert_tl_matmul_correctness and tl_matmul.



def run_regression_perf():
M, N, K = 4096, 4096, 4096
out_dtype, accum_dtype = "float32", "float32"
in_dtype = T.float8_e4m3fn
in_dtype = select_fp8_e4m3_dtype()
kernel_e4m3 = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
print(kernel_e4m3.get_kernel_source())
profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer)
latency_e4m3 = profiler_e4m3.do_bench(backend="cupti")
if torch.version.hip is None:
latency_e4m3 = profiler_e4m3.do_bench(backend="cupti")
else:
latency_e4m3 = profiler_e4m3.do_bench()
return latency_e4m3


Expand Down
2 changes: 2 additions & 0 deletions src/target/codegen_hip.cc
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
{"float8_e4m3fnx8", "long"},
{"float8_e5m2fnuzx4", "fp8_e5_4_t"},
{"float8_e5m2fnuzx8", "long"},
{"float8_e5m2x4", "fp8_e5_4_t"},
{"float8_e5m2x8", "long"},
{"float32x16", "float32x16"}};
std::string call_mfma_code = R"({
*((({C_dtype}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}),
Expand Down
Loading
Loading