-
Notifications
You must be signed in to change notification settings - Fork 418
[AMD] Fix ROCm FP8 dtype selection and MFMA support on gfx942/gfx950 #1743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughIntroduces runtime FP8 dtype selection utilities and replaces hardcoded FP8 dtypes across examples, tests, and kernels; refactors HIP FP8 type wrappers; and extends MFMA intrinsic generator to recognize additional FP8/BF8 variants. Changes are limited to dtype selection, wrappers, and generator mapping. Changes
Sequence DiagramsequenceDiagram
participant App as Application / Example
participant SelStr as select_fp8_e4m3_dtype()
participant Platform as Platform Detector
participant SelTorch as select_torch_fp8_e4m3_dtype()
participant Torch as PyTorch
App->>SelTorch: request torch.dtype for FP8
SelTorch->>SelStr: request FP8 dtype string
SelStr->>Platform: detect target (CUDA / ROCm / gfx)
Platform-->>SelStr: platform info
SelStr-->>SelTorch: return dtype string (e.g., "float8_e4m3fn" / "float8_e4m3fnuz")
SelTorch->>Torch: map string -> torch.dtype
Torch-->>SelTorch: return torch.dtype
SelTorch-->>App: provide runtime torch.dtype for kernels/tests
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@tilelang/utils/__init__.py`:
- Line 3: Remove the unused "# noqa: F401" from the import line in
tilelang/utils/__init__.py so Ruff no longer flags the directive as unused;
locate the line that imports determine_target, select_fp8_e4m3_dtype, and
select_torch_fp8_e4m3_dtype and delete the trailing "# noqa: F401"
(alternatively, if you intended to silence F401, enable F401 in the Ruff config
instead).
In `@tilelang/utils/target.py`:
- Around line 73-79: The dtype-selection logic currently queries device 0 via
torch.cuda.get_device_properties(0); change it to use the active CUDA/HIP device
by calling torch.cuda.current_device() (or equivalent) and pass that index into
torch.cuda.get_device_properties so the gcn_arch check (gcnArchName) reflects
the currently selected GPU; update the block in tilelang/utils/target.py where
torch.version.hip, torch.cuda.is_available(), props =
torch.cuda.get_device_properties(0), and gcn_arch.startswith("gfx950") are used
to instead call torch.cuda.get_device_properties(current_device) (using
torch.cuda.current_device()) before inspecting gcnArchName.
🧹 Nitpick comments (1)
src/tl_templates/hip/hip_fp8.h (1)
67-79: Consider adding a float constructor for API symmetry.
fp8_e5_tlacks a constructor fromfloatwhilefp8_e4_tprovides one (lines 43-51). If this asymmetry is intentional for the current use cases, this is fine. Otherwise, consider adding it for API consistency:♻️ Optional: Add float constructor to fp8_e5_t
struct fp8_e5_t { unsigned char data; __device__ fp8_e5_t() {} __device__ fp8_e5_t(hip_fp8_e5_t val) { data = *reinterpret_cast<unsigned char *>(&val); } + __device__ fp8_e5_t(float val) { + data = __hip_cvt_float_to_fp8(val, __HIP_SATFINITE, __HIP_E5M2_FNUZ); + } __device__ operator hip_fp8_e5_t() const { return *reinterpret_cast<const hip_fp8_e5_t *>(&data); } __device__ operator float() const { return static_cast<float>(static_cast<hip_fp8_e5_t>(*this)); } };
tilelang/utils/__init__.py
Outdated
| """The profiler and convert to torch utils""" | ||
|
|
||
| from .target import determine_target # noqa: F401 | ||
| from .target import determine_target, select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype # noqa: F401 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the unused # noqa: F401 to satisfy Ruff.
Ruff flags the directive as unused on this line, which can fail linting. Either drop it or enable F401 in the config.
🧹 Proposed fix
-from .target import determine_target, select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype # noqa: F401
+from .target import determine_target, select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from .target import determine_target, select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype # noqa: F401 | |
| from .target import determine_target, select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype |
🧰 Tools
🪛 Ruff (0.14.14)
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
🤖 Prompt for AI Agents
In `@tilelang/utils/__init__.py` at line 3, Remove the unused "# noqa: F401" from
the import line in tilelang/utils/__init__.py so Ruff no longer flags the
directive as unused; locate the line that imports determine_target,
select_fp8_e4m3_dtype, and select_torch_fp8_e4m3_dtype and delete the trailing
"# noqa: F401" (alternatively, if you intended to silence F401, enable F401 in
the Ruff config instead).
| if torch.version.hip is None: | ||
| return "float8_e4m3fn" | ||
| if not torch.cuda.is_available(): | ||
| return "float8_e4m3fnuz" | ||
| props = torch.cuda.get_device_properties(0) | ||
| gcn_arch = getattr(props, "gcnArchName", "") | ||
| if gcn_arch.startswith("gfx950"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
cd /tmp && find . -type f -name "target.py" | head -20Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
find . -type f -name "target.py" -path "*/tilelang/utils/*"Repository: tile-ai/tilelang
Length of output: 86
🏁 Script executed:
find . -type f -name "*.py" | grep -i target | head -20Repository: tile-ai/tilelang
Length of output: 86
🏁 Script executed:
git ls-files | grep target.pyRepository: tile-ai/tilelang
Length of output: 84
🏁 Script executed:
cat -n tilelang/utils/target.pyRepository: tile-ai/tilelang
Length of output: 9539
🏁 Script executed:
rg "select_fp8_e4m3_dtype" --type py -B 2 -A 2Repository: tile-ai/tilelang
Length of output: 3304
🏁 Script executed:
rg "torch.cuda.get_device_properties" --type py -B 3 -A 3Repository: tile-ai/tilelang
Length of output: 6868
🏁 Script executed:
web_search
PyTorch torch.cuda.current_device() torch.cuda.get_device_properties() multi-GPU best practicesRepository: tile-ai/tilelang
Length of output: 291
🌐 Web query:
PyTorch torch.cuda.current_device() best practices multi-GPU systems
💡 Result:
What torch.cuda.current_device() means (and why it can surprise you)
- PyTorch tracks a per-process “currently selected” CUDA device;
torch.cuda.current_device()returns that device index. By default it’s typicallycuda:0unless you change it. CUDA tensor allocation defaults to this device. [1][2] - If you restrict visibility with
CUDA_VISIBLE_DEVICES, thencuda:0is “GPU 0 among the visible set”, which may map to different physical GPUs in different processes—so seeingcurrent_device()==0can be totally correct in multi-process setups. [1][5]
Best practices on multi-GPU systems
1) Prefer explicit devices over relying on the “current device”
Use torch.device("cuda:N") (or a passed-in device) and move tensors/modules explicitly (.to(device)). This avoids accidental allocations on the wrong GPU when the current device changes. [1]
2) If you do rely on the current device, set it intentionally (per process)
- Use
torch.cuda.set_device(local_rank)early before creating CUDA tensors (or wrap code inwith torch.cuda.device(local_rank):). [1][4] - This matters for distributed/NCCL: some collectives use
torch.cuda.current_device()to decide which GPU to use, and PyTorch explicitly puts the responsibility on you to ensure each rank has the right GPU selected. [4]
3) With torchrun/DDP: map one process ↔ one GPU
torchrunprovidesLOCAL_RANK; use it to select the GPU and to configure DDPdevice_ids/output_device. [3]- Minimal pattern:
import os, torch
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
model = model.to(local_rank)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[local_rank], output_device=local_rank
)[3][4]
4) Use CUDA_VISIBLE_DEVICES to control which GPUs PyTorch can see
This is often the cleanest way to partition GPUs across jobs/process groups, and it changes the meaning of device indices inside the process (re-indexing from 0). [1][5]
5) When creating new tensors inside forward, don’t depend on the current device
If you need a new tensor “like an existing tensor”, use x.new_* / new_tensor so it inherits device/dtype, rather than torch.* factory calls that can follow the current-device context. [1]
Sources: PyTorch CUDA semantics / best practices [1], torch.cuda API docs [2], torchrun docs [3], torch.distributed docs (NCCL note about current_device) [4], PyTorch forum guidance on current_device, set_device, and CUDA_VISIBLE_DEVICES behavior [5].
Use the current device when querying GPU architecture for dtype selection.
In multi-GPU ROCm/HIP systems, device 0 may not be the active device. When selecting the FP8 dtype based on GPU architecture (gfx950 vs. other), the function must query the current device instead of hardcoding device 0, otherwise it will return the incorrect dtype.
🔧 Proposed fix
- props = torch.cuda.get_device_properties(0)
+ device = torch.cuda.current_device()
+ props = torch.cuda.get_device_properties(device)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if torch.version.hip is None: | |
| return "float8_e4m3fn" | |
| if not torch.cuda.is_available(): | |
| return "float8_e4m3fnuz" | |
| props = torch.cuda.get_device_properties(0) | |
| gcn_arch = getattr(props, "gcnArchName", "") | |
| if gcn_arch.startswith("gfx950"): | |
| if torch.version.hip is None: | |
| return "float8_e4m3fn" | |
| if not torch.cuda.is_available(): | |
| return "float8_e4m3fnuz" | |
| device = torch.cuda.current_device() | |
| props = torch.cuda.get_device_properties(device) | |
| gcn_arch = getattr(props, "gcnArchName", "") | |
| if gcn_arch.startswith("gfx950"): |
🤖 Prompt for AI Agents
In `@tilelang/utils/target.py` around lines 73 - 79, The dtype-selection logic
currently queries device 0 via torch.cuda.get_device_properties(0); change it to
use the active CUDA/HIP device by calling torch.cuda.current_device() (or
equivalent) and pass that index into torch.cuda.get_device_properties so the
gcn_arch check (gcnArchName) reflects the currently selected GPU; update the
block in tilelang/utils/target.py where torch.version.hip,
torch.cuda.is_available(), props = torch.cuda.get_device_properties(0), and
gcn_arch.startswith("gfx950") are used to instead call
torch.cuda.get_device_properties(current_device) (using
torch.cuda.current_device()) before inspecting gcnArchName.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py`:
- Around line 237-240: The selected fp8 helpers (select_fp8_e4m3_dtype and
select_fp8_e5m2_dtype) return dtype names as strings but
tl_matmul/assert_tl_matmul_correctness expect TileLang dtype objects (e.g.,
T.float8_e4m3fn); update the calls so you convert the returned string to the
TileLang dtype object before passing to assert_tl_matmul_correctness (for
example by resolving the string via the TileLang type namespace, e.g., mapping
or using getattr(T, dtype_name) to get T.float8_e4m3fn), and use those resolved
dtype objects when invoking assert_tl_matmul_correctness and tl_matmul.
🧹 Nitpick comments (1)
src/tl_templates/hip/hip_fp8.h (1)
160-179: Add storage ctor/assignment forfp8_e5_4_tfor API parity.
fp8_e4_4_tsupports construction and assignment from its storage type (uint32_t);fp8_e5_4_tdoesn't. If generated code attempts to use packed storage for E5, this API gap can cause compilation errors. Alignfp8_e5_4_twith thefp8_e4_4_tinterface by adding a storage constructor and assignment operator.♻️ Suggested parity additions
struct __align__(4) fp8_e5_4_t { union { fp8_e5_4_storage_t data; struct { fp8_e5_t x; fp8_e5_t y; fp8_e5_t z; fp8_e5_t w; }; }; __device__ fp8_e5_4_t() {} + __device__ fp8_e5_4_t(const fp8_e5_4_storage_t &val) : data(val) {} __device__ fp8_e5_4_t(const hip_fp8x4_e5_t &val) { data = *reinterpret_cast<const fp8_e5_4_storage_t *>(&val); } __device__ operator hip_fp8x4_e5_t() const { return *reinterpret_cast<const hip_fp8x4_e5_t *>(&data); } + __device__ fp8_e5_4_t &operator=(const fp8_e5_4_storage_t &val) { + data = val; + return *this; + } };
| e4m3_dtype = select_fp8_e4m3_dtype() | ||
| assert_tl_matmul_correctness(128, 128, 128, e4m3_dtype, T.float32, T.float32) | ||
| e5m2_dtype = select_fp8_e5m2_dtype() | ||
| assert_tl_matmul_correctness(128, 128, 128, e5m2_dtype, T.float32, T.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
cd /tmp && cat > verify_fp8.py << 'EOF'
import sys
try:
import tilelang.language as T
from tilelang.utils import select_fp8_e4m3_dtype, select_fp8_e5m2_dtype
for fn in (select_fp8_e4m3_dtype, select_fp8_e5m2_dtype):
d = fn()
print(f"{fn.__name__}: value={d}, type={type(d).__name__}, repr={repr(d)}")
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
import traceback
traceback.print_exc()
EOF
python verify_fp8.pyRepository: tile-ai/tilelang
Length of output: 146
🏁 Script executed:
# Find the implementations of these selector functions
rg "def select_fp8_e4m3_dtype|def select_fp8_e5m2_dtype" --type py -A 5Repository: tile-ai/tilelang
Length of output: 818
🏁 Script executed:
# Check what tl_matmul does with in_dtype
rg "def tl_matmul|def assert_tl_matmul" --type py -A 15Repository: tile-ai/tilelang
Length of output: 37110
🏁 Script executed:
# Check the example file itself
cat -n examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py | sed -n '230,245p'Repository: tile-ai/tilelang
Length of output: 713
Convert string dtypes from select_fp8_*() to TileLang dtype objects.
select_fp8_e4m3_dtype() and select_fp8_e5m2_dtype() return strings (e.g., "float8_e4m3fn"), but tl_matmul asserts in_dtype against TileLang dtype objects like T.float8_e4m3fn. This will cause an AssertionError at runtime because "float8_e4m3fn" != T.float8_e4m3fn.
Safe normalization
+def _tl_dtype(d):
+ return getattr(T, d) if isinstance(d, str) else d
+
def main():
- e4m3_dtype = select_fp8_e4m3_dtype()
+ e4m3_dtype = _tl_dtype(select_fp8_e4m3_dtype())
assert_tl_matmul_correctness(128, 128, 128, e4m3_dtype, T.float32, T.float32)
- e5m2_dtype = select_fp8_e5m2_dtype()
+ e5m2_dtype = _tl_dtype(select_fp8_e5m2_dtype())
assert_tl_matmul_correctness(128, 128, 128, e5m2_dtype, T.float32, T.float32)🤖 Prompt for AI Agents
In `@examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py` around lines 237 -
240, The selected fp8 helpers (select_fp8_e4m3_dtype and select_fp8_e5m2_dtype)
return dtype names as strings but tl_matmul/assert_tl_matmul_correctness expect
TileLang dtype objects (e.g., T.float8_e4m3fn); update the calls so you convert
the returned string to the TileLang dtype object before passing to
assert_tl_matmul_correctness (for example by resolving the string via the
TileLang type namespace, e.g., mapping or using getattr(T, dtype_name) to get
T.float8_e4m3fn), and use those resolved dtype objects when invoking
assert_tl_matmul_correctness and tl_matmul.
LeiWang1999
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your contribution, I left a simple comment that would be better to rename the select_fp8_type into determine_fp8_type.
| return arch == "arm64" | ||
|
|
||
|
|
||
| def select_fp8_e4m3_dtype() -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be better to rename it into determine_fp8_type
Description
This PR fixes ROCm FP8 handling across gfx942/gfx950 by selecting the correct
FP8 variants at runtime and making MFMA/codegen recognize the FP8 dtypes used
by ROCm. It also consolidates FP8 selection into shared helpers so examples and
tests stay consistent across devices.
Key changes
select_fp8_e4m3_dtype()/select_fp8_e5m2_dtype()helpers and Torch variants.Tests
pytest -q testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.pypython /opt/tilelang/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.pypython /opt/tilelang/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.pypython /opt/tilelang/examples/gemm_fp8/regression_example_gemm_fp8.pyCC: @Gongen-Ali
Summary by CodeRabbit
New Features
Refactor
Tests
✏️ Tip: You can customize this high-level summary in your review settings.