-
Notifications
You must be signed in to change notification settings - Fork 790
Fix ONNX Concat and Unsqueeze to handle Scalar inputs #4370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Adds handling for scalar inputs in Concat nodes, allowing concatenation of scalars and rank-1 tensors to produce a 1D tensor. Includes new ONNX test models and Rust tests to reproduce and verify the fix for issue tracel-ai#4228, ensuring correct type inference and code generation for mixed scalar/tensor inputs.
Adds support for Unsqueeze with scalar axes input, ensuring correct extraction and processing of scalar axis values. Updates Concat to handle mixing scalars, rank-1 tensors, and shapes, and adds a new test and ONNX model to verify Unsqueeze with scalar axes.
Included a reference to the relevant GitHub issue for tracking known limitations in the pad node implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR fixes ONNX import failures when encountering scalar inputs in Concat and Unsqueeze operations. The issue (#4228) occurred when models used scalar indices from Gather operations with Concat, causing a TypeMismatch error.
Changes:
- Added scalar input handling to Concat operator for both type inference and configuration extraction
- Added scalar axes handling to Unsqueeze operator's configuration extraction
- Implemented code generation for converting scalars to rank-1 tensors before concatenation
- Added comprehensive test coverage including Python test generators and Rust tests
Reviewed changes
Copilot reviewed 10 out of 13 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| crates/onnx-ir/src/node/concat.rs | Added scalar type handling in type inference and config extraction; includes logic to calculate output dimensions and types when scalars are mixed with tensors/shapes |
| crates/onnx-ir/src/node/unsqueeze.rs | Added scalar axes input handling with proper dtype validation and value extraction |
| crates/burn-onnx/src/burn/node/concat.rs | Implemented code generation to convert scalar inputs to rank-1 tensors using from_data_dtype before concatenation |
| crates/onnx-ir/src/node/pad.rs | Added tracking issue reference for existing pad operator limitation |
| crates/burn-onnx/onnx-tests/tests/concat/*.{py,onnx} | Added two new test models: concat_scalar_direct (direct scalar concat) and concat_scalar_from_gather (scalar with unsqueeze workaround) |
| crates/burn-onnx/onnx-tests/tests/unsqueeze/*.{py,onnx} | Added test model unsqueeze_scalar_axes for scalar axes input validation |
| crates/burn-onnx/onnx-tests/tests/{concat,unsqueeze}/mod.rs | Added Rust test implementations that verify correct output shapes and values |
| crates/burn-onnx/onnx-tests/build.rs | Registered new test models in the build process |
Comments suppressed due to low confidence (1)
crates/onnx-ir/src/node/concat.rs:295
- When concatenating scalar inputs (rank 0), the only valid axis is 0. However, if a user provides axis=-1, the normalization at line 295 will result in normalized_axis = -1 + 0 = -1, which is invalid. The TODO comment at line 297 mentions missing validation for normalized_axis being within valid range [0, rank).
For scalar concatenation specifically, the code should validate that the provided axis (after normalization) is 0, since that's the only valid axis for rank-0 inputs. Consider adding validation such as:
if rank == 0 and normalized_axis != 0, return an error indicating that scalar concatenation only supports axis=0.
// extract the rank based on input type
let rank = match &node.inputs.first().unwrap().ty {
ArgType::Tensor(tensor) => tensor.rank as i64,
ArgType::Shape(_) => 1, // Shapes are 1D
ArgType::Scalar(_) => 0, // Scalars are rank-0
};
// if axis is negative, it is counted from the end
let normalized_axis = if axis < 0 { axis + rank } else { axis };
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ArgType::Tensor(_) if has_scalar => { | ||
| // Mixed scalar/tensor concatenation - convert scalars to rank-1 tensors first | ||
| let mut inits = Vec::new(); | ||
| let mut input_exprs = Vec::new(); | ||
|
|
||
| for (i, input_arg) in self.inputs.iter().enumerate() { | ||
| let input = scope.arg(input_arg); | ||
|
|
||
| if input_arg.ty.is_scalar() { | ||
| // Convert scalar to rank-1 tensor | ||
| let dtype = input_arg.ty.elem_type(); | ||
| let dtype_tokens = dtype.to_tokens(); | ||
| let kind = match dtype { | ||
| DType::Bool => quote! { , Bool }, | ||
| _ if dtype.is_float() => quote! {}, | ||
| _ => quote! { , Int }, | ||
| }; | ||
| let temp_name = | ||
| Ident::new(&format!("scalar_as_tensor_{}", i), Span::call_site()); | ||
| let init = quote! { | ||
| let #temp_name: Tensor<B, 1 #kind> = Tensor::from_data_dtype( | ||
| burn::tensor::TensorData::from([#input]), | ||
| &*self.device, | ||
| #dtype_tokens | ||
| ); | ||
| }; | ||
| inits.push(init); | ||
| input_exprs.push(quote! { #temp_name }); | ||
| } else { | ||
| input_exprs.push(input); | ||
| } | ||
| } | ||
|
|
||
| quote! { | ||
| #(#inits)* | ||
| let #output = burn::tensor::Tensor::cat([#(#input_exprs),*].into(), #dim); | ||
| } | ||
| } |
Copilot
AI
Jan 22, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The scalar-to-tensor conversion logic only handles the case where the output is ArgType::Tensor. However, when scalars are mixed with ArgType::Shape inputs, the output type is set to ArgType::Shape (per line 148 in onnx-ir/src/node/concat.rs). In this case, the code will fall through to the Shape branch (not shown in this diff but at lines 67-87), which will try to use &#input_name[..] on scalar values. This won't compile since scalars are not sliceable arrays.
The fix should also handle scalar inputs when the output type is ArgType::Shape, by converting scalars to single-element arrays before concatenation, similar to how it's done here for tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the bot has effectively caught a bug 😄
Since this PR is meant to handle scalar inputs, mixed scalar + shape concat should be fixed before we merge?
Codecov Report❌ Patch coverage is ❌ Your patch check has failed because the patch coverage (77.77%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #4370 +/- ##
========================================
Coverage 68.86% 68.87%
========================================
Files 1412 1412
Lines 168245 168392 +147
========================================
+ Hits 115857 115974 +117
- Misses 52388 52418 +30 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Summary
Fixes #4228 - ONNX Import fails with
TypeMismatch { expected: "Tensor or Shape", actual: "Scalar(I64)" }inconcat.rsChanges
extract_config(): AddedScalarhandling (rank 0)infer_types(): Now handles mixedScalar+Shape+ rank-1TensorinputsShapewhen mixing scalars with shapes, or 1D tensor otherwiseTensor::from_data_dtype()before concatenationextract_config(): AddedScalarhandling for axes input (single axis value as scalar instead of 1D tensor)Tests
concat_scalar_direct- minimal reproduction of ONNX Import Fails with TypeMismatch Error inconcat.rs(v0.20.0-pre.6) #4228concat_scalar_from_gather- pattern with Unsqueeze workaroundunsqueeze_scalar_axes- tests scalar axes input handlingNote
The original model from #4228 now progresses further but hits a separate pre-existing limitation in the
Padoperator (only supports last-2-dimension padding). This is tracked in #4269 and is independent of this fix.