Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions crates/burn-onnx/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ fn main() {
.input("tests/concat/concat_mixed_three_elements.onnx")
.input("tests/concat/concat_multiple_mixed.onnx")
.input("tests/concat/concat_with_constants.onnx")
.input("tests/concat/concat_scalar_direct.onnx")
.input("tests/concat/concat_scalar_from_gather.onnx")
.input("tests/constant/constant_f32.onnx")
.input("tests/constant/constant_f64.onnx")
.input("tests/constant/constant_i32.onnx")
Expand Down Expand Up @@ -351,6 +353,7 @@ fn main() {
.input("tests/unsqueeze/unsqueeze_runtime_axes.onnx")
.input("tests/unsqueeze/unsqueeze_like.onnx")
.input("tests/unsqueeze/unsqueeze_int_to_shape.onnx")
.input("tests/unsqueeze/unsqueeze_scalar_axes.onnx")
.input("tests/unsqueeze/squeeze_unsqueeze_roundtrip.onnx")
.input("tests/split/split.onnx")
.input("tests/xor/xor.onnx")
Expand Down
Binary file not shown.
138 changes: 138 additions & 0 deletions crates/burn-onnx/onnx-tests/tests/concat/concat_scalar_direct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
#!/usr/bin/env python3

# used to generate model: concat_scalar_direct.onnx
# This test directly reproduces issue #4228: Concat fails when receiving Scalar(I64) input.
#
# Pattern: Shape -> Gather (scalar index) -> Concat (direct, no Unsqueeze)
# This is the exact pattern that causes the TypeMismatch error in onnx-ir.

import onnx
import onnx.helper
import numpy as np


def build_model():
# Get shape of input tensor: [batch, channels, height, width]
shape_node = onnx.helper.make_node(
"Shape",
inputs=["input1"],
outputs=["shape1"],
name="/Shape"
)

# Constant scalar index (0) to extract batch dimension
# Using shape=[] makes it a scalar - this is the key to the bug
const_idx_node = onnx.helper.make_node(
"Constant",
inputs=[],
outputs=["idx"],
value=onnx.helper.make_tensor(
name="idx_value",
data_type=onnx.TensorProto.INT64,
dims=[], # Scalar
vals=[0]
),
name="/ConstIdx"
)

# Gather the batch dimension (index 0) from shape
# With scalar index, output is also scalar (Scalar(I64) in onnx-ir)
gather_node = onnx.helper.make_node(
"Gather",
inputs=["shape1", "idx"],
outputs=["batch_dim"], # This will be Scalar(I64) - triggers the bug
axis=0,
name="/Gather"
)

# Constant for another dimension
const_dim_node = onnx.helper.make_node(
"Constant",
inputs=[],
outputs=["other_dim"],
value=onnx.helper.make_tensor(
name="other_dim_value",
data_type=onnx.TensorProto.INT64,
dims=[], # Also scalar
vals=[64]
),
name="/ConstDim"
)

# Concat two scalars directly - this triggers the bug
# In ONNX, this is valid and produces a 1D tensor of length 2
concat_node = onnx.helper.make_node(
"Concat",
inputs=["batch_dim", "other_dim"],
outputs=["output_shape"],
axis=0,
name="/Concat"
)

# Create the graph
graph = onnx.helper.make_graph(
name="main_graph",
nodes=[
shape_node,
const_idx_node,
gather_node,
const_dim_node,
concat_node
],
inputs=[
onnx.helper.make_value_info(
name="input1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=[2, 3, 4, 5]
),
),
],
outputs=[
onnx.helper.make_value_info(
name="output_shape",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.INT64, shape=[2] # [batch_dim, 64]
),
)
]
)

# Create the model
model = onnx.helper.make_model(
graph,
ir_version=8,
opset_imports=[onnx.helper.make_operatorsetid("", 16)]
)

return model


def main():
onnx_model = build_model()
file_name = "concat_scalar_direct.onnx"
onnx.save(onnx_model, file_name)
onnx.checker.check_model(file_name)

print(f"Finished exporting model to {file_name}")

# Test with onnx.reference.ReferenceEvaluator
try:
from onnx.reference import ReferenceEvaluator

# Create test data with shape [2, 3, 4, 5]
test_input = np.ones((2, 3, 4, 5), dtype=np.float32)

# Run inference
sess = ReferenceEvaluator(onnx_model)
result = sess.run(None, {"input1": test_input})

print(f"Test input shape: {test_input.shape}")
print(f"Output shape tensor: {result[0]}")
print(f"Expected: [2, 64] (batch=2 from input shape, 64 from constant)")

except ImportError:
print("onnx.reference not available, skipping inference test")


if __name__ == "__main__":
main()
Binary file not shown.
161 changes: 161 additions & 0 deletions crates/burn-onnx/onnx-tests/tests/concat/concat_scalar_from_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
#!/usr/bin/env python3

# used to generate model: concat_scalar_from_gather.onnx
# This test reproduces issue #4228: Concat fails when receiving Scalar(I64) input
# from a Gather operation with scalar index.
#
# Pattern: Shape -> Gather (scalar index) -> Concat
# The Gather with scalar index produces a scalar output, which Concat should handle.

import onnx
import onnx.helper
import numpy as np


def build_model():
# Get shape of input tensor: [batch, channels, height, width]
shape_node = onnx.helper.make_node(
"Shape",
inputs=["input1"],
outputs=["shape1"],
name="/Shape"
)

