3
3
//
4
4
5
5
#include " core/operator_set.hpp"
6
- #include " exceptions.hpp"
7
- #include " openvino/op/constant.hpp"
8
6
#include " openvino/op/convert.hpp"
7
+ #include " openvino/op/divide.hpp"
9
8
#include " openvino/op/gather.hpp"
10
9
#include " openvino/op/log.hpp"
11
10
#include " openvino/op/multiply.hpp"
12
11
#include " openvino/op/negative.hpp"
12
+ #include " openvino/op/not_equal.hpp"
13
13
#include " openvino/op/reduce_mean.hpp"
14
14
#include " openvino/op/reduce_sum.hpp"
15
- #include " openvino/op/select.hpp"
16
15
#include " openvino/op/softmax.hpp"
17
- #include " utils/common.hpp"
18
16
#include " softmax_cross_entropy_loss.hpp"
19
17
20
18
namespace ov {
21
19
namespace frontend {
22
20
namespace onnx {
23
21
namespace {
24
- // softmax cross entropy implementation (Shared helper fn)
25
- OutputVector impl_softmax_cross_entropy (const Node& node, int64_t axis_default) {
26
- const auto inputs = node.get_ov_inputs ();
22
+ OutputVector impl_softmax_cross_entropy (const Node& node, int64_t axis_default) {
23
+ const auto inputs = node.get_ov_inputs ();
27
24
28
- const auto scores = inputs[0 ];
29
- const auto labels = inputs[1 ];
25
+ const auto scores = inputs[0 ];
26
+ const auto labels = inputs[1 ];
30
27
31
- const auto axis = node. get_attribute_value < int64_t >( " axis " , axis_default) ;
32
- const auto reduction = node. get_attribute_value < std::string>( " reduction " , " mean " ) ;
28
+ bool has_weights = inputs. size () > 2 ;
29
+ std::shared_ptr<ov::Node> weights_gather = nullptr ;
33
30
34
- // Computing softmax
35
- const auto softmax = std::make_shared<ov::op::v8::Softmax>(scores, axis) ;
36
- const auto log_softmax = std::make_shared <ov::op::v0::Log>(softmax) ;
31
+ bool has_ignore_index = node. has_attribute ( " ignore_index " );
32
+ int64_t ignore_index_val = 0 ;
33
+ std::shared_ptr <ov::Node> mask = nullptr ;
37
34
38
- const auto axis_const = ov::op::v0::Constant::create (element::i64 , {}, {axis});
39
- const auto gathered = std::make_shared<ov::op::v8::Gather>(log_softmax, labels, axis_const);
35
+ if (has_ignore_index) {
36
+ ignore_index_val = node.get_attribute_value <int64_t >(" ignore_index" );
37
+ auto ignore_index_node = ov::op::v0::Constant::create (labels.get_element_type (), {}, {ignore_index_val});
38
+ auto neq = std::make_shared<ov::op::v1::NotEqual>(labels, ignore_index_node);
39
+ mask = std::make_shared<ov::op::v0::Convert>(neq, scores.get_element_type ());
40
+ }
40
41
42
+ if (has_weights) {
43
+ const auto weights = inputs[2 ];
44
+ const auto axis_for_weights = ov::op::v0::Constant::create (element::i64 , {}, {0 });
45
+ weights_gather = std::make_shared<ov::op::v8::Gather>(weights, labels, axis_for_weights);
41
46
42
- // Computing loss
43
- std::shared_ptr<ov::Node> loss = std::make_shared<ov::op::v0::Negative>(gathered);
47
+ if (has_ignore_index) {
48
+ weights_gather = std::make_shared<ov::op::v1::Multiply>(weights_gather, mask);
49
+ }
50
+ } else if (has_ignore_index) {
51
+ weights_gather = mask;
52
+ }
44
53
45
- // applying reduction as mentioned in https://github.com/onnx/onnx/blob/main/docs/Changelog.md#softmaxcrossentropyloss-12
54
+ const auto axis = node.get_attribute_value <int64_t >(" axis" , axis_default);
55
+ const auto reduction = node.get_attribute_value <std::string>(" reduction" , " mean" );
46
56
47
- if (reduction != " none" ) {
48
- const auto reduce_axis = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {0 });
49
-
50
- loss = (reduction == " mean" )
51
- ? static_cast <std::shared_ptr<ov::Node>>(
52
- std::make_shared<ov::op::v1::ReduceMean>(loss->output (0 ), reduce_axis, true ))
53
- : static_cast <std::shared_ptr<ov::Node>>(
54
- std::make_shared<ov::op::v1::ReduceSum>(loss->output (0 ), reduce_axis, true ));
55
- }
57
+ const auto softmax = std::make_shared<ov::op::v8::Softmax>(scores, axis);
58
+ const auto log_softmax = std::make_shared<ov::op::v0::Log>(softmax);
56
59
57
- return {loss};
58
- }
59
- }
60
- namespace ai_onnx {
61
- namespace opset_12 {
62
- OutputVector ov::frontend::onnx::ai_onnx::opset_12::softmax_cross_entropy_loss (const Node& node) {
63
- return impl_softmax_cross_entropy (node, 1 );
64
- }
65
- ONNX_OP (" SoftmaxCrossEntropyLoss" , OPSET_SINCE(12 ), ai_onnx::opset_12::softmax_cross_entropy_loss);
66
- }
67
- namespace opset_13 {
68
- OutputVector ov::frontend::onnx::ai_onnx::opset_13::softmax_cross_entropy_loss (const Node& node) {
69
- return impl_softmax_cross_entropy (node, 1 );
60
+ const auto axis_const = ov::op::v0::Constant::create (element::i64 , {}, {axis});
61
+ const auto gathered = std::make_shared<ov::op::v8::Gather>(log_softmax, labels, axis_const);
62
+
63
+ std::shared_ptr<ov::Node> loss = std::make_shared<ov::op::v0::Negative>(gathered);
64
+
65
+ if (weights_gather) {
66
+ loss = std::make_shared<ov::op::v1::Multiply>(loss, weights_gather);
70
67
}
71
68
72
- ONNX_OP (" SoftmaxCrossEntropyLoss" , OPSET_SINCE(13 ), ai_onnx::opset_13::softmax_cross_entropy_loss);
69
+ if (reduction != " none" ) {
70
+ auto loss_shape = loss->get_output_partial_shape (0 );
71
+ if (loss_shape.rank ().is_static ()) {
72
+ size_t loss_rank = loss_shape.rank ().get_length ();
73
+ std::vector<int64_t > reduce_axes (loss_rank);
74
+ std::iota (reduce_axes.begin (), reduce_axes.end (), 0 );
75
+ auto reduce_axis = ov::op::v0::Constant::create (ov::element::i64 , {reduce_axes.size ()}, reduce_axes);
76
+
77
+ if (reduction == " mean" ) {
78
+ if (weights_gather) {
79
+ auto loss_sum = std::make_shared<ov::op::v1::ReduceSum>(loss, reduce_axis, false );
80
+ auto weight_sum = std::make_shared<ov::op::v1::ReduceSum>(weights_gather, reduce_axis, false );
81
+ loss = std::make_shared<ov::op::v1::Divide>(loss_sum, weight_sum);
82
+ } else {
83
+ loss = std::make_shared<ov::op::v1::ReduceMean>(loss, reduce_axis, false );
84
+ }
85
+ } else if (reduction == " sum" ) {
86
+ loss = std::make_shared<ov::op::v1::ReduceSum>(loss, reduce_axis, false );
87
+ }
88
+ } else {
89
+ OPENVINO_THROW (" Dynamic rank is not supported for SoftmaxCrossEntropyLoss reduction." );
90
+ }
73
91
}
74
- }
92
+
93
+ return {loss};
94
+ }
95
+ } // namespace
96
+ namespace ai_onnx {
97
+ namespace opset_12 {
98
+ OutputVector softmax_cross_entropy_loss (const Node& node) {
99
+ return impl_softmax_cross_entropy (node, 1 );
75
100
}
101
+ ONNX_OP (" SoftmaxCrossEntropyLoss" , OPSET_IN(12 ), ai_onnx::opset_12::softmax_cross_entropy_loss);
102
+ } // namespace opset_12
103
+ namespace opset_13 {
104
+ OutputVector softmax_cross_entropy_loss (const Node& node) {
105
+ return impl_softmax_cross_entropy (node, 1 );
76
106
}
77
- }
107
+ ONNX_OP (" SoftmaxCrossEntropyLoss" , OPSET_IN(13 ), ai_onnx::opset_13::softmax_cross_entropy_loss);
108
+ } // namespace opset_13
109
+ } // namespace ai_onnx
110
+ } // namespace onnx
111
+ } // namespace frontend
112
+ } // namespace ov
0 commit comments