-
Notifications
You must be signed in to change notification settings - Fork 419
[AMD] Fix ROCm FP8 dtype selection and MFMA support on gfx942/gfx950 #1743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -64,6 +64,31 @@ def check_metal_availability() -> bool: | |||||||||||||||||||||||||||||||
| return arch == "arm64" | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def select_fp8_e4m3_dtype() -> str: | ||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would be better to rename it into |
||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| Select the correct FP8 E4M3 dtype string for the current platform. | ||||||||||||||||||||||||||||||||
| - CUDA defaults to FP8 E4M3FN. | ||||||||||||||||||||||||||||||||
| - ROCm uses FNUZ except gfx950 (OCP), which requires FN. | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| if torch.version.hip is None: | ||||||||||||||||||||||||||||||||
| return "float8_e4m3fn" | ||||||||||||||||||||||||||||||||
| if not torch.cuda.is_available(): | ||||||||||||||||||||||||||||||||
| return "float8_e4m3fnuz" | ||||||||||||||||||||||||||||||||
| props = torch.cuda.get_device_properties(0) | ||||||||||||||||||||||||||||||||
| gcn_arch = getattr(props, "gcnArchName", "") | ||||||||||||||||||||||||||||||||
| if gcn_arch.startswith("gfx950"): | ||||||||||||||||||||||||||||||||
|
Comment on lines
+73
to
+79
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: cd /tmp && find . -type f -name "target.py" | head -20Repository: tile-ai/tilelang Length of output: 42 🏁 Script executed: find . -type f -name "target.py" -path "*/tilelang/utils/*"Repository: tile-ai/tilelang Length of output: 86 🏁 Script executed: find . -type f -name "*.py" | grep -i target | head -20Repository: tile-ai/tilelang Length of output: 86 🏁 Script executed: git ls-files | grep target.pyRepository: tile-ai/tilelang Length of output: 84 🏁 Script executed: cat -n tilelang/utils/target.pyRepository: tile-ai/tilelang Length of output: 9539 🏁 Script executed: rg "select_fp8_e4m3_dtype" --type py -B 2 -A 2Repository: tile-ai/tilelang Length of output: 3304 🏁 Script executed: rg "torch.cuda.get_device_properties" --type py -B 3 -A 3Repository: tile-ai/tilelang Length of output: 6868 🏁 Script executed: web_search
PyTorch torch.cuda.current_device() torch.cuda.get_device_properties() multi-GPU best practicesRepository: tile-ai/tilelang Length of output: 291 🌐 Web query:
💡 Result: What
|
||||||||||||||||||||||||||||||||
| if torch.version.hip is None: | |
| return "float8_e4m3fn" | |
| if not torch.cuda.is_available(): | |
| return "float8_e4m3fnuz" | |
| props = torch.cuda.get_device_properties(0) | |
| gcn_arch = getattr(props, "gcnArchName", "") | |
| if gcn_arch.startswith("gfx950"): | |
| if torch.version.hip is None: | |
| return "float8_e4m3fn" | |
| if not torch.cuda.is_available(): | |
| return "float8_e4m3fnuz" | |
| device = torch.cuda.current_device() | |
| props = torch.cuda.get_device_properties(device) | |
| gcn_arch = getattr(props, "gcnArchName", "") | |
| if gcn_arch.startswith("gfx950"): |
🤖 Prompt for AI Agents
In `@tilelang/utils/target.py` around lines 73 - 79, The dtype-selection logic
currently queries device 0 via torch.cuda.get_device_properties(0); change it to
use the active CUDA/HIP device by calling torch.cuda.current_device() (or
equivalent) and pass that index into torch.cuda.get_device_properties so the
gcn_arch check (gcnArchName) reflects the currently selected GPU; update the
block in tilelang/utils/target.py where torch.version.hip,
torch.cuda.is_available(), props = torch.cuda.get_device_properties(0), and
gcn_arch.startswith("gfx950") are used to instead call
torch.cuda.get_device_properties(current_device) (using
torch.cuda.current_device()) before inspecting gcnArchName.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the unused
# noqa: F401to satisfy Ruff.Ruff flags the directive as unused on this line, which can fail linting. Either drop it or enable F401 in the config.
🧹 Proposed fix
📝 Committable suggestion
🧰 Tools
🪛 Ruff (0.14.14)
3-3: Unused
noqadirective (non-enabled:F401)Remove unused
noqadirective(RUF100)
🤖 Prompt for AI Agents