22// Licensed under the MIT License.
33
44#include " core/providers/coreml/builders/impl/base_op_builder.h"
5+ #include " core/providers/coreml/builders/impl/builder_utils.h"
56#include " core/providers/coreml/builders/model_builder.h"
7+ #include " core/providers/coreml/shape_utils.h"
68#include " core/providers/coreml/builders/op_builder_factory.h"
79#include " core/providers/shared/utils/utils.h" // for NodeAttrHelper
810
@@ -14,28 +16,119 @@ class ShapeOpBuilder : public BaseOpBuilder {
1416
1517 bool IsOpSupportedImpl (const Node& node, const OpBuilderInputParams& input_params,
1618 const logging::Logger& logger) const override ;
19+ bool HasSupportedInputsImpl (const Node& node, const OpBuilderInputParams& input_params,
20+ const logging::Logger& logger) const override ;
21+ bool SupportsMLProgram () const override { return true ; }
1722};
1823
1924Status ShapeOpBuilder::AddToModelBuilderImpl (ModelBuilder& model_builder, const Node& node,
20- const logging::Logger& /* logger*/ ) const {
21- auto layer = model_builder.CreateNNLayer (node);
22- layer->mutable_getshape ();
23- *layer->mutable_input ()->Add () = node.InputDefs ()[0 ]->Name ();
24- *layer->mutable_output ()->Add () = node.OutputDefs ()[0 ]->Name ();
25- model_builder.AddLayer (std::move (layer));
25+ const logging::Logger& logger) const {
26+ const auto & input_defs = node.InputDefs ();
27+
28+ #if defined(COREML_ENABLE_MLPROGRAM)
29+ if (model_builder.CreateMLProgram ()) {
30+ using namespace CoreML ::Specification::MILSpec;
31+ NodeAttrHelper node_attr_helper{node};
32+ int64_t num_dims = input_defs[0 ]->Shape ()->dim_size ();
33+ int64_t start = HandleNegativeAxis (node_attr_helper.Get (" start" , 0 ), num_dims);
34+
35+ int64_t size = -1 ;
36+ if (node_attr_helper.HasAttr (" end" )) {
37+ int64_t end = HandleNegativeAxis (node_attr_helper.Get (" end" , -1 ), num_dims);
38+ size = end - start;
39+ }
40+
41+ int32_t output_datatype = ONNX_NAMESPACE::TensorProto_DataType_INT32;
42+ std::unique_ptr<Operation> op = model_builder.CreateOperation (node, " shape" );
43+ AddOperationInput (*op, " x" , input_defs[0 ]->Name ());
44+ if (size != -1 || start != 0 ) {
45+ std::string_view layer_input_name_x = model_builder.GetUniqueName (node, " slice_by_size" );
46+ std::vector<int64_t > x0_shape{num_dims};
47+ AddIntermediateOperationOutput (*op, layer_input_name_x, output_datatype, x0_shape);
48+ model_builder.AddOperation (std::move (op));
49+
50+ auto slice_op = model_builder.CreateOperation (node, " slice_by_size" );
51+ AddOperationInput (*slice_op, " x" , layer_input_name_x);
52+ std::vector<int64_t > starts = {start};
53+ std::vector<int64_t > sizes = {size};
54+ AddOperationInput (*slice_op, " begin" , model_builder.AddConstant (slice_op->type (), " begin" , starts));
55+ AddOperationInput (*slice_op, " size" , model_builder.AddConstant (slice_op->type (), " size" , sizes));
56+ AddOperationOutput (*slice_op, *node.OutputDefs ()[0 ], output_datatype);
57+ model_builder.AddOperation (std::move (slice_op));
58+ } else {
59+ AddOperationOutput (*op, *node.OutputDefs ()[0 ], output_datatype);
60+ model_builder.AddOperation (std::move (op));
61+ }
62+ } else // NOLINT
63+ #endif
64+ {
65+ auto layer = model_builder.CreateNNLayer (node);
66+ layer->mutable_getshape ();
67+ *layer->mutable_input ()->Add () = input_defs[0 ]->Name ();
68+ *layer->mutable_output ()->Add () = node.OutputDefs ()[0 ]->Name ();
69+ model_builder.AddLayer (std::move (layer));
70+ }
2671 return Status::OK ();
2772}
2873
29- bool ShapeOpBuilder::IsOpSupportedImpl (const Node& node, const OpBuilderInputParams& /* input_params*/ ,
74+ bool ShapeOpBuilder::IsOpSupportedImpl (const Node& node, const OpBuilderInputParams& input_params,
3075 const logging::Logger& logger) const {
76+ const auto * tensor_shape = node.InputDefs ()[0 ]->Shape ();
77+ if (tensor_shape == nullptr ) {
78+ return false ;
79+ }
80+
3181 NodeAttrHelper node_attr_helper{node};
32- if (node_attr_helper.Get (" start" , 0 ) != 0 ) {
33- LOGS (logger, VERBOSE) << " Shape does not support 'start' attribute with value other than 0" ;
82+ if (!input_params.create_mlprogram ) {
83+ if (node_attr_helper.HasAttr (" end" )) {
84+ LOGS (logger, VERBOSE) << " Shape does not support 'end' attribute" ;
85+ return false ;
86+ }
87+
88+ if (node_attr_helper.Get (" start" , 0 ) != 0 ) {
89+ LOGS (logger, VERBOSE) << " Shape does not support 'start' attribute with value other than 0" ;
90+ return false ;
91+ }
92+ } else {
93+ int64_t size = node_attr_helper.HasAttr (" end" )
94+ ? HandleNegativeAxis (node_attr_helper.Get (" end" , 0 ), tensor_shape->dim_size ())
95+ : tensor_shape->dim_size ();
96+ int64_t start = HandleNegativeAxis (node_attr_helper.Get (" start" , 0 ), tensor_shape->dim_size ());
97+ size = size - start;
98+ if (size == 0 ) {
99+ return false ;
100+ }
101+ }
102+
103+ return true ;
104+ }
105+
106+ bool ShapeOpBuilder::HasSupportedInputsImpl (const Node& node,
107+ [[maybe_unused]] const OpBuilderInputParams& input_params,
108+ const logging::Logger& logger) const {
109+ // We only check the type of input 0
110+ const auto & input = *node.InputDefs ()[0 ];
111+
112+ int32_t input_type;
113+ if (!GetType (input, input_type, logger)) {
34114 return false ;
35115 }
36116
37- if (node_attr_helper.HasAttr (" end" )) {
38- LOGS (logger, VERBOSE) << " Shape does not support 'end' attribute" ;
117+ if (input_params.create_mlprogram ) {
118+ if ((input_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 ||
119+ input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT ||
120+ input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)) {
121+ return true ;
122+ } else {
123+ LOGS (logger, VERBOSE) << " [" << node.OpType ()
124+ << " ] Input type: [" << input_type
125+ << " ] is not supported." ;
126+ return false ;
127+ }
128+ } else if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
129+ LOGS (logger, VERBOSE) << " [" << node.OpType ()
130+ << " ] Input type: [" << input_type
131+ << " ] is not supported." ;
39132 return false ;
40133 }
41134
0 commit comments