Skip to content

Add new data type Float8_e8m0fnu #4665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jun 26, 2025
Merged

Add new data type Float8_e8m0fnu #4665

merged 19 commits into from
Jun 26, 2025

Conversation

zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Jun 24, 2025

Yet another variant of fp8, commonly used as scaling factors for mxfp4

Copy link

github-actions bot commented Jun 24, 2025

Review updated until commit badeeb2

Description

  • Added support for new data type Float8_e8m0fnu

  • Updated CUDA version checks for Float8_e8m0fnu

  • Included Float8_e8m0fnu in various type handling functions

  • Added conversion functions for Float8_e8m0fnu in CUDA


Changes walkthrough 📝

Relevant files
Enhancement
16 files
codegen.cpp
Added `Float8_e8m0fnu` to CUDA kernel generator                   
+1/-0     
device_version.cpp
Added CUDA version check for `Float8_e8m0fnu`                       
+12/-0   
index.cpp
Included `Float8_e8m0fnu` in bitwise or check                       
+2/-1     
utils.cpp
Added min/max value handling for `Float8_e8m0fnu`               
+10/-0   
allocations.cpp
Included `Float8_e8m0fnu` in NaN filling                                 
+1/-0     
executor_kernel_arg.cpp
Added conversion for `Float8_e8m0fnu` in polymorphic value conversion
+6/-1     
type.cpp
Added `Float8_e8m0fnu` to various type handling functions
+38/-3   
type_promotion.cpp
Included `Float8_e8m0fnu` in type promotion logic               
+5/-3     
validator_utils.cpp
Included `Float8_e8m0fnu` in tolerance calculation             
+1/-0     
python_utils.cpp
Added `Float8_e8m0fnu` to Python string conversion             
+2/-0     
enum.cpp
Added `Float8_e8m0fnu` to Python enum bindings                     
+1/-0     
python_bindings.cpp
Added `Float8_e8m0fnu` to Python bindings                               
+1/-0     
pytorch_utils.py
Added `Float8_e8m0fnu` to dtype mapping                                   
+1/-0     
utils.py
Added `Float8_e8m0fnu` to argument type enum                         
+1/-0     
fp8_support.cu
Added conversion functions for `Float8_e8m0fnu`                   
+144/-0 
type.h
Added `Float8_e8m0fnu` to type definitions                             
+10/-1   
Cleanup
1 files
arith.cpp
Removed unnecessary includes for FP8 types                             
+0/-5     
Tests
2 files
test_gpu1.cpp
Added tests for `Float8_e8m0fnu` casting                                 
+35/-8   
test_python_frontend.py
Added tests for `Float8_e8m0fnu` in Python frontend           
+5/-1     

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

CUDA Version Check

Ensure that the CUDA version checks are correctly implemented and that the error messages are clear and consistent.

        "(9.0)");
#else
    NVF_ERROR(
        "Fusion contains Float8_xxx values which was not supported in given "
        "CUDA version");
