Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT FE] Added support for aten::hstack and aten::vstack #28933

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions src/frontends/pytorch/src/op/stack.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/concat.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_stack_common(const NodeContext& context,
const std::deque<ov::Output<ov::Node>>& list_elems,
int64_t axis) {
auto first_node = list_elems.front().get_node_shared_ptr();
if (list_elems.size() == 1 &&
!ov::as_type_ptr<op::util::FrameworkNode>(context.get_input(0).get_node_shared_ptr())) {
auto tensor = list_elems[0];
auto shape = context.mark_node(std::make_shared<v3::ShapeOf>(tensor, element::i32));
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
auto axis_const = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {axis}));
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
auto int_max =
context.mark_node(v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int32_t>().max()}));
auto shape_sliced = context.mark_node(std::make_shared<v8::Slice>(shape, one, int_max, one));
auto new_shape =
context.mark_node(std::make_shared<v12::ScatterElementsUpdate>(shape_sliced, axis_const, neg_1, zero));
return {context.mark_node(std::make_shared<v1::Reshape>(tensor, new_shape, false))};
}

const auto first_in_type = list_elems.front().get_element_type();
const bool is_mixed_type =
list_elems.size() > 1 && (std::any_of(std::next(list_elems.begin()),
list_elems.end(),
[&first_in_type](const ov::Output<ov::Node>& input) {
return input.get_element_type() != first_in_type ||
input.get_element_type() == ov::element::dynamic;
}));
auto inputs_vec = OutputVector(list_elems.begin(), list_elems.end());
if (is_mixed_type) {
auto node_of_type = inputs_vec[0];
for (size_t i = 1; i < inputs_vec.size(); ++i) {
auto cpt = context.mark_node(std::make_shared<v14::ConvertPromoteTypes>(node_of_type, list_elems[i], true));
node_of_type = cpt->output(0);
inputs_vec[i] = cpt->output(1);
}

inputs_vec[0] = node_of_type;
const auto unified_type = node_of_type.get_element_type();
for (size_t i = 1; i < inputs_vec.size(); ++i) {
if (inputs_vec[i].get_element_type() != unified_type ||
inputs_vec[i].get_element_type() == ov::element::dynamic) {
inputs_vec[i] = context.mark_node(std::make_shared<v1::ConvertLike>(list_elems[i], node_of_type));
}
}
auto concat = std::make_shared<v0::Concat>(inputs_vec, axis);
return {context.mark_node(concat)};
}

return {context.mark_node(std::make_shared<v0::Concat>(inputs_vec, axis))};
}

OutputVector translate_hstack(const NodeContext& context) {
num_inputs_check(context, 1, 2);
const auto&& list_elems = get_list_as_outputs(context.get_input(0));
int64_t axis = 1;
auto out = translate_stack_common(context, list_elems, axis);
if (!context.input_is_none(1)) {
context.mutate_input(1, out[0]);
}
return out;
};

