Add CPU Cast op support for Float8E8M0#28435
Conversation
…er.py) Agent-Logs-Url: https://github.com/microsoft/onnxruntime/sessions/01411378-b38b-4058-a023-e838a46a2252 Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Cross-reviewed against both the ONNX numpy_helper.py Before → After (for normal floats):
Also fixed subnormal rounding thresholds to match the ONNX G/R/S scheme, and replaced self-referential expected values in the three round-mode tests with hardcoded Warning Firewall rules blocked me from connecting to one or more addresses (expand for details)I tried to connect to the following addresses, but was blocked by firewall rules:
If you need me to access, download, or install something from one of these locations, you can either:
|
…tations - Fix denorm Up (ceiling) path: mantissa == 0x400000 is exactly 2^(-127), which is representable as val=0. Changed threshold from (mantissa > 0) to (mantissa > 0x400000) so exact 2^(-127) values don't incorrectly round up to val=1. - Fix Rounding test: 1.25 with Up (ceiling) mode rounds up to 2.0 (val=128), not down to 1.0 (val=127). - Fix SubnormalRounding test: subnormals are valid positive values within E8M0 range, so saturate=false does not produce NaN for them.
…with DISABLE_FLOAT8_TYPES When DISABLE_FLOAT8_TYPES is defined (minimal/wasm/Android builds), Float8E8M0 is not declared, so boost::mp11::mp_remove<..., Float8E8M0> fails to compile. Since AllIRv10WithInt2 already excludes Float8E8M0 in that case, the PreOpset24 type lists are equivalent to the full ones. Fixes CI failures in: - Linux CPU Minimal Build E2E (6a/6b/6c/7) - Android CI (AndroidBinarySizeCheckJob_MinimalBaseline, android_nnapi_ep) - Wasm builds (wasm_Debug, wasm_Release) - WebGPU builds
ReviewOverall a well-scoped, well-tested change. The new dispatcher cleanly threads Major1. The "backward-compatible" claim is misleading. The ctor signature is compatible (default arg), but
Both are in a public header. Suggest one of:
2. Subnormal Minor3. 4. CUDA Cast registration not updated. Opset 24+ Cast on CUDA still omits 5. Duplicate 6. 7. Test oracles use the function under test. 8. Missing negative tests:
Nits
Praise
Suggested action: Major #1 is the only item I'd consider a blocker — at minimum a documentation/changelog issue, at most warrants reverting the ctor default. Major #2 deserves a quick confirmation that matching the ONNX reference over the prose is intentional. The rest is polish. |
There was a problem hiding this comment.
Pull request overview
This PR extends ONNX Runtime’s CPU Cast operator to support the FLOAT8E8M0 destination type, including parsing/handling of the round_mode attribute and updating Float8E8M0 conversion semantics and tests accordingly.
Changes:
- Added
Float8E8M0::RoundModeand threadedround_modethrough CPU Cast when casting toFLOAT8E8M0. - Extended CPU Cast kernel type constraints/dispatch to include
FLOAT8E8M0for opset 24+ while keeping earlier opsets’ constraints unchanged. - Added/updated unit tests and docs to reflect
FLOAT8E8M0Cast support and revised conversion behavior.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| include/onnxruntime/core/common/float8.h | Adds RoundMode and updates float→E8M0 conversion logic to support multiple rounding modes. |
| onnxruntime/core/providers/cpu/tensor/cast_op.cc | Enables CPU Cast to/from FLOAT8E8M0, validates attributes, and adds a dedicated dispatcher to pass round_mode. |
| onnxruntime/test/providers/cpu/tensor/cast_op_test.cc | Adds CPU Cast tests covering float8e8m0 conversions, saturate behavior, and round_mode variants. |
| onnxruntime/test/framework/float8e8m0_test.cc | Updates existing Float8E8M0 unit test expectations/comments for the revised default rounding behavior. |
| docs/OperatorKernels.md | Updates the kernel support table to include tensor(float8e8m0) for Cast in opset 24 and 25+. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Description
This PR adds Float8E8M0 support to the CPU Cast operator, implementing follow-up items 3 (Cast op support) and 4 (rounding mode) from PR #28381.
Changes
Float8E8M0 RoundMode (
include/onnxruntime/core/common/float8.h)RoundModeenum (Up,Down,Nearest) toFloat8E8M0structround_modeparameter (default:Up)Up/Nearest: ties round away from zero (higher power of 2)Down: ties round towards zero (lower power of 2)UpmodeCPU Cast op (
onnxruntime/core/providers/cpu/tensor/cast_op.cc)Float8E8M0to the Cast kernel's enabled type list (AllIRv10WithInt2)FLOAT8E8M0to thesaturateattribute validationround_modeattribute parsing ("up","down","nearest") per the ONNX opset 25 Cast schemaCastToE8M0Dispatchertemplate that handles casting from any source type (float, double, int, MLFloat16, BFloat16, Int4, Int2, string, other Float8 types) to Float8E8M0 with propersaturateandround_modesupportround_modethroughTests (
onnxruntime/test/providers/cpu/tensor/cast_op_test.cc)FloatToFloat8E8M0_Saturate/_NoSaturate— basic float→E8M0 with saturate on/offFloatToFloat8E8M0_RoundModeUp/_RoundModeDown/_RoundModeNearest— all three rounding modesFloat8E8M0ToFloat/Float8E8M0ToDouble— E8M0→float/double conversionMLFloat16ToFloat8E8M0/DoubleToFloat8E8M0/Int32ToFloat8E8M0— various source typesMotivation and Context
PR #28381 added the Float8E8M0 data type to ORT but deferred Cast op support. This PR completes the Cast op integration so that models can use
Cast(to=FLOAT8E8M0)with propersaturateandround_modesemantics as defined in the ONNX opset 25 spec.Testing
All 10 new Cast op tests pass. All 19 existing Float8E8M0 unit tests and all 4 existing Float8 Cast tests continue to pass.