Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions examples/gemm_fp8/example_tilelang_gemm_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import torch_assert_close
from tilelang.utils import select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype
import itertools


Expand All @@ -17,8 +18,9 @@ def supply_prog(args):
a_param, b_param = args
M, K = a_param.shape
N, _ = b_param.shape
a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
fp8_dtype = select_torch_fp8_e4m3_dtype()
a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype)
b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype)
return [a, b]


Expand Down Expand Up @@ -53,7 +55,7 @@ def get_configs():
)
@tilelang.jit(out_idx=[-1])
def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type):
dtype = T.float8_e4m3fnuz
dtype = select_fp8_e4m3_dtype()
accum_dtype = T.float32

@T.prim_func
Expand Down Expand Up @@ -104,8 +106,9 @@ def gemm_fp8_ss(

def test_gemm_fp8(M, N, K):
kernel = fp8_matmul(M, N, K)
a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
fp8_dtype = select_torch_fp8_e4m3_dtype()
a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype)
b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype)
c = kernel(a, b)
ref_c = ref_program(a, b)
torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
Expand Down
7 changes: 5 additions & 2 deletions examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tilelang.tileop.base import GemmWarpPolicy
from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mfma_macro_generator import MatrixCorePreshuffleIntrinEmitter
from tilelang.utils import select_fp8_e4m3_dtype

tilelang.testing.set_random_seed(0)

