Skip to content

Conversation

@antimora
Copy link
Collaborator

Summary

Fixes #4228 - ONNX Import fails with TypeMismatch { expected: "Tensor or Shape", actual: "Scalar(I64)" } in concat.rs

Changes

  • extract_config(): Added Scalar handling (rank 0)
  • infer_types(): Now handles mixed Scalar + Shape + rank-1 Tensor inputs
  • Output type correctly inferred as Shape when mixing scalars with shapes, or 1D tensor otherwise
  • Added code generation for scalar inputs - converts each scalar to a rank-1 tensor using Tensor::from_data_dtype() before concatenation
  • extract_config(): Added Scalar handling for axes input (single axis value as scalar instead of 1D tensor)

Tests

Note

The original model from #4228 now progresses further but hits a separate pre-existing limitation in the Pad operator (only supports last-2-dimension padding). This is tracked in #4269 and is independent of this fix.

Adds handling for scalar inputs in Concat nodes, allowing concatenation of scalars and rank-1 tensors to produce a 1D tensor. Includes new ONNX test models and Rust tests to reproduce and verify the fix for issue tracel-ai#4228, ensuring correct type inference and code generation for mixed scalar/tensor inputs.
Adds support for Unsqueeze with scalar axes input, ensuring correct extraction and processing of scalar axis values. Updates Concat to handle mixing scalars, rank-1 tensors, and shapes, and adds a new test and ONNX model to verify Unsqueeze with scalar axes.
Included a reference to the relevant GitHub issue for tracking known limitations in the pad node implementation.
@antimora antimora requested review from Copilot and laggui January 22, 2026 21:32
@antimora antimora self-assigned this Jan 22, 2026
@antimora antimora added the onnx label Jan 22, 2026
@antimora antimora added the bug Something isn't working label Jan 22, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes ONNX import failures when encountering scalar inputs in Concat and Unsqueeze operations. The issue (#4228) occurred when models used scalar indices from Gather operations with Concat, causing a TypeMismatch error.

Changes:

  • Added scalar input handling to Concat operator for both type inference and configuration extraction
  • Added scalar axes handling to Unsqueeze operator's configuration extraction
  • Implemented code generation for converting scalars to rank-1 tensors before concatenation
  • Added comprehensive test coverage including Python test generators and Rust tests

Reviewed changes

Copilot reviewed 10 out of 13 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
crates/onnx-ir/src/node/concat.rs Added scalar type handling in type inference and config extraction; includes logic to calculate output dimensions and types when scalars are mixed with tensors/shapes
crates/onnx-ir/src/node/unsqueeze.rs Added scalar axes input handling with proper dtype validation and value extraction
crates/burn-onnx/src/burn/node/concat.rs Implemented code generation to convert scalar inputs to rank-1 tensors using from_data_dtype before concatenation
crates/onnx-ir/src/node/pad.rs Added tracking issue reference for existing pad operator limitation
crates/burn-onnx/onnx-tests/tests/concat/*.{py,onnx} Added two new test models: concat_scalar_direct (direct scalar concat) and concat_scalar_from_gather (scalar with unsqueeze workaround)
crates/burn-onnx/onnx-tests/tests/unsqueeze/*.{py,onnx} Added test model unsqueeze_scalar_axes for scalar axes input validation
crates/burn-onnx/onnx-tests/tests/{concat,unsqueeze}/mod.rs Added Rust test implementations that verify correct output shapes and values
crates/burn-onnx/onnx-tests/build.rs Registered new test models in the build process
Comments suppressed due to low confidence (1)

crates/onnx-ir/src/node/concat.rs:295

  • When concatenating scalar inputs (rank 0), the only valid axis is 0. However, if a user provides axis=-1, the normalization at line 295 will result in normalized_axis = -1 + 0 = -1, which is invalid. The TODO comment at line 297 mentions missing validation for normalized_axis being within valid range [0, rank).

For scalar concatenation specifically, the code should validate that the provided axis (after normalization) is 0, since that's the only valid axis for rank-0 inputs. Consider adding validation such as:
if rank == 0 and normalized_axis != 0, return an error indicating that scalar concatenation only supports axis=0.

        // extract the rank based on input type
        let rank = match &node.inputs.first().unwrap().ty {
            ArgType::Tensor(tensor) => tensor.rank as i64,
            ArgType::Shape(_) => 1,  // Shapes are 1D
            ArgType::Scalar(_) => 0, // Scalars are rank-0
        };

        // if axis is negative, it is counted from the end
        let normalized_axis = if axis < 0 { axis + rank } else { axis };

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +21 to +58
ArgType::Tensor(_) if has_scalar => {
// Mixed scalar/tensor concatenation - convert scalars to rank-1 tensors first
let mut inits = Vec::new();
let mut input_exprs = Vec::new();

for (i, input_arg) in self.inputs.iter().enumerate() {
let input = scope.arg(input_arg);

if input_arg.ty.is_scalar() {
// Convert scalar to rank-1 tensor
let dtype = input_arg.ty.elem_type();
let dtype_tokens = dtype.to_tokens();
let kind = match dtype {
DType::Bool => quote! { , Bool },
_ if dtype.is_float() => quote! {},
_ => quote! { , Int },
};
let temp_name =
Ident::new(&format!("scalar_as_tensor_{}", i), Span::call_site());
let init = quote! {
let #temp_name: Tensor<B, 1 #kind> = Tensor::from_data_dtype(
burn::tensor::TensorData::from([#input]),
&*self.device,
#dtype_tokens
);
};
inits.push(init);
input_exprs.push(quote! { #temp_name });
} else {
input_exprs.push(input);
}
}

quote! {
#(#inits)*
let #output = burn::tensor::Tensor::cat([#(#input_exprs),*].into(), #dim);
}
}
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scalar-to-tensor conversion logic only handles the case where the output is ArgType::Tensor. However, when scalars are mixed with ArgType::Shape inputs, the output type is set to ArgType::Shape (per line 148 in onnx-ir/src/node/concat.rs). In this case, the code will fall through to the Shape branch (not shown in this diff but at lines 67-87), which will try to use &#input_name[..] on scalar values. This won't compile since scalars are not sliceable arrays.

The fix should also handle scalar inputs when the output type is ArgType::Shape, by converting scalars to single-element arrays before concatenation, similar to how it's done here for tensors.

Copilot uses AI. Check for mistakes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the bot has effectively caught a bug 😄

Since this PR is meant to handle scalar inputs, mixed scalar + shape concat should be fixed before we merge?

@codecov
Copy link

codecov bot commented Jan 23, 2026

Codecov Report

❌ Patch coverage is 77.77778% with 34 lines in your changes missing coverage. Please review.
✅ Project coverage is 68.87%. Comparing base (3cd0671) to head (23e0edf).

Files with missing lines Patch % Lines
crates/onnx-ir/src/node/concat.rs 74.15% 23 Missing ⚠️
crates/onnx-ir/src/node/unsqueeze.rs 50.00% 7 Missing ⚠️
crates/burn-onnx/src/burn/node/concat.rs 81.81% 4 Missing ⚠️

❌ Your patch check has failed because the patch coverage (77.77%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage.
❌ Your project check has failed because the head coverage (68.87%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage.

Additional details and impacted files
@@           Coverage Diff            @@
##             main    #4370    +/-   ##
========================================
  Coverage   68.86%   68.87%            
========================================
  Files        1412     1412            
  Lines      168245   168392   +147     
========================================
+ Hits       115857   115974   +117     
- Misses      52388    52418    +30     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working onnx

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ONNX Import Fails with TypeMismatch Error in concat.rs (v0.20.0-pre.6)

2 participants