Skip to content

Conversation

@GregorySech
Copy link

@GregorySech GregorySech commented Dec 19, 2025

This pull request fixes a panic that occurred when the ONNX Expand operation receives a scalar input instead of a tensor. The fix properly handles scalar inputs by converting them to rank-1 tensors during code generation.

Checklist

  • Confirmed that cargo run-checks command has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Fixes #4206

Changes

  • Modified type inference in ExpandProcessor::infer_types to accept scalar inputs and propagate their data types
  • Updated code generation in ExpandNode::forward to convert scalar inputs to rank-1 tensors with appropriate tensor kinds (Float, Int, Bool)
  • Added comprehensive test coverage including unit tests and an end-to-end ONNX model test

Testing

  1. burn-import/onnx-tests/tests/expand/expand_scalar.py generates expand_scalar.onnx, an ONNX graph in the fashion of expand_tensor but with a scalar as input.
  2. The expand_scalar test checks that the output of the model generated from expand_scalar.onnx has the correct shape.
  3. Code gen tests

@GregorySech
Copy link
Author

I created this pull request to get a bit of feedback. I think I should extract the logic to create the temporary Tensor for the scalar.
I'm also unsure if it's okay to just use' burn::tensor::Int' for all those different specialisations of' int'. Same for floating point types.

@laggui
Copy link
Member

laggui commented Dec 22, 2025

CC'ing @antimora since you self-assigned the linked issue.

@antimora
Copy link
Collaborator

I created this pull request to get a bit of feedback. I think I should extract the logic to create the temporary Tensor for the scalar. I'm also unsure if it's okay to just use' burn::tensor::Int' for all those different specialisations of' int'. Same for floating point types.

I think for this one it's better if make changes on your PR directly because there are a few things to consider. This is related to tensor creation with dtype

@GregorySech
Copy link
Author

I created this pull request to get a bit of feedback. I think I should extract the logic to create the temporary Tensor for the scalar. I'm also unsure if it's okay to just use' burn::tensor::Int' for all those different specialisations of' int'. Same for floating point types.

I think for this one it's better if make changes on your PR directly because there are a few things to consider. This is related to tensor creation with dtype

@antimora let me know if you have any directions that I should implement.
I guess that tensor creation from a dtype based on a scalar might be somewhat useful for other ops.

@antimora
Copy link
Collaborator

antimora commented Jan 6, 2026

@GregorySech sorry I thought I could handle it before my travels (be back on 20th).

Recently I have discovered that tensor creation from native value was defaulting to backends default dtype. You need to make sure you pass scalars dtype when creating a tensor from native scalar. You can find some existing examples. @laggui is also aware of this issue.

Please proceed with this PR and assume I am not working on it. I will have access to my computer after Jan 20th.

Thank you for your contribution.

@antimora antimora self-requested a review January 20, 2026 20:33
@antimora antimora added the onnx label Jan 21, 2026
@antimora antimora changed the title WIP: Fix panic for Expand onnx operation with scalar input Fix panic for Expand onnx operation with scalar input Jan 21, 2026
@antimora antimora requested a review from Copilot January 21, 2026 19:46
@antimora antimora added the bug Something isn't working label Jan 21, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request fixes a panic that occurred when the ONNX Expand operation receives a scalar input instead of a tensor. The fix properly handles scalar inputs by converting them to rank-1 tensors during code generation.

Changes:

  • Modified type inference in ExpandProcessor::infer_types to accept scalar inputs and propagate their data types
  • Updated code generation in ExpandNode::forward to convert scalar inputs to rank-1 tensors with appropriate tensor kinds (Float, Int, Bool)
  • Added comprehensive test coverage including unit tests and an end-to-end ONNX model test

Reviewed changes

Copilot reviewed 5 out of 6 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
crates/onnx-ir/src/node/expand.rs Added scalar type handling in type inference, extracting dtype from scalar inputs
crates/burn-onnx/src/burn/node/expand.rs Implemented scalar-to-tensor conversion logic for all supported data types (float, int, uint, bool) with comprehensive unit tests
crates/burn-onnx/onnx-tests/tests/expand/mod.rs Added integration test for scalar expand operation verifying both shape and values
crates/burn-onnx/onnx-tests/tests/expand/expand_scalar.py Python script to generate test ONNX model with scalar input
crates/burn-onnx/onnx-tests/tests/expand/expand_scalar.onnx Generated ONNX model file
crates/burn-onnx/onnx-tests/build.rs Added new test model to build configuration

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

- Add scalar input handling in ExpandNode codegen by converting scalar to rank-1 tensor using `from_data_dtype` with explicit dtype preservation
- Update type inference in onnx-ir to accept Scalar inputs
- Add expand_scalar integration test with value verification

Co-Authored-By: GregorySech <16958043+GregorySech@users.noreply.github.com>
Copy link
Collaborator

@antimora antimora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made the fixes directly, so personally approving. Can be merged after @laggui has a chance to review.

Though this fixes the scalar expand issue, there are more issues when converting decode onnx:

DEBUG onnx_ir::phases::node_conversion: Preserving outer-scope value_store for 'constant459_out1' (Constant)
DEBUG onnx_ir::pipeline:  PHASE 3: Type Inference
ERROR burn_onnx::logger: PANIC => panicked at crates/burn-onnx/src/model_gen.rs:296:33:
Failed to parse ONNX file './out/edge_sam_3x_decoder_opset16.onnx': Type inference failed: Custom("Cannot determine output rank for Expand node expand4 with fully dynamic shape tensor")

thread 'main' (484957) panicked at crates/burn-onnx/src/model_gen.rs:296:33:
Failed to parse ONNX file './out/edge_sam_3x_decoder_opset16.onnx': Type inference failed: Custom("Cannot determine output rank for Expand node expand4 with fully dynamic shape tensor")

I would encourage to open a new issue or PR submission.

@codecov
Copy link

codecov bot commented Jan 21, 2026

Codecov Report

❌ Patch coverage is 99.27007% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 68.67%. Comparing base (5843c6e) to head (d911178).

Files with missing lines Patch % Lines
crates/onnx-ir/src/node/expand.rs 50.00% 1 Missing ⚠️

❌ Your project check has failed because the head coverage (68.67%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #4230      +/-   ##
==========================================
+ Coverage   68.65%   68.67%   +0.02%     
==========================================
  Files        1411     1411              
  Lines      168676   168805     +129     
==========================================
+ Hits       115804   115933     +129     
  Misses      52872    52872              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@GregorySech
Copy link
Author

Hi, wait for the merge, i need to push a fix to handle the corner case where onnx cannot infer the rank of the output but the input is scalar. Currently it’s still panicing but it can be handled

@antimora antimora requested a review from laggui January 22, 2026 17:06
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, wait for the merge, i need to push a fix to handle the corner case where onnx cannot infer the rank of the output but the input is scalar. Currently it’s still panicing but it can be handled

Sounds good! Will request changes just to block any accidental merge 😄

Otherwise the current changes LGTM

Comment on lines +29 to +35
let init = quote! {
let input = Tensor::<B, 1 #kind>::from_data_dtype(
burn::tensor::TensorData::from([#input]),
&*self.device,
#dtype_tokens
);
};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use Tensor::<B, 1, #kind>::full instead? Just a small suggestion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working onnx

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Expand of a scalar panics burn-import

3 participants