Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1244,7 +1244,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
}

// Handle conversion from float8 (E4M3/E5M2) to float32
if (tl::IsCudaVectorizableFP8(from_ty) && target_ty.is_float()) {
if (tl::IsCudaVectorizableFP8(from_ty) && target_ty.is_float() &&
target_ty.bits() == 32) {
bool from_type_is_e4m3 =
from_ty.is_float8_e4m3() || from_ty.is_float8_e4m3fn();
std::string type_suffix = from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2";
Expand Down Expand Up @@ -1280,7 +1281,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
}
}

// Handle conversion from float to float8 (E8M0)
// Handle conversion from float32 to float8 (E8M0)
if (from_ty.is_float() && from_ty.bits() == 32 &&
target_ty.is_float8_e8m0fnu()) {
// Use __tl_cvt_float2_to_e8m0x2 for vectorized conversion (float2 ->
Expand Down Expand Up @@ -1315,7 +1316,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
}

// Handle conversion from float32 to float4 (E2M1)
if (from_ty.is_float() && target_ty.is_float4_e2m1fn()) {
if (from_ty.is_float() && from_ty.bits() == 32 &&
target_ty.is_float4_e2m1fn()) {
// Use __tl_cvt_float2_to_fp4x2 for vectorized conversion (float2 -> fp4x2)
if (lanes == 2 || lanes == 4 || lanes == 8) {
PrintVectorizedCast("__tl_cvt_float2_to_fp4x2", "float2", "uint8_t", "",
Expand All @@ -1335,7 +1337,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
}

// Handle conversion from float4 (E2M1) to float32
if (from_ty.is_float4_e2m1fn() && target_ty.is_float()) {
if (from_ty.is_float4_e2m1fn() && target_ty.is_float() &&
target_ty.bits() == 32) {
// Use __tl_cvt_fp4x2_to_float2 for vectorized conversion (fp4x2 -> float2)
if (lanes == 2 || lanes == 4 || lanes == 8) {
PrintVectorizedCast("__tl_cvt_fp4x2_to_float2", "uint8_t", "float2", "",
Expand Down
68 changes: 58 additions & 10 deletions src/target/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,29 +162,43 @@ bool IsCudaVectorizableFP8(DataType dtype) {

bool IsCudaVectorizableCast(DataType from_ty, DataType target_ty) {
// float16 -> float32
if (from_ty.is_float16() && target_ty.is_float())
if (from_ty.is_float16() && target_ty.is_float() && target_ty.bits() == 32)
return true;

// float32 -> float16
if (from_ty.is_float() && target_ty.is_float16())
if (from_ty.is_float() && from_ty.bits() == 32 && target_ty.is_float16())
return true;

// bfloat16 -> float32
if (from_ty.is_bfloat16() && target_ty.is_float())
if (from_ty.is_bfloat16() && target_ty.is_float() && target_ty.bits() == 32)
return true;

// float32 -> bfloat16
if (from_ty.is_float() && target_ty.is_bfloat16())
if (from_ty.is_float() && from_ty.bits() == 32 && target_ty.is_bfloat16())
return true;

// float32 -> float8 (E4M3/E5M2)
if (from_ty.is_float() && IsCudaVectorizableFP8(target_ty))
if (from_ty.is_float() && from_ty.bits() == 32 &&
IsCudaVectorizableFP8(target_ty))
return true;

// float8 (E4M3/E5M2) -> float32
if (IsCudaVectorizableFP8(from_ty) && target_ty.is_float())
if (IsCudaVectorizableFP8(from_ty) && target_ty.is_float() &&
target_ty.bits() == 32)
return true;

// Not implemented for now

// float64(double) -> float8 (E4M3/E5M2)
// if (from_ty.is_float() && from_ty.bits() == 64 &&
// IsCudaVectorizableFP8(target_ty))
// return true;

// float8 (E4M3/E5M2) -> float64(double)
// if (IsCudaVectorizableFP8(from_ty) && target_ty.is_float() &&
// target_ty.bits() == 64)
// return true;

// float8 (E8M0) -> bfloat16
if (from_ty.is_float8_e8m0fnu() && target_ty.is_bfloat16())
return true;
Expand All @@ -193,16 +207,50 @@ bool IsCudaVectorizableCast(DataType from_ty, DataType target_ty) {
if (from_ty.is_bfloat16() && target_ty.is_float8_e8m0fnu())
return true;

// float32/double -> float8 (E8M0)
if (from_ty.is_float() && target_ty.is_float8_e8m0fnu())
// float32 -> float8 (E8M0)
if (from_ty.is_float() && from_ty.bits() == 32 &&
target_ty.is_float8_e8m0fnu())
return true;

// float64(double) -> float8 (E8M0)
if (from_ty.is_float() && from_ty.bits() == 64 &&
target_ty.is_float8_e8m0fnu())
return true;

// float4_e2m1fn -> float16
if (from_ty.is_float4_e2m1fn() && target_ty.is_float16())
return true;

// float16 -> float4_e2m1fn
if (from_ty.is_float16() && target_ty.is_float4_e2m1fn())
return true;

// float4_e2m1fn -> float32
if (from_ty.is_float4_e2m1fn() && target_ty.is_float())
if (from_ty.is_float4_e2m1fn() && target_ty.is_float() &&
target_ty.bits() == 32)
return true;

// float32 -> float4_e2m1fn
if (from_ty.is_float() && target_ty.is_float4_e2m1fn())
if (from_ty.is_float() && from_ty.bits() == 32 &&
target_ty.is_float4_e2m1fn())
return true;

// float4_e2m1fn -> float64(double)
if (from_ty.is_float4_e2m1fn() && target_ty.is_float() &&
target_ty.bits() == 64)
return true;

// float64(double) -> float4_e2m1fn
if (from_ty.is_float() && from_ty.bits() == 64 &&
target_ty.is_float4_e2m1fn())
return true;

// float4_e2m1fn -> bfloat16
if (from_ty.is_float4_e2m1fn() && target_ty.is_bfloat16())
return true;

// bfloat16 -> float4_e2m1fn
if (from_ty.is_bfloat16() && target_ty.is_float4_e2m1fn())
return true;

return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,6 @@ def run_vectorized_cast(src_dtype: T.dtype, dst_dtype: T.dtype, check_str: str,
(T.float32, T.float16, "__float22half2_rn", 4),
(T.float16, T.float32, "__half22float2", 2),
(T.float16, T.float32, "__half22float2", 4),
(T.float32, T.float8_e4m3fn, "__nv_cvt_float2_to_fp8x2", 2),
(T.float32, T.float8_e4m3fn, "__nv_cvt_float2_to_fp8x2", 4),
(T.float32, T.float8_e5m2, "__nv_cvt_float2_to_fp8x2", 2),
(T.float32, T.float8_e5m2, "__nv_cvt_float2_to_fp8x2", 4),
(T.float32, T.bfloat16, "__float22bfloat162_rn", 2),
(T.float32, T.bfloat16, "__float22bfloat162_rn", 4),
(T.bfloat16, T.float32, "__bfloat1622float2", 2),
Expand All @@ -105,6 +101,11 @@ def test_vectorized_cast(src_dtype, dst_dtype, check_str, lanes):
@pytest.mark.parametrize(
"src_dtype, dst_dtype, check_str, lanes",
[
# FP8 <-> FP32
(T.float32, T.float8_e4m3fn, "__nv_cvt_float2_to_fp8x2", 2),
(T.float32, T.float8_e4m3fn, "__nv_cvt_float2_to_fp8x2", 4),
(T.float32, T.float8_e5m2, "__nv_cvt_float2_to_fp8x2", 2),
(T.float32, T.float8_e5m2, "__nv_cvt_float2_to_fp8x2", 4),
(T.float8_e4m3fn, T.float32, "__tl_cvt_fp8x2_to_float2", 2),
(T.float8_e4m3fn, T.float32, "__tl_cvt_fp8x2_to_float2", 4),
(T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 2),
Expand Down
Loading