1111#include " core/providers/coreml/shape_utils.h"
1212#include " core/providers/shared/utils/utils.h"
1313#include " core/optimizer/initializer.h"
14+ #include " core/providers/cpu/tensor/unsqueeze.h"
1415
1516namespace onnxruntime {
1617namespace coreml {
@@ -27,7 +28,7 @@ class SqueezeOpBuilder : public BaseOpBuilder {
2728};
2829
2930namespace {
30- void GetAxes (ModelBuilder& model_builder, const Node& node, std::vector< int64_t > & axes) {
31+ void GetAxes (ModelBuilder& model_builder, const Node& node, TensorShapeVector & axes) {
3132 // Squeeze opset 13 use input as axes
3233 if (node.SinceVersion () > 12 ) {
3334 // If axes is not provided, return an empty axes as default to squeeze all
@@ -41,7 +42,8 @@ void GetAxes(ModelBuilder& model_builder, const Node& node, std::vector<int64_t>
4142 }
4243 } else {
4344 NodeAttrHelper helper (node);
44- axes = helper.Get (" axes" , std::vector<int64_t >());
45+ auto axes_attr = helper.Get (" axes" , std::vector<int64_t >());
46+ axes.assign (axes_attr.begin (), axes_attr.end ());
4547 }
4648}
4749} // namespace
@@ -58,7 +60,7 @@ Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
5860 std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer (node);
5961 const auto & input_defs (node.InputDefs ());
6062 auto * coreml_squeeze = layer->mutable_squeeze ();
61- std::vector< int64_t > axes;
63+ TensorShapeVector axes;
6264 GetAxes (model_builder, node, axes);
6365 std::vector<int64_t > input_shape;
6466 GetShape (*input_defs[0 ], input_shape, logger);
@@ -72,24 +74,12 @@ Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
7274
7375 if (coreml_op_type == " squeeze" ) {
7476 if (!axes.empty ()) {
75- AddOperationInput (*op, " axes" , model_builder.AddConstant (op->type (), " axes" , axes));
77+ // coreml squeeze op does support negative axes
78+ AddOperationInput (*op, " axes" , model_builder.AddConstant (op->type (), " axes" , AsSpan (axes)));
7679 }
7780 } else {
78- for (auto & axis : axes) {
79- axis = HandleNegativeAxis (axis, input_shape.size () + axes.size ());
80- }
81- std::vector<int64_t > new_shape (axes.size () + input_shape.size (), 1 );
82- std::sort (axes.begin (), axes.end ());
83- // For example: Given an input tensor (data) of shape [3, 4, 5],
84- // then Unsqueeze(data, axes=[0, 4]) outputs a tensor (expanded) containing same data as data but with shape [1, 3, 4, 5, 1].
85- for (size_t i = 0 , ori_i = 0 , axes_i = 0 ; i < new_shape.size (); i++) {
86- if ((axes_i >= axes.size () || static_cast <int64_t >(i) != axes[axes_i]) && input_shape.size () >= ori_i) {
87- new_shape[i] = input_shape[ori_i++];
88- } else {
89- axes_i++;
90- }
91- }
92- AddOperationInput (*op, " shape" , model_builder.AddConstant (op->type (), " shape" , new_shape));
81+ TensorShapeVector output_shape = UnsqueezeBase::ComputeOutputShape (TensorShape (input_shape), axes);
82+ AddOperationInput (*op, " shape" , model_builder.AddConstant (op->type (), " shape" , AsSpan (output_shape)));
9383 }
9484 AddOperationOutput (*op, *node.OutputDefs ()[0 ]);
9585 model_builder.AddOperation (std::move (op));
@@ -118,7 +108,7 @@ bool SqueezeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputP
118108 if (node.SinceVersion () > 12 && input_defs.size () > 1 ) {
119109 const auto & axes_name = input_defs[1 ]->Name ();
120110 if (!input_params.graph_viewer .GetConstantInitializer (axes_name)) {
121- LOGS (logger, VERBOSE) << " Input axes of Squeeze must be known" ;
111+ LOGS (logger, VERBOSE) << " Input axes must be known" ;
122112 return false ;
123113 }
124114 }
@@ -127,19 +117,19 @@ bool SqueezeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputP
127117 if (!input_params.create_mlprogram ) {
128118 return false ;
129119 }
130- int64_t rank = -1 ;
120+
121+ int64_t num_of_new_dims = 0 ;
131122 if (node.SinceVersion () > 12 ) {
132- const auto & axes_tensor = *input_params.graph_viewer .GetConstantInitializer (input_defs[1 ]->Name ());
133- Initializer unpacked_tensor (axes_tensor);
134- rank = unpacked_tensor.size ();
123+ num_of_new_dims = node.InputDefs ()[1 ]->Shape ()->dim (0 ).dim_value ();
135124 } else {
136125 NodeAttrHelper helper (node);
137126 auto axes = helper.Get (" axes" , std::vector<int64_t >());
138- rank = static_cast <int64_t >(axes.size ());
127+ num_of_new_dims = static_cast <int64_t >(axes.size ());
139128 }
129+
140130 std::vector<int64_t > input_shape;
141- if (!GetShape (*input_defs[0 ], input_shape, logger) || input_shape.size () + rank > 5 ) {
142- LOGS (logger, VERBOSE) << " Unsqueeze with rank > 5 is not supported" ;
131+ if (!GetShape (*input_defs[0 ], input_shape, logger) || input_shape.size () + num_of_new_dims > 5 ) {
132+ LOGS (logger, VERBOSE) << " Unsqueeze with num_of_new_dims > 5 is not supported" ;
143133 return false ;
144134 }
145135 }
0 commit comments