#endif // (CUDA_VERSION >= 12010)
  }
  if (val->dtype() == DataType::Float8_e8m0fnu) {
#if (CUDA_VERSION >= 12070)
    ensureVersion(
        {10, 0},
        "Fusion contains Float8_e8m0fnu values which was introduced in "
        "Blackwell (10.0)");
#else
    NVF_ERROR(
        "Fusion contains Float8_e8m0fnu values which was not supported in "
        "given CUDA version");
#endif // (CUDA_VERSION >= 12070)
Tolerance Handling

Verify that the tolerance levels for Float8_e8m0fnu are appropriate and that the TODO comment is addressed.

    return {abs_tol, abs_tol * 0.01};
  }
}
// TODO: fp8 likely will need higher tolerance.
case DataType::Float8_e4m3fn:
case DataType::Float8_e5m2:
case DataType::Float8_e8m0fnu:
case DataType::BFloat16: {
  // Copied from float case
  const auto& sum_tolerance_entry = tolerances.sum_tolerances_half;
  const auto& base_abs = tolerances.base_half_abs_tol;
Device Code Optimization

Review the device code for Float8_e8m0fnu to ensure that it is optimized and that all necessary conversions are correctly implemented.

struct __e8m0;
__device__ __inline__ __e8m0 __float2e8m0(const float);
__device__ __inline__ __e8m0 __double2e8m0(const double);

struct __align__(1) __e8m0 {
  __e8m0() = default;

  __e8m0(const __e8m0& other) {
    __x = other.__x;
  }

  __e8m0(const __e8m0&& other) {
    __x = other.__x;
  }

  __e8m0(const volatile __e8m0& other) {
    __x = other.__x;
  }

  __e8m0(const volatile __e8m0&& other) {
    __x = other.__x;
  }

  // Note: not returning reference for `__e8m0::operator=`
  // Doing so would requires us to return `volatile __e8m0&` for the volatile
  // variants, which would trigger a gcc warning `implicit dereference will not
  // access object of type ‘volatile S’ in statement`
  __device__ void operator=(const __e8m0& other) {
    __x = other.__x;
  }

  __device__ void operator=(const __e8m0&& other) {
    __x = other.__x;
  }

  __device__ void operator=(const volatile __e8m0& other) {
    __x = other.__x;
  }

  __device__ void operator=(const volatile __e8m0&& other) {
    __x = other.__x;
  }

  __device__ void operator=(const __e8m0& other) volatile {
    __x = other.__x;
  }

  __device__ void operator=(const __e8m0&& other) volatile {
    __x = other.__x;
  }

  __device__ void operator=(const volatile __e8m0& other) volatile {
    __x = other.__x;
  }

  __device__ void operator=(const volatile __e8m0&& other) volatile {
    __x = other.__x;
  }

  __device__ __e8m0(const float f) {
    __x = __float2e8m0(f).__x;
  }

  __device__ __e8m0(const double f) {
    __x = __double2e8m0(f).__x;
  }

  __device__ __e8m0(const int x) : __x(x) {}

  __device__ __e8m0(const long long x) : __x(x) {}

  __device__ __e8m0(const uint8_t x) : __x(x) {}

  __device__ __e8m0(const uint16_t x) : __x(x) {}

  __device__ uint8_t raw() const {
    return __x;
  }

 protected:
  uint8_t __x;
};

// see NOTE [ fp8 cast optimization ]
__device__ __inline__ __e8m0 __float2e8m0(const float f) {
  constexpr float f_const_zero = 0.f;
  unsigned short _tmp_buffer;
  __e8m0 val;
  asm("{cvt.rz.satfinite.ue8m0x2.f32 %0, %1, %2;}"
      : "=h"(_tmp_buffer)
      : "f"(f_const_zero), "f"(f));
  memcpy(&val, &_tmp_buffer, sizeof(uint8_t));

  return val;
}

__device__ __inline__ float __e8m02float(const __e8m0 b) {
  unsigned short _tmp_buffer;
  memcpy(&_tmp_buffer, &b, sizeof(uint8_t));
  float val;
  asm("{\n\t"
      ".reg .b32 buf0;\n\t"
      "cvt.rn.bf16x2.ue8m0x2 buf0, %1;\n\t"
      "cvt.u16.u32 %1, buf0;\n\t"
      "cvt.f32.bf16 %0, %1;\n\t"
      "}"
      : "=f"(val)
      : "h"(_tmp_buffer));

  return val;
}

__device__ __inline__ __e8m0 __double2e8m0(const double f) {
  return __float2e8m0(f);
}

__device__ __inline__ double __e8m02double(const __e8m0 b) {
  return __e8m02float(b);
}

__device__ __inline__ __e8m0 __half2e8m0(const __half h) {
  return __float2e8m0(__half2float(h));
}

__device__ __inline__ __half __e8m02half(const __e8m0 b) {
  return __float2half(__e8m02float(b));
}

__device__ __inline__ __e8m0 __bfloat2e8m0(const __bfloat h) {
  return __float2e8m0(__bfloat2float(h));
}

__device__ __inline__ __bfloat __e8m02bfloat(const __e8m0 b) {
  return __float2bfloat(__e8m02float(b));
}

__device__ __inline__ __e8m0 operator|(const __e8m0 x, const __e8m0 y) {
  unsigned short val;
  unsigned short x_val = x.raw();
  unsigned short y_val = y.raw();
  asm("{  or.b16 %0, %1, %2;}\n" : "=h"(val) : "h"(x_val), "h"(y_val));
  return __e8m0(val);
}

Comment on lines +2736 to +2742
? (outputs[0]
.as<at::Tensor>()
.ge(ref_output / 2)
.logical_and(
outputs[0].as<at::Tensor>().le(ref_output * 2))
.all()
.item<bool>())
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like PyTorch's implementation is different from NVIDIA's implementation...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain what you're doing here? Looks like it accounts for a 2x error range?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly. e8m0 has no mantissa, so it can only represent 1*2^x, and I am asserting that, if the results are different, then x is different by at most 1.

@zasdfgbnm zasdfgbnm marked this pull request as ready for review June 25, 2025 01:03
@zasdfgbnm
Copy link
Collaborator Author

!test

@zasdfgbnm zasdfgbnm requested a review from naoyam June 25, 2025 01:03
@zasdfgbnm
Copy link
Collaborator Author

!test

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zasdfgbnm zasdfgbnm merged commit 5008027 into main Jun 26, 2025
30 of 39 checks passed
@zasdfgbnm zasdfgbnm deleted the Float8_e5m2 branch June 26, 2025 15:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants