-
Notifications
You must be signed in to change notification settings - Fork 168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FP8 Support #2054
Comments
This is just an umbrella issue to get started. Feel free to modify / fill in the blanks / link sub-issues and related discussions. |
The gfx940 ISA supports 2 fp8 formats: fp8 and bf8. You can see both format supported with mfma, including operands of mixed formats: https://llvm.org/docs/AMDGPU/AMDGPUAsmGFX940.html#vop3. FP8 mfma is plumbed through the amdgpu llvm backend: https://reviews.llvm.org/D129906, for example: // CHECK-GFX940-LABEL: @test_mfma_f32_32x32x16_fp8_bf8
// CHECK-GFX940: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.fp8.bf8(i64 %a, i64 %b, <16 x float> %c, i32 0, i32 0, i32 0)
void test_mfma_f32_32x32x16_fp8_bf8(global v16f* out, long a, long b, v16f c)
{
*out = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(a, b, c, 0, 0, 0);
} The fp8 operands are packed as i64. The only other amdgcn intrinsic for fp8 types is FP8 is E4M3 (inference-focused) while BF8 is E5M2 (training-focused): https://www.amd.com/en/products/accelerators/instinct/mi300/mi300a.html. OCP 8-bit Floating Point Specification (OFP8) Related paper with an overview of fp8 types: FP8 FORMATS FOR DEEP LEARNING Related blog post with overview of fp8 support for H100: https://lambdalabs.com/blog/nvidia-hopper-h100-and-fp8-support |
FP8 support in LLVM/MLIR: RFC from Sep '22 by @stellaraccident: https://discourse.llvm.org/t/rfc-add-apfloat-and-mlir-type-support-for-fp8-e5m2/65279.
Since then, the other types plumbed all the way through MLIR are: Float8E4M3FNType f8E4M3FNTy;
Float8E5M2FNUZType f8E5M2FNUZTy;
Float8E4M3FNUZType f8E4M3FNUZTy;
Float8E4M3B11FNUZType f8E4M3B11FNUZTy; .Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
.Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
.Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; })
.Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
.Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; }) func.func @float_attrs_pass() {
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f8E5M2
float_attr = 2. : f8E5M2
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f8E4M3FN
float_attr = 2. : f8E4M3FN
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f8E5M2FNUZ
float_attr = 2. : f8E5M2FNUZ
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f8E4M3FNUZ
float_attr = 2. : f8E4M3FNUZ
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f8E4M3B11FNUZ
float_attr = 2. : f8E4M3B11FNUZ
} : () -> ()
"test.float_attrs static constexpr fltSemantics semFloat8E5M2 = {15, -14, 3, 8};
static constexpr fltSemantics semFloat8E5M2FNUZ = {
15, -15, 3, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
static constexpr fltSemantics semFloat8E4M3FN = {
8, -6, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes};
static constexpr fltSemantics semFloat8E4M3FNUZ = {
7, -7, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
static constexpr fltSemantics semFloat8E4M3B11FNUZ = {
4, -10, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero}; |
amgcn's fp8 maps to |
If a model is trained with 2-bit mantissas (E5M2), how is the 3rd bit of mantissa in E4M3 going to be useful in inference? |
Also https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html talks a bit about fp8 in NVIDIA GPUs, which is useful reference. In general, fp8 right now are just used in a very ad-hoc way--with ISAs just do conversion and tensor/matrix core ops. For training we also have different fp8 scaling factors for different tensors and need model/framework level handling there, so also quite ad-hoc. So as we've discussed in the meeting, getting a minimal matmul to excersise fp8 + tensor/matrix core in IREE/SHARK would be good start and foundation to everything else. We can then build other parts on top. |
This is explained a bit in the NVIDIA doc as linked in my previous comment:
|
Support in MLIR/LLVM/AMDGPU already seems quite promising, so as discussed this morning the plan is to show a very simple example using fp8 in IREE first, something like module {
func.func @matmul_static(%arg0: tensor<32x32xi8>, %arg1: tensor<32x32xi8>, %arg2: tensor<32x32xf32>) -> tensor<32x32xf32> {
%0 = tensor.bitcast %arg0 : tensor<32x32xi8> to tensor<32x32xf8E4M3FNUZ>
%1 = tensor.bitcast %arg1 : tensor<32x32xi8> to tensor<32x32xf8E4M3FNUZ>
%2 = linalg.matmul ins(%0, %1 : tensor<32x32xf8E4M3FNUZ>, tensor<32x32xf8E4M3FNUZ>) outs(%arg2 : tensor<32x32xf32>) -> tensor<32x32xf32>
return %2 : tensor<32x32xf32>
}
} or, to avoid the need to also handle mfma at the same time, just something as simple as #map = affine_map<(d0) -> (d0)>
module {
func.func @extend_i8(%arg0: tensor<32xi8>) -> tensor<32xf32> {
%0 = tensor.bitcast %arg0 : tensor<32xi8> to tensor<32xf8E4M3FNUZ>
%1 = tensor.empty() : tensor<32xf32>
%2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%0 : tensor<32xf8E4M3FNUZ>) outs(%1 : tensor<32xf32>) {
^bb0(%in: f8E4M3FNUZ, %out: f32):
%3 = arith.extf %in : f8E4M3FNUZ to f32
linalg.yield %3 : f32
} -> tensor<32xf32>
return %2 : tensor<32xf32>
}
} |
Explanation of the LLVM fp semantics naming convention:
source: https://github.com/jax-ml/ml_dtypes?tab=readme-ov-file#float8_e5m2fnuz |
Looking through support in MLIR and lowering into NVVM/ROCDL, seems to be already there as well.. MFMA to ROCLD intrinsics :
Tensor core instructions lowering
So for the examples in this comment #2054 (comment) , the extension truncation should just pass through and compile on AMD. The mfma support, it would be great if we could just take a single matmul of the exact mfma shape and it would just lower to that operation. Like literally all tile sizes would be 1... it should vectorize to |
This is an umbrella issue for allowing fp8 type(s) in shark, spanning all the required layers of the stack: Turbine, IREE, MLIR, LLVM, including backends of interest like ROCm.
Some initial research is required to scope this properly and divide into subtasks, but the main work items are roughly:
llvm::APFloat
The text was updated successfully, but these errors were encountered: