-
Notifications
You must be signed in to change notification settings - Fork 790
Fix ONNX Resize to support runtime scales input (#4336) #4369
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The Resize node codegen only handled runtime sizes input, panicking when scales were provided instead. Per ONNX spec, either sizes or scales can be used (mutually exclusive). This adds support for ResizeScales::Runtime by computing output dimensions from input.dims() * scales at runtime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds support for runtime scales input to the ONNX Resize operator, fixing a panic that occurred when scales were provided instead of sizes. The implementation follows the existing pattern for runtime sizes handling.
Changes:
- Added runtime scales handling for both Shape and Tensor input types in the Resize operator codegen
- Added 6 snapshot tests covering all interpolation modes (nearest, linear, cubic) for both Shape and Tensor inputs
- Added integration test with runtime scales tensor input and ONNX model generation script
- Added TODO comments referencing issue #4368 for future API refactoring
Reviewed changes
Copilot reviewed 5 out of 6 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| crates/burn-tensor/src/tensor/module.rs | Added TODO comment about refactoring interpolate API to support scale_factor |
| crates/burn-onnx/src/burn/node/resize.rs | Added runtime scales handling logic for Shape and Tensor inputs, plus 6 new snapshot tests |
| crates/burn-onnx/onnx-tests/tests/resize/resize_with_scales_tensor.py | Python script to generate ONNX model for testing runtime scales |
| crates/burn-onnx/onnx-tests/tests/resize/resize_with_scales_tensor.onnx | Generated ONNX model binary for runtime scales test |
| crates/burn-onnx/onnx-tests/tests/resize/mod.rs | Added integration test for resize with runtime scales tensor |
| crates/burn-onnx/onnx-tests/build.rs | Added new test model to build configuration |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Codecov Report❌ Patch coverage is
❌ Your project check has failed because the head coverage (68.91%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #4369 +/- ##
==========================================
+ Coverage 68.65% 68.91% +0.25%
==========================================
Files 1411 1412 +1
Lines 168676 168506 -170
==========================================
+ Hits 115804 116122 +318
+ Misses 52872 52384 -488 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| } | ||
| } | ||
|
|
||
| #[allow(clippy::unnecessary_unwrap)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed on main, should remove the attribute
| #[allow(clippy::unnecessary_unwrap)] |
| pub fn forward(&self, input: Tensor<B, 4>, scale_factors: [i64; 1]) -> Tensor<B, 4> { | ||
| let output = { | ||
| let input_dims = input.dims(); | ||
| let target_height = ((input_dims[2] as f64) * (scale_factors[2] as f64)) | ||
| as usize; | ||
| let target_width = ((input_dims[3] as f64) * (scale_factors[3] as f64)) as usize; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accessing scale_factors[2] and scale_factors[3] is out of bounds for an array of length 1. Looks like the runtime shape generates an incorrect signature (an existing bug that also appears to exist for the runtime sizes, not just scales)
The ONNX
Resizeoperator accepts eithersizesorscalesas input (mutually exclusive per spec). The codegen only handled runtimesizesinput and panicked whenscaleswas provided: panicked at crates/burn-onnx/src/burn/node/resize.rs:175:22: Runtime resize requires sizes inputThis PR adds support for
ResizeScales::Runtimeby computing output dimensions frominput.dims() * scalesat runtimeChecklist
cargo run-checkscommand has been executed.Related Issues/PRs
Fixes #4336
Changes
ResizeScales::Runtimehandling inforward()for bothShapeandTensorinput typesTesting