Skip to content

[PT FE] Support aten::take operation #29479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 28 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
0656890
added aten::take function in pytorch, but tests are not running
vijaykr338 Mar 15, 2025
98c1c61
implemented aten::str, aten::delete but unable to write their tests
vijaykr338 Mar 15, 2025
725c474
Merge branch 'master' into master
vijaykr338 Mar 15, 2025
d70f2f5
Merge branch 'master' into master
vijaykr338 Mar 17, 2025
d6adb61
Merge branch 'master' into master
vijaykr338 Mar 19, 2025
0393cde
fixed op_table.cpp
vijaykr338 Mar 19, 2025
cd831da
Merge branch 'master' into master
vijaykr338 Mar 24, 2025
3eea874
Merge branch 'master' into master
vijaykr338 Mar 28, 2025
3555b2e
fixed code style, added out of bounds check for take function, added …
vijaykr338 Mar 28, 2025
063f8e1
Merge branch 'openvinotoolkit:master' into master
vijaykr338 Mar 29, 2025
54aa9c9
Delete src/frontends/pytorch/src/op/str.cpp
vijaykr338 Apr 4, 2025
6648067
Delete src/frontends/pytorch/src/op/delete.cpp
vijaykr338 Apr 4, 2025
5b2eb21
Delete tests/layer_tests/pytorch_tests/test_delete.py
vijaykr338 Apr 4, 2025
0a72cd5
Update op_table.cpp
vijaykr338 Apr 4, 2025
6dbb55a
Merge branch 'master' into master
mvafin Apr 4, 2025
30b4183
Merge branch 'master' into master
vijaykr338 Apr 6, 2025
be53f0b
added support for quantized relu6
vijaykr338 Apr 6, 2025
5e1e379
Delete tests/layer_tests/pytorch_tests/test_quantized_relu6.py
vijaykr338 Apr 6, 2025
e9caf31
Update op_table.cpp
vijaykr338 Apr 6, 2025
6e071df
Delete src/frontends/pytorch/src/op/quantized_relu6.cpp
vijaykr338 Apr 6, 2025
08c22c5
Merge branch 'master' into master
mvafin Apr 7, 2025
5bd5bb8
Update tests/layer_tests/pytorch_tests/test_take.py
vijaykr338 Apr 7, 2025
31b6ece
Merge branch 'master' into master
mlukasze May 21, 2025
db60469
Merge branch 'master' into master
mlukasze Jun 3, 2025
396f16c
Update src/frontends/pytorch/src/op/take.cpp
vijaykr338 Jun 28, 2025
055e355
Update src/frontends/pytorch/src/op/take.cpp
vijaykr338 Jun 28, 2025
ed3e61e
Update src/frontends/pytorch/src/op/take.cpp
vijaykr338 Jun 28, 2025
2535095
Update src/frontends/pytorch/src/op/take.cpp
vijaykr338 Jun 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions src/frontends/pytorch/src/op/take.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/core/validation_util.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_take_op(const NodeContext& context) {
num_inputs_check(context, 2, 3);
auto input = context.get_input(0);
auto indices = context.get_input(1);
auto input_shape = input.get_partial_shape();
if (input_shape.rank().is_static() && input_shape.rank().get_length() == 0) {
FRONT_END_OP_CONVERSION_CHECK(false, "input tensor MUST be non-scalar");
}
Comment on lines +24 to +27
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto input_shape = input.get_partial_shape();
if (input_shape.rank().is_static() && input_shape.rank().get_length() == 0) {
FRONT_END_OP_CONVERSION_CHECK(false, "input tensor MUST be non-scalar");
}

Looking on how torch.take functions, scalar tensors should be allowed, for example:
torch.take(torch.tensor(5), torch.tensor(0)) -> tensor(5)
torch.take(torch.tensor(5), torch.tensor([0])) -> tensor([5])

Is there any reason for such check? I've modified test suite a bit and it seem that scalars should work correctly.

auto new_shape = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-1}));
input = context.mark_node(std::make_shared<v1::Reshape>(input, new_shape, false));
indices = context.mark_node(std::make_shared<v0::Convert>(indices, element::i64));
auto axis_constant = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
auto gather = context.mark_node(std::make_shared<v8::Gather>(input, indices, axis_constant));
if (!context.input_is_none(2)) {
context.mutate_input(2, gather);
}

return {gather};
}

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ OP_CONVERTER(translate_sub_);
OP_CONVERTER(translate_sum);
OP_CONVERTER(translate_t);
OP_CONVERTER(translate_take_along_dim);
OP_CONVERTER(translate_take_op);
OP_CONVERTER(translate_to);
OP_CONVERTER(translate_topk);
OP_CONVERTER(translate_transpose);
Expand Down Expand Up @@ -725,6 +726,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::swapaxes", op::quantizable_op<op::translate_transpose>},
{"aten::t", op::translate_t},
{"aten::take_along_dim", op::translate_take_along_dim},
{"aten::take", op::translate_take_op},
{"aten::tan", op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Tan>, 1>},
{"aten::tan_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Tan>>},
{"aten::tanh", op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Tanh>, 1>},
Expand Down
35 changes: 35 additions & 0 deletions tests/layer_tests/pytorch_tests/test_take.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
import numpy as np
import torch
from pytorch_layer_test_class import PytorchLayerTest

class TestTake(PytorchLayerTest):
def _prepare_input(self, input_shape, indices_shape, max_val):
input_tensor = np.random.randn(*input_shape).astype(np.float32)
indices = np.random.randint(-max_val, max_val, indices_shape).astype(np.int64)
return (input_tensor, indices)

def create_model(self):
class aten_take(torch.nn.Module):
def forward(self, x, indices):
return torch.take(x, indices)

ref_net = None
return aten_take(), ref_net, "aten::take"

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_torch_export
@pytest.mark.parametrize("input_shape", [(10,), (3, 4), (2, 3, 4)])
@pytest.mark.parametrize("indices_shape", [(5,), (2, 2), (3, 2)])
def test_take(self, input_shape, indices_shape, ie_device, precision, ir_version):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add test case for scenario with out. This would require few changes for whole class, you can refer to similar op: https://github.com/openvinotoolkit/openvino/blob/master/tests/layer_tests/pytorch_tests/test_take_along_dim.py#L51

max_val = np.prod(input_shape)
self._test(*self.create_model(), ie_device, precision, ir_version,
kwargs_to_prepare_input={
"input_shape": input_shape,
"indices_shape": indices_shape,
"max_val": max_val
})