Skip to content

Add new data type Float8_e8m0fnu #4665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
case DataType::BFloat16:
case DataType::Float8_e4m3fn:
case DataType::Float8_e5m2:
case DataType::Float8_e8m0fnu:
return "f";
case DataType::Int:
// We use the LL suffix for int64_t literals
Expand Down
12 changes: 12 additions & 0 deletions csrc/device_lower/analysis/device_version.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ void MinimumDeviceVersion::dispatch(Val* val) {
"Fusion contains Float8_xxx values which was not supported in given "
"CUDA version");
#endif // (CUDA_VERSION >= 12010)
}
if (val->dtype() == DataType::Float8_e8m0fnu) {
#if (CUDA_VERSION >= 12070)
ensureVersion(
{10, 0},
"Fusion contains Float8_e8m0fnu values which was introduced in "
"Blackwell (10.0)");
#else
NVF_ERROR(
"Fusion contains Float8_e8m0fnu values which was not supported in "
"given CUDA version");
#endif // (CUDA_VERSION >= 12070)
}
IterVisitor::dispatch(val);
}
Expand Down
3 changes: 2 additions & 1 deletion csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2783,7 +2783,8 @@ void IndexLowering::handle(const CatOp* cat) {

DataType dt = out->dtype();
bool use_bitwise_or = dt == DataType::Half || dt == DataType::BFloat16 ||
dt == DataType::Float8_e4m3fn || dt == DataType::Float8_e5m2;
dt == DataType::Float8_e4m3fn || dt == DataType::Float8_e5m2 ||
dt == DataType::Float8_e8m0fnu;
BinaryOpType op_type =
use_bitwise_or ? BinaryOpType::BitwiseOr : BinaryOpType::Add;

Expand Down
5 changes: 0 additions & 5 deletions csrc/ops/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@
#include <type.h>
#include <type_promotion.h>

#include <c10/util/BFloat16.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Half.h>

#include <cfloat>

namespace nvfuser {
Expand Down
10 changes: 10 additions & 0 deletions csrc/ops/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,11 @@ Val* getMinimumValue(DataType v) {
return IrBuilder::create<Val>(static_cast<double>(
-std::numeric_limits<c10::Float8_e5m2>::infinity()));
break;
case DataType::Float8_e8m0fnu:
// e8m0 is finite.
return IrBuilder::create<Val>(static_cast<double>(
-std::numeric_limits<c10::Float8_e8m0fnu>::max()));
break;
case (DataType::Int):
return IrBuilder::create<Val>(std::numeric_limits<int64_t>::lowest());
break;
Expand Down Expand Up @@ -588,6 +593,11 @@ Val* getMaximumValue(DataType v) {
return IrBuilder::create<Val>(static_cast<double>(
std::numeric_limits<c10::Float8_e5m2>::infinity()));
break;
case DataType::Float8_e8m0fnu:
// e8m0 is finite.
return IrBuilder::create<Val>(
static_cast<double>(std::numeric_limits<c10::Float8_e8m0fnu>::max()));
break;
case (DataType::Int):
return IrBuilder::create<Val>(std::numeric_limits<int64_t>::max());
break;
Expand Down
1 change: 1 addition & 0 deletions csrc/runtime/allocations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ void fillTensorWithNan(at::Tensor& t) {
case at::ScalarType::BFloat16:
case at::ScalarType::Float8_e4m3fn:
case at::ScalarType::Float8_e5m2:
case at::ScalarType::Float8_e8m0fnu:
t.fill_(std::nan(""));
break;
case at::ScalarType::ComplexHalf:
Expand Down
7 changes: 6 additions & 1 deletion csrc/runtime/executor_kernel_arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,16 @@ std::vector<std::byte> polymorphicValueToBytes(
at::Float8_e5m2 v8 = (at::Float8_e5m2)(float)v;
return std::vector<std::byte>(
(std::byte*)&v8, (std::byte*)&v8 + sizeof(at::Float8_e5m2));
} else if (dtype == DataType::Float8_e8m0fnu) {
at::Float8_e8m0fnu v8 = (at::Float8_e8m0fnu)(float)v;
return std::vector<std::byte>(
(std::byte*)&v8, (std::byte*)&v8 + sizeof(at::Float8_e8m0fnu));
} else {
NVF_THROW(
"Cannot convert double to ",
dtype,
" type: only half, bfloat16, float and double are supported.");
" type: only half, bfloat16, float, double, fp8_e4m3fn, fp8_e5m2, "
"fp8_e8m0fnu are supported.");
}
} else if (argument.is<std::complex<double>>()) {
// FUSER_PERF_SCOPE("polymorphicValueToBytes(std::complex<double>)");
Expand Down
41 changes: 38 additions & 3 deletions csrc/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,21 +112,27 @@ bool isInclusiveType(const DataType& base_type, const DataType& wider_type) {
(base_type == DataType::Double || base_type == DataType::Float ||
base_type == DataType::Half || base_type == DataType::BFloat16 ||
base_type == DataType::Float8_e4m3fn ||
base_type == DataType::Float8_e5m2)) {
base_type == DataType::Float8_e5m2 ||
base_type == DataType::Float8_e8m0fnu)) {
return true;
}
if ((wider_type == DataType::Float || wider_type == DataType::ComplexFloat) &&
(base_type == DataType::Float || base_type == DataType::Half ||
base_type == DataType::BFloat16 ||
base_type == DataType::Float8_e4m3fn ||
base_type == DataType::Float8_e5m2)) {
base_type == DataType::Float8_e5m2 ||
base_type == DataType::Float8_e8m0fnu)) {
return true;
}
if ((wider_type == DataType::Half || wider_type == DataType::BFloat16) &&
(base_type == DataType::Float8_e4m3fn ||
base_type == DataType::Float8_e5m2)) {
return true;
}
if (wider_type == DataType::BFloat16 &&
base_type == DataType::Float8_e8m0fnu) {
return true;
}
if ((wider_type == DataType::Int || wider_type == DataType::Double ||
wider_type == DataType::ComplexDouble) &&
base_type == DataType::Int32) {
Expand Down Expand Up @@ -173,6 +179,9 @@ bool isSupportedTypeByDevice(DataType dtype) {
if (dtype == DataType::Float8_e4m3fn || dtype == DataType::Float8_e5m2) {
return major_ver >= 9;
}
if (dtype == DataType::Float8_e8m0fnu) {
return major_ver >= 10;
}
return true;
}

Expand Down Expand Up @@ -227,6 +236,8 @@ static std::string data_type2string(DataType t) {
return "__e4m3";
case DataType::Float8_e5m2:
return "__e5m2";
case DataType::Float8_e8m0fnu:
return "__e8m0";
case DataType::Float4_e2m1:
return "e2m1";
case DataType::Index:
Expand Down Expand Up @@ -1227,6 +1238,23 @@ static const char* supported_casts2string(std::pair<DataType, DataType> t) {
case supported_switch_pair(DataType::BFloat16, DataType::Float8_e4m3fn):
return "__bfloat2e4m3";

case supported_switch_pair(DataType::Float8_e8m0fnu, DataType::Float):
return "__e8m02float";
case supported_switch_pair(DataType::Float8_e8m0fnu, DataType::Double):
return "__e8m02double";
case supported_switch_pair(DataType::Float8_e8m0fnu, DataType::Half):
return "__e8m02half";
case supported_switch_pair(DataType::Float8_e8m0fnu, DataType::BFloat16):
return "__e8m02bfloat";
case supported_switch_pair(DataType::Float, DataType::Float8_e8m0fnu):
return "__float2e8m0";
case supported_switch_pair(DataType::Double, DataType::Float8_e8m0fnu):
return "__double2e8m0";
case supported_switch_pair(DataType::Half, DataType::Float8_e8m0fnu):
return "__half2e8m0";
case supported_switch_pair(DataType::BFloat16, DataType::Float8_e8m0fnu):
return "__bfloat2e8m0";

default:
return nullptr;
}
Expand All @@ -1248,6 +1276,8 @@ DataType aten_to_data_type(const at::ScalarType& scalar_type) {
return DataType::Float8_e4m3fn;
case at::ScalarType::Float8_e5m2:
return DataType::Float8_e5m2;
case at::ScalarType::Float8_e8m0fnu:
return DataType::Float8_e8m0fnu;
case at::ScalarType::Char:
return DataType::Char;
case at::ScalarType::Short:
Expand Down Expand Up @@ -1290,6 +1320,8 @@ at::ScalarType data_type_to_aten(const DataType& data_type) {
return at::ScalarType::Float8_e4m3fn;
case DataType::Float8_e5m2:
return at::ScalarType::Float8_e5m2;
case DataType::Float8_e8m0fnu:
return at::ScalarType::Float8_e8m0fnu;
case DataType::Index:
NVF_THROW(
"Index is determined at compile time,",
Expand Down Expand Up @@ -1574,6 +1606,7 @@ std::string typePrefix(const DataType data_type) {
case DataType::BFloat16:
case DataType::Float8_e4m3fn:
case DataType::Float8_e5m2:
case DataType::Float8_e8m0fnu:
return "f";
case DataType::Index:
case DataType::Int:
Expand Down Expand Up @@ -1708,6 +1741,7 @@ int max_digits10(DataType dtype) {
// Type Precision max_digits10
// fp8_e5m2 3 2
// fp8_e4m3 4 3
// fp8_e8m0 1 2
// bfloat16 8 4
// float16 11 5
// float32 24 9
Expand All @@ -1723,7 +1757,8 @@ int max_digits10(DataType dtype) {
return 4;
} else if (dtype == DataType::Float8_e4m3fn) {
return 3;
} else if (dtype == DataType::Float8_e5m2) {
} else if (
dtype == DataType::Float8_e5m2 || dtype == DataType::Float8_e8m0fnu) {
return 2;
} else {
NVF_CHECK(
Expand Down
11 changes: 10 additions & 1 deletion csrc/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ enum class PrimDataType {
BFloat16,
Float8_e4m3fn,
Float8_e5m2,
Float8_e8m0fnu,
Float4_e2m1,
// Integral types
Char,
Expand Down Expand Up @@ -190,6 +191,7 @@ struct DataType {
static constexpr PrimDataType Float4_e2m1 = PrimDataType::Float4_e2m1;
static constexpr PrimDataType Float8_e4m3fn = PrimDataType::Float8_e4m3fn;
static constexpr PrimDataType Float8_e5m2 = PrimDataType::Float8_e5m2;
static constexpr PrimDataType Float8_e8m0fnu = PrimDataType::Float8_e8m0fnu;
static constexpr PrimDataType Index = PrimDataType::Index;
static constexpr PrimDataType Char = PrimDataType::Char;
static constexpr PrimDataType Short = PrimDataType::Short;
Expand Down Expand Up @@ -268,7 +270,8 @@ bool isInclusiveType(const DataType& base_type, const DataType& type);
inline bool isFloatingPointType(DataType dtype) {
return dtype == DataType::Double || dtype == DataType::Float ||
dtype == DataType::Half || dtype == DataType::BFloat16 ||
dtype == DataType::Float8_e4m3fn || dtype == DataType::Float8_e5m2;
dtype == DataType::Float8_e4m3fn || dtype == DataType::Float8_e5m2 ||
dtype == DataType::Float8_e8m0fnu;
}

// Returns if the datatype is an integer type
Expand Down Expand Up @@ -409,6 +412,10 @@ DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
DataType::Float8_e5m2,
at::ScalarType::Float8_e5m2,
at::Float8_e5m2);
DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
DataType::Float8_e8m0fnu,
at::ScalarType::Float8_e8m0fnu,
at::Float8_e8m0fnu);
DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE(
DataType::Char,
at::ScalarType::Char,
Expand Down Expand Up @@ -1104,6 +1111,8 @@ constexpr inline size_t primDataTypeSizeBit(PrimDataType type) {
return sizeof(at::Float8_e4m3fn) * 8;
case DataType::Float8_e5m2:
return sizeof(at::Float8_e5m2) * 8;
case DataType::Float8_e8m0fnu:
return sizeof(at::Float8_e8m0fnu) * 8;
case DataType::Float4_e2m1:
return 4;
case DataType::Index:
Expand Down
8 changes: 5 additions & 3 deletions csrc/type_promotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ ResultTypeState updateResultTypeState(
ResultTypeState new_state = in_state;
DataType current = scalar;
if (scalar == DataType::Half || scalar == DataType::BFloat16 ||
scalar == DataType::Float8_e4m3fn || scalar == DataType::Float8_e5m2) {
scalar == DataType::Float8_e4m3fn || scalar == DataType::Float8_e5m2 ||
scalar == DataType::Float8_e8m0fnu) {
current = DataType::Float;
}
new_state.wrappedResult =
Expand Down Expand Up @@ -197,11 +198,12 @@ DataType computeTypes(
}

auto common_type = computeTypes(config, vt_operands);
// Cast FP16 / BFloat16 to Float
// Cast FP16 / BFloat16 / FP8 to Float
if (cast_half_to_float &&
(common_type == DataType::Half || common_type == DataType::BFloat16 ||
common_type == DataType::Float8_e4m3fn ||
common_type == DataType::Float8_e5m2)) {
common_type == DataType::Float8_e5m2 ||
common_type == DataType::Float8_e8m0fnu)) {
common_type = DataType::Float;
}

Expand Down
1 change: 1 addition & 0 deletions csrc/validator_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ std::pair<double, double> getTolerance(
// TODO: fp8 likely will need higher tolerance.
case DataType::Float8_e4m3fn:
case DataType::Float8_e5m2:
case DataType::Float8_e8m0fnu:
case DataType::BFloat16: {
// Copied from float case
const auto& sum_tolerance_entry = tolerances.sum_tolerances_half;
Expand Down
1 change: 1 addition & 0 deletions python/nvfuser/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
torch.bfloat16: DataType.BFloat16,
torch.float8_e4m3fn: DataType.Float8_e4m3fn,
torch.float8_e5m2: DataType.Float8_e5m2,
torch.float8_e8m0fnu: DataType.Float8_e8m0fnu,
torch.long: DataType.Int,
torch.int: DataType.Int32,
torch.bool: DataType.Bool,
Expand Down
1 change: 1 addition & 0 deletions python/nvfuser/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class ArgumentType(Enum):
torch.int64: "int64",
torch.float8_e4m3fn: "float8_e4m3fn",
torch.float8_e5m2: "float8_e5m2",
torch.float8_e8m0fnu: "float8_e8m0fnu",
torch.bfloat16: "bfloat16",
torch.float16: "float16",
torch.float32: "float32",
Expand Down
2 changes: 2 additions & 0 deletions python/python_common/python_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ const char* dtypeToPyString(PrimDataType t) {
return "DataType.Float8_e4m3fn";
case DataType::Float8_e5m2:
return "DataType.Float8_e5m2";
case DataType::Float8_e8m0fnu:
return "DataType.Float8_e8m0fnu";
case DataType::Int:
return "DataType.Int";
case DataType::Int32:
Expand Down
1 change: 1 addition & 0 deletions python/python_direct/enum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ void bindEnums(py::module& nvfuser) {
.value("BFloat16", DataType::BFloat16)
.value("Float8_e4m3fn", DataType::Float8_e4m3fn)
.value("Float8_e5m2", DataType::Float8_e5m2)
.value("Float8_e8m0fnu", DataType::Float8_e8m0fnu)
.value("ComplexFloat", DataType::ComplexFloat)
.value("ComplexDouble", DataType::ComplexDouble)
.value("Null", DataType::Null);
Expand Down
1 change: 1 addition & 0 deletions python/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ void initNvFuserPythonBindings(PyObject* module) {
.value("BFloat16", DataType::BFloat16)
.value("Float8_e4m3fn", DataType::Float8_e4m3fn)
.value("Float8_e5m2", DataType::Float8_e5m2)
.value("Float8_e8m0fnu", DataType::Float8_e8m0fnu)
.value("ComplexFloat", DataType::ComplexFloat)
.value("ComplexDouble", DataType::ComplexDouble)
.value("Null", DataType::Null);
Expand Down
Loading