Expand Down Expand Up @@ -45,12 +46,14 @@ def tl_matmul(
num_stages,
k_pack=2,
num_threads=256,
in_dtype=T.float8_e4m3fnuz,
in_dtype=None,
out_dtype=T.float32,
accum_dtype=T.float32,
a_transposed=False,
b_transposed=True,
):
if in_dtype is None:
in_dtype = select_fp8_e4m3_dtype()
b_preshuffle = True
warp_size = 64
num_warps = num_threads // warp_size
Expand Down Expand Up @@ -164,7 +167,7 @@ def shuffle_weight(


def assert_tl_matmul_correctness(M, N, K, k_pack=1, a_transposed=False, b_transposed=True):
in_dtype = T.float8_e4m3fnuz
in_dtype = select_fp8_e4m3_dtype()
out_dtype = T.float32
accum_dtype = T.float32
kernel = tl_matmul(
Expand Down
105 changes: 69 additions & 36 deletions src/tl_templates/hip/hip_fp8.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include <hip/amd_detail/amd_hip_fp8.h>
#include <stdint.h>

#define HIP_FP8_ENABLED 1

Expand All @@ -16,53 +17,84 @@

#if (TILELANG_FP8_E4M3_VARIANT == TILELANG_FP8_E4M3_VARIANT_FN)
#if defined(__clang__) && defined(__HIPCC__)
#if __is_identifier(__hip_fp8_e4m3)
#if !__is_identifier(__hip_fp8_e4m3)
#define TILELANG_HAVE_FP8_E4M3_FN 1
#endif
#endif
#endif

#if defined(TILELANG_HAVE_FP8_E4M3_FN)
using fp8_e4_t = __hip_fp8_e4m3;
using fp8_e4_2_t = __hip_fp8x2_e4m3;
using fp8_e4_4_storage_t = __hip_fp8x4_e4m3;
using hip_fp8_e4_t = __hip_fp8_e4m3;
using hip_fp8x2_e4_t = __hip_fp8x2_e4m3;
using hip_fp8x4_e4_t = __hip_fp8x4_e4m3;
#else
// FNUZ path (MI300X and universal fallback)
using fp8_e4_t = __hip_fp8_e4m3_fnuz;
using fp8_e4_2_t = __hip_fp8x2_e4m3_fnuz;
using fp8_e4_4_storage_t = __hip_fp8x4_e4m3_fnuz;
using hip_fp8_e4_t = __hip_fp8_e4m3_fnuz;
using hip_fp8x2_e4_t = __hip_fp8x2_e4m3_fnuz;
using hip_fp8x4_e4_t = __hip_fp8x4_e4m3_fnuz;
#endif

struct fp8_e4_t {
unsigned char data;
__device__ fp8_e4_t() {}
__device__ fp8_e4_t(hip_fp8_e4_t val) {
data = *reinterpret_cast<unsigned char *>(&val);
}
__device__ fp8_e4_t(float val) {
constexpr __hip_fp8_interpretation_t interp =
#if (TILELANG_FP8_E4M3_VARIANT == TILELANG_FP8_E4M3_VARIANT_FNUZ)
__HIP_E4M3_FNUZ;
#else
__HIP_E4M3;
#endif
data = __hip_cvt_float_to_fp8(val, __HIP_SATFINITE, interp);
}
__device__ operator hip_fp8_e4_t() const {
return *reinterpret_cast<const hip_fp8_e4_t *>(&data);
}
__device__ operator float() const {
return static_cast<float>(static_cast<hip_fp8_e4_t>(*this));
}
};

using fp8_e4_2_t = hip_fp8x2_e4_t;
using fp8_e4_4_storage_t = uint32_t;

// Additional FP8 types for compatibility
using fp8_e5_t = __hip_fp8_e5m2_fnuz;
using hip_fp8_e5_t = __hip_fp8_e5m2_fnuz;
using fp8_e5_2_t = __hip_fp8x2_e5m2_fnuz;

struct fp8_e5_t {
unsigned char data;
__device__ fp8_e5_t() {}
__device__ fp8_e5_t(hip_fp8_e5_t val) {
data = *reinterpret_cast<unsigned char *>(&val);
}
__device__ operator hip_fp8_e5_t() const {
return *reinterpret_cast<const hip_fp8_e5_t *>(&data);
}
__device__ operator float() const {
return static_cast<float>(static_cast<hip_fp8_e5_t>(*this));
}
};
// Note: E8M0 types are not supported in current HIP version
// using fp8_e8_t = __hip_fp8_e8m0_fnuz;
// using fp8_e8_2_t = __hip_fp8x2_e8m0_fnuz;

// Simple wrapper that provides member access for generated code
struct fp8_e4_4_t {
union {
// __hip_fp8x4_e4m3_fnuz data;
fp8_e4_4_storage_t data;
struct {
fp8_e4_t x, y, z, w;
};
};

// Default constructor
__device__ fp8_e4_4_t() = default;
struct __align__(4) fp8_e4_4_t {
fp8_e4_4_storage_t data;

// Constructor from __hip_fp8x4_e4m3_fnuz
__device__ fp8_e4_4_t() {}
__device__ fp8_e4_4_t(const fp8_e4_4_storage_t &val) : data(val) {}
__device__ fp8_e4_4_t(const hip_fp8x4_e4_t &val) {
data = *reinterpret_cast<const fp8_e4_4_storage_t *>(&val);
}

// Constructor from float4
__device__ fp8_e4_4_t(const float4 &val) : data(val) {}

// Conversion operator to __hip_fp8x4_e4m3_fnuz
__device__ operator fp8_e4_4_storage_t() const { return data; }
__device__ operator hip_fp8x4_e4_t() const {
return *reinterpret_cast<const hip_fp8x4_e4_t *>(&data);
}

// Assignment operator
__device__ fp8_e4_4_t &operator=(const fp8_e4_4_storage_t &val) {
data = val;
return *this;
Expand All @@ -80,16 +112,17 @@ struct __align__(16) fp8_e4_16_t {
};

// FP8 E5M2 vector types
struct fp8_e5_4_t {
union {
__hip_fp8x4_e5m2_fnuz data;
struct {
fp8_e5_t x, y, z, w;
};
};
__device__ fp8_e5_4_t() = delete;
__device__ fp8_e5_4_t(const __hip_fp8x4_e5m2_fnuz &val) : data(val) {}
__device__ operator __hip_fp8x4_e5m2_fnuz() const { return data; }
using fp8_e5_4_storage_t = uint32_t;

struct __align__(4) fp8_e5_4_t {
fp8_e5_4_storage_t data;
__device__ fp8_e5_4_t() {}
__device__ fp8_e5_4_t(const __hip_fp8x4_e5m2_fnuz &val) {
data = *reinterpret_cast<const fp8_e5_4_storage_t *>(&val);
}
__device__ operator __hip_fp8x4_e5m2_fnuz() const {
return *reinterpret_cast<const __hip_fp8x4_e5m2_fnuz *>(&data);
}
};

struct __align__(8) fp8_e5_8_t {
Expand Down
4 changes: 3 additions & 1 deletion tilelang/intrinsics/mfma_macro_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class MatrixCoreIntrinEmitter:
"int32": "int32",
"float8_e4m3": "e4m3",
"float8_e5m2": "e5m2",
"float8_e4m3fn": "e4m3fn",
"float8_e4m3fnuz": "e4m3fnuz",
"float8_e5m2fnuz": "e5m2fnuz",
}
Expand Down Expand Up @@ -108,7 +109,7 @@ def __init__(

def _initialize_k_dim(self, a_dtype=T.float16):
if isinstance(a_dtype, str):
if a_dtype in ["float8_e4m3fnuz", "float8_e5m2fnuz", T.int8]:
if a_dtype in ["float8_e4m3fn", "float8_e4m3fnuz", "float8_e5m2fnuz", T.int8]:
self.k_dim = 32
return
a_dtype = DataType(a_dtype)
Expand Down Expand Up @@ -141,6 +142,7 @@ def _initialize_mfma_prefix(self, k_dim=16):
"float32": "f32",
"int8": "i8",
"int32": "i32",
"float8_e4m3fn": "fp8",
"float8_e4m3fnuz": "fp8",
"float8_e5m2fnuz": "fp8",
}[in_dtype]
Expand Down
2 changes: 1 addition & 1 deletion tilelang/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""The profiler and convert to torch utils"""

from .target import determine_target # noqa: F401
from .target import determine_target, select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype # noqa: F401
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove the unused # noqa: F401 to satisfy Ruff.

Ruff flags the directive as unused on this line, which can fail linting. Either drop it or enable F401 in the config.

🧹 Proposed fix
-from .target import determine_target, select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype  # noqa: F401
+from .target import determine_target, select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from .target import determine_target, select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype # noqa: F401
from .target import determine_target, select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype
🧰 Tools
🪛 Ruff (0.14.14)

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

🤖 Prompt for AI Agents
In `@tilelang/utils/__init__.py` at line 3, Remove the unused "# noqa: F401" from
the import line in tilelang/utils/__init__.py so Ruff no longer flags the
directive as unused; locate the line that imports determine_target,
select_fp8_e4m3_dtype, and select_torch_fp8_e4m3_dtype and delete the trailing
"# noqa: F401" (alternatively, if you intended to silence F401, enable F401 in
the Ruff config instead).

from .tensor import TensorSupplyType, torch_assert_close, map_torch_type # noqa: F401
from .language import (
is_global, # noqa: F401
Expand Down
25 changes: 25 additions & 0 deletions tilelang/utils/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,31 @@ def check_metal_availability() -> bool:
return arch == "arm64"


def select_fp8_e4m3_dtype() -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be better to rename it into determine_fp8_type

"""
Select the correct FP8 E4M3 dtype string for the current platform.
- CUDA defaults to FP8 E4M3FN.
- ROCm uses FNUZ except gfx950 (OCP), which requires FN.
"""
if torch.version.hip is None:
return "float8_e4m3fn"
if not torch.cuda.is_available():
return "float8_e4m3fnuz"
props = torch.cuda.get_device_properties(0)
gcn_arch = getattr(props, "gcnArchName", "")
if gcn_arch.startswith("gfx950"):
Comment on lines +73 to +79
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cd /tmp && find . -type f -name "target.py" | head -20

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

find . -type f -name "target.py" -path "*/tilelang/utils/*"

Repository: tile-ai/tilelang

Length of output: 86


🏁 Script executed:

find . -type f -name "*.py" | grep -i target | head -20

Repository: tile-ai/tilelang

Length of output: 86


🏁 Script executed:

git ls-files | grep target.py

Repository: tile-ai/tilelang

Length of output: 84


🏁 Script executed:

cat -n tilelang/utils/target.py

Repository: tile-ai/tilelang

Length of output: 9539


🏁 Script executed:

rg "select_fp8_e4m3_dtype" --type py -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 3304


🏁 Script executed:

rg "torch.cuda.get_device_properties" --type py -B 3 -A 3

Repository: tile-ai/tilelang

Length of output: 6868


🏁 Script executed:

web_search
PyTorch torch.cuda.current_device() torch.cuda.get_device_properties() multi-GPU best practices

Repository: tile-ai/tilelang

Length of output: 291


🌐 Web query:

PyTorch torch.cuda.current_device() best practices multi-GPU systems

💡 Result:

What torch.cuda.current_device() means (and why it can surprise you)

  • PyTorch tracks a per-process “currently selected” CUDA device; torch.cuda.current_device() returns that device index. By default it’s typically cuda:0 unless you change it. CUDA tensor allocation defaults to this device. [1][2]
  • If you restrict visibility with CUDA_VISIBLE_DEVICES, then cuda:0 is “GPU 0 among the visible set”, which may map to different physical GPUs in different processes—so seeing current_device()==0 can be totally correct in multi-process setups. [1][5]

Best practices on multi-GPU systems

1) Prefer explicit devices over relying on the “current device”

Use torch.device("cuda:N") (or a passed-in device) and move tensors/modules explicitly (.to(device)). This avoids accidental allocations on the wrong GPU when the current device changes. [1]

2) If you do rely on the current device, set it intentionally (per process)

  • Use torch.cuda.set_device(local_rank) early before creating CUDA tensors (or wrap code in with torch.cuda.device(local_rank):). [1][4]
  • This matters for distributed/NCCL: some collectives use torch.cuda.current_device() to decide which GPU to use, and PyTorch explicitly puts the responsibility on you to ensure each rank has the right GPU selected. [4]

3) With torchrun/DDP: map one process ↔ one GPU

  • torchrun provides LOCAL_RANK; use it to select the GPU and to configure DDP device_ids / output_device. [3]
  • Minimal pattern:
import os, torch
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

model = model.to(local_rank)
model = torch.nn.parallel.DistributedDataParallel(
    model, device_ids=[local_rank], output_device=local_rank
)

[3][4]

4) Use CUDA_VISIBLE_DEVICES to control which GPUs PyTorch can see

This is often the cleanest way to partition GPUs across jobs/process groups, and it changes the meaning of device indices inside the process (re-indexing from 0). [1][5]

5) When creating new tensors inside forward, don’t depend on the current device

If you need a new tensor “like an existing tensor”, use x.new_* / new_tensor so it inherits device/dtype, rather than torch.* factory calls that can follow the current-device context. [1]

Sources: PyTorch CUDA semantics / best practices [1], torch.cuda API docs [2], torchrun docs [3], torch.distributed docs (NCCL note about current_device) [4], PyTorch forum guidance on current_device, set_device, and CUDA_VISIBLE_DEVICES behavior [5].


Use the current device when querying GPU architecture for dtype selection.

In multi-GPU ROCm/HIP systems, device 0 may not be the active device. When selecting the FP8 dtype based on GPU architecture (gfx950 vs. other), the function must query the current device instead of hardcoding device 0, otherwise it will return the incorrect dtype.

🔧 Proposed fix
-    props = torch.cuda.get_device_properties(0)
+    device = torch.cuda.current_device()
+    props = torch.cuda.get_device_properties(device)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if torch.version.hip is None:
return "float8_e4m3fn"
if not torch.cuda.is_available():
return "float8_e4m3fnuz"
props = torch.cuda.get_device_properties(0)
gcn_arch = getattr(props, "gcnArchName", "")
if gcn_arch.startswith("gfx950"):
if torch.version.hip is None:
return "float8_e4m3fn"
if not torch.cuda.is_available():
return "float8_e4m3fnuz"
device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device)
gcn_arch = getattr(props, "gcnArchName", "")
if gcn_arch.startswith("gfx950"):
🤖 Prompt for AI Agents
In `@tilelang/utils/target.py` around lines 73 - 79, The dtype-selection logic
currently queries device 0 via torch.cuda.get_device_properties(0); change it to
use the active CUDA/HIP device by calling torch.cuda.current_device() (or
equivalent) and pass that index into torch.cuda.get_device_properties so the
gcn_arch check (gcnArchName) reflects the currently selected GPU; update the
block in tilelang/utils/target.py where torch.version.hip,
torch.cuda.is_available(), props = torch.cuda.get_device_properties(0), and
gcn_arch.startswith("gfx950") are used to instead call
torch.cuda.get_device_properties(current_device) (using
torch.cuda.current_device()) before inspecting gcnArchName.

return "float8_e4m3fn"
return "float8_e4m3fnuz"


def select_torch_fp8_e4m3_dtype() -> torch.dtype:
dtype_name = select_fp8_e4m3_dtype()
torch_dtype = getattr(torch, dtype_name, None)
if torch_dtype is None:
raise RuntimeError(f"PyTorch does not expose dtype {dtype_name}")
return torch_dtype


def normalize_cutedsl_target(target: str | Target) -> Target | None:
if isinstance(target, Target):
if target.kind.name == "cuda" and "cutedsl" in target.keys:
Expand Down
Loading