Skip to content

Commit 04d6427

Browse files
committed
Implementation of SoftmaxCrossEntropyLoss function for opset_12 & opset_13
1 parent dfa4812 commit 04d6427

File tree

1 file changed

+79
-44
lines changed

1 file changed

+79
-44
lines changed

src/frontends/onnx/frontend/src/op/softmax_crossentropy_loss.cpp

Lines changed: 79 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,75 +3,110 @@
33
//
44

55
#include "core/operator_set.hpp"
6-
#include "exceptions.hpp"
7-
#include "openvino/op/constant.hpp"
86
#include "openvino/op/convert.hpp"
7+
#include "openvino/op/divide.hpp"
98
#include "openvino/op/gather.hpp"
109
#include "openvino/op/log.hpp"
1110
#include "openvino/op/multiply.hpp"
1211
#include "openvino/op/negative.hpp"
12+
#include "openvino/op/not_equal.hpp"
1313
#include "openvino/op/reduce_mean.hpp"
1414
#include "openvino/op/reduce_sum.hpp"
15-
#include "openvino/op/select.hpp"
1615
#include "openvino/op/softmax.hpp"
17-
#include "utils/common.hpp"
1816
#include "softmax_cross_entropy_loss.hpp"
1917

2018
namespace ov {
2119
namespace frontend {
2220
namespace onnx {
2321
namespace {
24-
// softmax cross entropy implementation (Shared helper fn)
25-
OutputVector impl_softmax_cross_entropy(const Node& node, int64_t axis_default) {
26-
const auto inputs = node.get_ov_inputs();
22+
OutputVector impl_softmax_cross_entropy(const Node& node, int64_t axis_default) {
23+
const auto inputs = node.get_ov_inputs();
2724

28-
const auto scores = inputs[0];
29-
const auto labels = inputs[1];
25+
const auto scores = inputs[0];
26+
const auto labels = inputs[1];
3027

31-
const auto axis = node.get_attribute_value<int64_t>("axis", axis_default);
32-
const auto reduction = node.get_attribute_value<std::string>("reduction", "mean");
28+
bool has_weights = inputs.size() > 2;
29+
std::shared_ptr<ov::Node> weights_gather = nullptr;
3330

34-
// Computing softmax
35-
const auto softmax = std::make_shared<ov::op::v8::Softmax>(scores, axis);
36-
const auto log_softmax = std::make_shared<ov::op::v0::Log>(softmax);
31+
bool has_ignore_index = node.has_attribute("ignore_index");
32+
int64_t ignore_index_val = 0;
33+
std::shared_ptr<ov::Node> mask = nullptr;
3734

38-
const auto axis_const = ov::op::v0::Constant::create(element::i64, {}, {axis});
39-
const auto gathered = std::make_shared<ov::op::v8::Gather>(log_softmax, labels, axis_const);
35+
if (has_ignore_index) {
36+
ignore_index_val = node.get_attribute_value<int64_t>("ignore_index");
37+
auto ignore_index_node = ov::op::v0::Constant::create(labels.get_element_type(), {}, {ignore_index_val});
38+
auto neq = std::make_shared<ov::op::v1::NotEqual>(labels, ignore_index_node);
39+
mask = std::make_shared<ov::op::v0::Convert>(neq, scores.get_element_type());
40+
}
4041

42+
if (has_weights) {
43+
const auto weights = inputs[2];
44+
const auto axis_for_weights = ov::op::v0::Constant::create(element::i64, {}, {0});
45+
weights_gather = std::make_shared<ov::op::v8::Gather>(weights, labels, axis_for_weights);
4146

42-
// Computing loss
43-
std::shared_ptr<ov::Node> loss = std::make_shared<ov::op::v0::Negative>(gathered);
47+
if (has_ignore_index) {
48+
weights_gather = std::make_shared<ov::op::v1::Multiply>(weights_gather, mask);
49+
}
50+
} else if (has_ignore_index) {
51+
weights_gather = mask;
52+
}
4453

45-
// applying reduction as mentioned in https://github.com/onnx/onnx/blob/main/docs/Changelog.md#softmaxcrossentropyloss-12
54+
const auto axis = node.get_attribute_value<int64_t>("axis", axis_default);
55+
const auto reduction = node.get_attribute_value<std::string>("reduction", "mean");
4656

47-
if (reduction != "none") {
48-
const auto reduce_axis = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
49-
50-
loss = (reduction == "mean")
51-
? static_cast<std::shared_ptr<ov::Node>>(
52-
std::make_shared<ov::op::v1::ReduceMean>(loss->output(0), reduce_axis, true))
53-
: static_cast<std::shared_ptr<ov::Node>>(
54-
std::make_shared<ov::op::v1::ReduceSum>(loss->output(0), reduce_axis, true));
55-
}
57+
const auto softmax = std::make_shared<ov::op::v8::Softmax>(scores, axis);
58+
const auto log_softmax = std::make_shared<ov::op::v0::Log>(softmax);
5659

57-
return {loss};
58-
}
59-
}
60-
namespace ai_onnx {
61-
namespace opset_12 {
62-
OutputVector ov::frontend::onnx::ai_onnx::opset_12::softmax_cross_entropy_loss(const Node& node) {
63-
return impl_softmax_cross_entropy(node, 1);
64-
}
65-
ONNX_OP("SoftmaxCrossEntropyLoss", OPSET_SINCE(12), ai_onnx::opset_12::softmax_cross_entropy_loss);
66-
}
67-
namespace opset_13 {
68-
OutputVector ov::frontend::onnx::ai_onnx::opset_13::softmax_cross_entropy_loss(const Node& node) {
69-
return impl_softmax_cross_entropy(node, 1);
60+
const auto axis_const = ov::op::v0::Constant::create(element::i64, {}, {axis});
61+
const auto gathered = std::make_shared<ov::op::v8::Gather>(log_softmax, labels, axis_const);
62+
63+
std::shared_ptr<ov::Node> loss = std::make_shared<ov::op::v0::Negative>(gathered);
64+
65+
if (weights_gather) {
66+
loss = std::make_shared<ov::op::v1::Multiply>(loss, weights_gather);
7067
}
7168

72-
ONNX_OP("SoftmaxCrossEntropyLoss", OPSET_SINCE(13), ai_onnx::opset_13::softmax_cross_entropy_loss);
69+
if (reduction != "none") {
70+
auto loss_shape = loss->get_output_partial_shape(0);
71+
if (loss_shape.rank().is_static()) {
72+
size_t loss_rank = loss_shape.rank().get_length();
73+
std::vector<int64_t> reduce_axes(loss_rank);
74+
std::iota(reduce_axes.begin(), reduce_axes.end(), 0);
75+
auto reduce_axis = ov::op::v0::Constant::create(ov::element::i64, {reduce_axes.size()}, reduce_axes);
76+
77+
if (reduction == "mean") {
78+
if (weights_gather) {
79+
auto loss_sum = std::make_shared<ov::op::v1::ReduceSum>(loss, reduce_axis, false);
80+
auto weight_sum = std::make_shared<ov::op::v1::ReduceSum>(weights_gather, reduce_axis, false);
81+
loss = std::make_shared<ov::op::v1::Divide>(loss_sum, weight_sum);
82+
} else {
83+
loss = std::make_shared<ov::op::v1::ReduceMean>(loss, reduce_axis, false);
84+
}
85+
} else if (reduction == "sum") {
86+
loss = std::make_shared<ov::op::v1::ReduceSum>(loss, reduce_axis, false);
87+
}
88+
} else {
89+
OPENVINO_THROW("Dynamic rank is not supported for SoftmaxCrossEntropyLoss reduction.");
90+
}
7391
}
74-
}
92+
93+
return {loss};
94+
}
95+
} // namespace
96+
namespace ai_onnx {
97+
namespace opset_12 {
98+
OutputVector softmax_cross_entropy_loss(const Node& node) {
99+
return impl_softmax_cross_entropy(node, 1);
75100
}
101+
ONNX_OP("SoftmaxCrossEntropyLoss", OPSET_IN(12), ai_onnx::opset_12::softmax_cross_entropy_loss);
102+
} // namespace opset_12
103+
namespace opset_13 {
104+
OutputVector softmax_cross_entropy_loss(const Node& node) {
105+
return impl_softmax_cross_entropy(node, 1);
76106
}
77-
}
107+
ONNX_OP("SoftmaxCrossEntropyLoss", OPSET_IN(13), ai_onnx::opset_13::softmax_cross_entropy_loss);
108+
} // namespace opset_13
109+
} // namespace ai_onnx
110+
} // namespace onnx
111+
} // namespace frontend
112+
} // namespace ov

0 commit comments

Comments
 (0)