OutputVector translate_vstack(const NodeContext& context) {
num_inputs_check(context, 1, 2);
const auto&& list_elems = get_list_as_outputs(context.get_input(0));
int64_t axis = 0;
auto out = translate_stack_common(context, list_elems, axis);
if (!context.input_is_none(1)) {
context.mutate_input(1, out[0]);
}
return out;
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
4 changes: 4 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ OP_CONVERTER(translate_group_norm);
OP_CONVERTER(translate_gru);
OP_CONVERTER(translate_hann_window);
OP_CONVERTER(translate_hardtanh);
OP_CONVERTER(translate_hstack);
OP_CONVERTER(translate_if);
OP_CONVERTER(translate_im2col);
OP_CONVERTER(translate_imag);
Expand Down Expand Up @@ -327,6 +328,7 @@ OP_CONVERTER(translate_quantize_per_channel_fx);
OP_CONVERTER(translate_quantize_per_tensor_fx);
OP_CONVERTER(translate_var_fx);
OP_CONVERTER(translate_var_mean_fx);
OP_CONVERTER(translate_vstack);
OP_CONVERTER(translate_unbind_int_fx);
OP_CONVERTER(translate_unique2);
OP_CONVERTER(translate_zeros_fx);
Expand Down Expand Up @@ -506,6 +508,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::hardsigmoid", op::quantizable_op<op::translate_1to1_match_1_inputs<opset10::HSigmoid>>},
{"aten::hardswish", op::quantizable_op<op::translate_1to1_match_1_inputs<opset10::HSwish>>},
{"aten::hardtanh", op::quantizable_op<op::translate_hardtanh>},
{"aten::hstack", op::translate_hstack},
{"aten::im2col", op::translate_im2col},
{"aten::imag", op::translate_imag},
// aten::index - Supported in limited set of patterns
Expand Down Expand Up @@ -710,6 +713,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::view_as", op::translate_reshape_as},
{"aten::view_as_complex", op::translate_view_as_complex},
{"aten::view_as_real", op::translate_view_as_real},
{"aten::vstack", op::translate_vstack},
{"aten::wait", op::skip_node},
{"aten::where", op::translate_where},
{"aten::zero", op::translate_zeros_like},
Expand Down
96 changes: 96 additions & 0 deletions tests/layer_tests/pytorch_tests/test_hstack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import numpy as np
import numpy as np

from pytorch_layer_test_class import PytorchLayerTest

class aten_hstack(torch.nn.Module):
def __init__(self):
super().__init__()

code-dev05 marked this conversation as resolved.
Show resolved Hide resolved
def forward(self, x):
return torch.hstack(self.prepare_input(x))

def prepare_input(self, x):
return (x, x)

class aten_hstack_out(aten_hstack):
def forward(self, x, out):
return torch.hstack(self.prepare_input(x), out=out), out

class TestHstack(PytorchLayerTest):
def _prepare_input(self, out=False, num_repeats=2):
data = np.random.randn(2, 1, 3)
if not out:
return (data, )
concat = [data for _ in range(num_repeats)]
out = np.zeros_like(np.concatenate(concat, axis=1))
return (data, out)

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("out", [False, True])
def test_hstack(self, out, ie_device, precision, ir_version):
model = aten_hstack() if not out else aten_hstack_out()
self._test(model, None, "aten_hstack", ie_device,
precision, ir_version, kwargs={"out": out, "num_repeats": 2})


class TestHstackAlignTypes(PytorchLayerTest):
def _prepare_input(self, in_types):
in_vals = []
for i in range(len(in_types)):
dtype = in_types[i]
in_vals.append(np.random.randn(2, 1, 3).astype(dtype))
return in_vals

def create_model(self, in_count):
class aten_align_types_hstack_two_args(torch.nn.Module):
def __init__(self):
super().__init()

def forward(self, x, y):
ins = [x, y]
return torch.hstack(ins)

class aten_align_types_hstack_three_args(torch.nn.Module):
def __init__(self):
super().__init()

def forward(self, x, y, z):
ins = [x, y, z]
return torch.hstack(ins)

if in_count == 2:
return aten_align_types_hstack_two_args()

if in_count == 3:
return aten_align_types_hstack_three_args()

@pytest.mark.parametrize(("in_types"), [
(np.float32, np.int32),
(np.int32, np.float32),
(np.float16, np.float32),
(np.int16, np.float16),
(np.int32, np.int64),
# # Three inputs
(np.float32, np.int32, np.int32),
(np.float32, np.int32, np.float32),
(np.int32, np.float32, np.int32),
(np.float32, np.int32, np.int16),
(np.int32, np.float32, np.int16),
(np.int16, np.int32, np.int16),
(np.float16, np.float32, np.float16),
(np.float32, np.float16, np.float32),
(np.float16, np.int32, np.int16),
(np.int16, np.float16, np.int16)
])
@pytest.mark.nightly
@pytest.mark.precommit
def test_align_types_hstack(self, ie_device, precision, ir_version, in_types):
self._test(self.create_model(len(in_types)), None, "aten::hstack",
ie_device, precision, ir_version, kwargs={"in_types": in_types})
96 changes: 96 additions & 0 deletions tests/layer_tests/pytorch_tests/test_vstack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import numpy as np
import numpy as np

from pytorch_layer_test_class import PytorchLayerTest

class aten_vstack(torch.nn.Module):
def __init__(self):
super().__init__()

code-dev05 marked this conversation as resolved.
Show resolved Hide resolved
def forward(self, x):
return torch.vstack(self.prepare_input(x))

def prepare_input(self, x):
return (x, x)

class aten_vstack_out(aten_vstack):
def forward(self, x, out):
return torch.vstack(self.prepare_input(x), out=out), out

class TestVstack(PytorchLayerTest):
def _prepare_input(self, out=False, num_repeats=2):
data = np.random.randn(2, 1, 3)
if not out:
return (data, )
concat = [data for _ in range(num_repeats)]
out = np.zeros_like(np.concatenate(concat, axis=0))
return (data, out)

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("out", [False, True])
def test_vstack(self, out, ie_device, precision, ir_version):
model = aten_vstack() if not out else aten_vstack_out()
self._test(model, None, "aten_vstack", ie_device,
precision, ir_version, kwargs={"out": out, "num_repeats": 2})


class TestVstackAlignTypes(PytorchLayerTest):
def _prepare_input(self, in_types):
in_vals = []
for i in range(len(in_types)):
dtype = in_types[i]
in_vals.append(np.random.randn(2, 1, 3).astype(dtype))
return in_vals

def create_model(self, in_count):
class aten_align_types_vstack_two_args(torch.nn.Module):
def __init__(self):
super().__init()

def forward(self, x, y):
ins = [x, y]
return torch.vstack(ins)

class aten_align_types_vstack_three_args(torch.nn.Module):
def __init__(self):
super().__init()

def forward(self, x, y, z):
ins = [x, y, z]
return torch.vstack(ins)

if in_count == 2:
return aten_align_types_vstack_two_args()

if in_count == 3:
return aten_align_types_vstack_three_args()

@pytest.mark.parametrize(("in_types"), [
(np.float32, np.int32),
(np.int32, np.float32),
(np.float16, np.float32),
(np.int16, np.float16),
(np.int32, np.int64),
# # Three inputs
(np.float32, np.int32, np.int32),
(np.float32, np.int32, np.float32),
(np.int32, np.float32, np.int32),
(np.float32, np.int32, np.int16),
(np.int32, np.float32, np.int16),
(np.int16, np.int32, np.int16),
(np.float16, np.float32, np.float16),
(np.float32, np.float16, np.float32),
(np.float16, np.int32, np.int16),
(np.int16, np.float16, np.int16)
])
@pytest.mark.nightly
@pytest.mark.precommit
def test_align_types_vstack(self, ie_device, precision, ir_version, in_types):
self._test(self.create_model(len(in_types)), None, "aten::vstack",
ie_device, precision, ir_version, kwargs={"in_types": in_types})
Loading