Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT FE] Support prim::TupleIndex operation #19978

Merged
merged 5 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/frontends/pytorch/src/op/tuple_index.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/gather.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_tuple_index(const NodeContext& context) {
// prim::TupleIndex(Any tup, int i) -> Any
num_inputs_check(context, 2, 2);
auto tuple = context.get_input(0).get_node_shared_ptr();
if (cast_fw_node(tuple, "prim::TupleConstruct")) {
// this case require index to be constant
auto index = context.const_input<int64_t>(1);
FRONT_END_OP_CONVERSION_CHECK(static_cast<size_t>(index) < tuple->get_input_size(),
"Index of TupleIndex operation is higher then number of tuple elements.");
return {tuple->get_input_source_output(index)};
} else {
// Assume this case is when tuple is represented as tensor
auto index = context.get_input(1);
auto zero = v0::Constant::create(element::i32, Shape{}, {0});
return {std::make_shared<v8::Gather>(context.get_input(0), index, zero)};
}
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ OP_CONVERTER(translate_topk);
OP_CONVERTER(translate_transpose);
OP_CONVERTER(translate_tril);
OP_CONVERTER(translate_triu);
OP_CONVERTER(translate_tuple_index);
OP_CONVERTER(translate_unflatten);
OP_CONVERTER(translate_unfold);
OP_CONVERTER(translate_upsample_bicubic2d);
Expand Down Expand Up @@ -479,6 +480,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"prim::requires_grad", op::return_false_scalar},
{"prim::PythonOp", op::translate_pythonop},
{"prim::type", op::skip_node}, // Used with prim::device, pass PtFrameworkNode.
{"prim::TupleIndex", op::translate_tuple_index},
{"quantized::add", op::translate_quantized_add},
{"quantized::add_relu", op::translate_quantized_add_relu},
{"quantized::cat", op::translate_quantized_cat},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ SoftmaxReshapeElimination::SoftmaxReshapeElimination() {

register_matcher(
std::make_shared<ov::pass::pattern::Matcher>(m_reshape1,
"ov::frontend::pytorch::pass::PrimTupleUnpackReplacer"),
"ov::frontend::pytorch::pass::SoftmaxReshapeElimination"),
[=](ov::pass::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto reshape0 = pattern_to_output[m_reshape0].get_node_shared_ptr();
Expand Down
38 changes: 29 additions & 9 deletions tests/layer_tests/pytorch_tests/test_tuple_construct.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,21 +79,21 @@ def forward(self, x):
def prepare_input(self, x):
return x, x + 2, None, x.reshape(-1), (x * 10).to(torch.int32)


ref_net = None

return prim_tuple_construct_tuple_unpack(), ref_net, ["prim::TupleConstruct", "prim::TupleUnpack"]

@pytest.mark.nightly
def test_tuple_construct_unpack(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, freeze_model=False)
self._test(*self.create_model(), ie_device,
precision, ir_version, freeze_model=False)


class TestTupleUnpackParameterSingle(PytorchLayerTest):
def _prepare_input(self):
def tensor_gen():
return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32)
return ( (tensor_gen(), tensor_gen()), )
return ((tensor_gen(), tensor_gen()), )

def create_model(self):
import torch
Expand All @@ -105,7 +105,6 @@ def forward(self, x: Tuple[torch.Tensor, torch.Tensor]):
x1, x2 = x
return x1, x2


return model(), None, ["prim::TupleUnpack"]

@pytest.mark.nightly
Expand All @@ -118,6 +117,7 @@ def _prepare_input(self):
def tensor_gen():
return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32)
# generate tensor with a different shape for easier mismatch detection in case of mixed input order

def tensor_gen_2():
return np.random.uniform(0, 50, (2, 3)).astype(np.float32)
return (tensor_gen_2(), (tensor_gen(), tensor_gen()), tensor_gen_2())
Expand All @@ -132,7 +132,6 @@ def forward(self, y1, x: Tuple[torch.Tensor, torch.Tensor], y2):
x1, x2 = x
return x1, x2, y1, y2


return model(), None, ["prim::TupleUnpack"]

@pytest.mark.nightly
Expand All @@ -144,7 +143,7 @@ class TestTupleUnpackParameterNested(PytorchLayerTest):
def _prepare_input(self):
def tensor_gen():
return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32)
return ( ((tensor_gen(), tensor_gen()), (tensor_gen(), tensor_gen())), )
return (((tensor_gen(), tensor_gen()), (tensor_gen(), tensor_gen())), )

def create_model(self):
import torch
Expand All @@ -158,7 +157,6 @@ def forward(self, x: Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor
y3, y4 = x2
return y1, y2, y3, y4


return model(), None, ["prim::TupleUnpack"]

@pytest.mark.nightly
Expand All @@ -170,7 +168,7 @@ class TestTupleUnpackParameterMultiple(PytorchLayerTest):
def _prepare_input(self):
def tensor_gen():
return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32)
return ( (tensor_gen(), tensor_gen()), (tensor_gen(), tensor_gen()) )
return ((tensor_gen(), tensor_gen()), (tensor_gen(), tensor_gen()))

def create_model(self):
import torch
Expand All @@ -183,9 +181,31 @@ def forward(self, x: Tuple[torch.Tensor, torch.Tensor], y: Tuple[torch.Tensor, t
z3, z4 = y
return z1, z2, z3, z4


return model(), None, ["prim::TupleUnpack"]

@pytest.mark.nightly
def test(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version)


class TestTupleIndex(PytorchLayerTest):
def _prepare_input(self):
return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32)

def create_model(self):
import torch
from typing import Tuple

class model(torch.nn.Module):
def forward(self, x):
return self.some_func((x,x))

def some_func(self, x: Tuple[torch.Tensor, torch.Tensor]):
return x[1] * 2, x[0] * 3

return model(), None, "prim::TupleIndex"

@pytest.mark.nightly
def test(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision,
ir_version, trace_model=False, freeze_model=False)