# Constant scalar index (0) to extract batch dimension
# Using shape=[] makes it a scalar
const_idx_node = onnx.helper.make_node(
"Constant",
inputs=[],
outputs=["idx"],
value=onnx.helper.make_tensor(
name="idx_value",
data_type=onnx.TensorProto.INT64,
dims=[], # Scalar - this is key to reproducing the bug
vals=[0]
),
name="/ConstIdx"
)

# Gather the batch dimension (index 0) from shape
# With scalar index, output is also scalar
gather_node = onnx.helper.make_node(
"Gather",
inputs=["shape1", "idx"],
outputs=["batch_dim"], # This will be Scalar(I64)
axis=0,
name="/Gather"
)

# Constant for new dimensions to concat
const_dims_node = onnx.helper.make_node(
"Constant",
inputs=[],
outputs=["new_dims"],
value=onnx.helper.make_tensor(
name="new_dims_value",
data_type=onnx.TensorProto.INT64,
dims=[2],
vals=[32, 64]
),
name="/ConstDims"
)

# Unsqueeze the scalar to make it 1D for concat
unsqueeze_axes_node = onnx.helper.make_node(
"Constant",
inputs=[],
outputs=["unsqueeze_axes"],
value=onnx.helper.make_tensor(
name="unsqueeze_axes_value",
data_type=onnx.TensorProto.INT64,
dims=[1],
vals=[0]
),
name="/UnsqueezeAxes"
)

unsqueeze_node = onnx.helper.make_node(
"Unsqueeze",
inputs=["batch_dim", "unsqueeze_axes"],
outputs=["batch_dim_1d"],
name="/Unsqueeze"
)

# Concat the unsqueezed batch dim with new dims
concat_node = onnx.helper.make_node(
"Concat",
inputs=["batch_dim_1d", "new_dims"],
outputs=["output_shape"],
axis=0,
name="/Concat"
)

# Create the graph
graph = onnx.helper.make_graph(
name="main_graph",
nodes=[
shape_node,
const_idx_node,
gather_node,
const_dims_node,
unsqueeze_axes_node,
unsqueeze_node,
concat_node
],
inputs=[
onnx.helper.make_value_info(
name="input1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=[2, 3, 4, 5]
),
),
],
outputs=[
onnx.helper.make_value_info(
name="output_shape",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.INT64, shape=[3] # [batch, 32, 64]
),
)
]
)

# Create the model
model = onnx.helper.make_model(
graph,
ir_version=8,
opset_imports=[onnx.helper.make_operatorsetid("", 16)]
)

return model


def main():
onnx_model = build_model()
file_name = "concat_scalar_from_gather.onnx"
onnx.save(onnx_model, file_name)
onnx.checker.check_model(file_name)

print(f"Finished exporting model to {file_name}")

# Test with onnx.reference.ReferenceEvaluator
try:
from onnx.reference import ReferenceEvaluator

# Create test data with shape [2, 3, 4, 5]
test_input = np.ones((2, 3, 4, 5), dtype=np.float32)

# Run inference
sess = ReferenceEvaluator(onnx_model)
result = sess.run(None, {"input1": test_input})

print(f"Test input shape: {test_input.shape}")
print(f"Output shape tensor: {result[0]}")
print(f"Expected: [2, 32, 64] (batch=2 from input, then 32, 64 from constant)")

except ImportError:
print("onnx.reference not available, skipping inference test")


if __name__ == "__main__":
main()
42 changes: 41 additions & 1 deletion crates/burn-onnx/onnx-tests/tests/concat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ include_models!(
concat_mixed_single_element,
concat_mixed_three_elements,
concat_multiple_mixed,
concat_with_constants
concat_with_constants,
concat_scalar_direct,
concat_scalar_from_gather
);

#[cfg(test)]
Expand Down Expand Up @@ -144,4 +146,42 @@ mod tests {
let expected: [i64; 8] = [3, 4, 2, 3, 5, 7, 8, 9];
assert_eq!(output, expected);
}

#[test]
fn concat_scalar_direct() {
// This test reproduces issue #4228: Concat receiving Scalar(I64) inputs
// Pattern: Shape -> Gather (scalar index) -> Concat
let device = Default::default();
let model: concat_scalar_direct::Model<TestBackend> =
concat_scalar_direct::Model::new(&device);

// Create test input with shape [2, 3, 4, 5]
let input1 = Tensor::<TestBackend, 4>::zeros([2, 3, 4, 5], &device);

// Run the model - extracts batch dim via Gather and concats with constant
let output = model.forward(input1);

// The output should be [2, 64] (batch=2 from input shape, 64 from constant)
let expected = Tensor::<TestBackend, 1, burn::prelude::Int>::from_ints([2, 64], &device);
assert!(output.equal(expected).all().into_scalar());
}

#[test]
fn concat_scalar_from_gather() {
// This test shows a workaround pattern: Shape -> Gather -> Unsqueeze -> Concat
// The output is a shape array since Unsqueeze converts scalar back to shape context
let device = Default::default();
let model: concat_scalar_from_gather::Model<TestBackend> =
concat_scalar_from_gather::Model::new(&device);

// Create test input with shape [2, 3, 4, 5]
let input1 = Tensor::<TestBackend, 4>::zeros([2, 3, 4, 5], &device);

// Run the model - extracts batch dim, unsqueezes, and concats with constants
let output = model.forward(input1);

// The output should be [2, 32, 64] (batch=2 unsqueezed, then 32, 64)
let expected: [i64; 3] = [2, 32, 64];
assert_eq!(output, expected);
}
}
Loading
Loading