Skip to content

Commit

Permalink
[PT FE]: support aten::pixel_unshuffle (openvinotoolkit#20325)
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova authored Oct 10, 2023
1 parent 4d9f2f3 commit 0dcde7f
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 1 deletion.
39 changes: 39 additions & 0 deletions src/frontends/pytorch/src/op/pixel_shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,45 @@ OutputVector translate_pixel_shuffle(const NodeContext& context) {
return {context.mark_node(std::make_shared<v1::Reshape>(transpose, shape_after, false))};
};

OutputVector translate_pixel_unshuffle(const NodeContext& context) {
// aten::pixel_unshuffle(Tensor self, int upscale_factor) -> Tensor
num_inputs_check(context, 2, 2);
auto x = context.get_input(0);
auto upscale_factor = context.get_input(1);
auto neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
auto neg_3 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-3}));
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
auto zero_s = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
auto one_s = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
Output<Node> shape;
Output<Node> rank;
std::tie(shape, rank) = get_shape_rank(context, x, true);
// 1. Reshape input to [-1, C, H / r, r, W / r, r], where r is upscale factor
auto indices = context.mark_node(v0::Constant::create(element::i32, Shape{3}, {-3, -2, -1}));
auto dims = context.mark_node(std::make_shared<v8::Gather>(shape, indices, zero_s));
auto dims_splitted = context.mark_node(std::make_shared<v1::Split>(dims, zero_s, 3));
auto c = dims_splitted->output(0);
auto h = dims_splitted->output(1);
auto w = dims_splitted->output(2);
auto dims_before = context.mark_node(std::make_shared<v8::Slice>(shape, zero, neg_3, one));
auto r = context.mark_node(std::make_shared<v0::Unsqueeze>(upscale_factor, zero));
auto new_h = context.mark_node(std::make_shared<v1::Divide>(h, upscale_factor, true));
auto new_w = context.mark_node(std::make_shared<v1::Divide>(w, upscale_factor, true));
auto intermediate_shape =
context.mark_node(std::make_shared<v0::Concat>(OutputVector{neg_1, c, new_h, r, new_w, r}, 0));
auto x_reshaped = context.mark_node(std::make_shared<v1::Reshape>(x, intermediate_shape, false));
// 2. Transpose to [-1, C, r, r, H / r, W / r]
auto transpose_order = context.mark_node(v0::Constant::create(element::i32, Shape{6}, {0, 1, 3, 5, 2, 4}));
auto x_transposed = context.mark_node(std::make_shared<v1::Transpose>(x_reshaped, transpose_order));
// 3. Reshape to [*, C*r*r, H / r, W / r]
auto r_sqr = context.mark_node(std::make_shared<v1::Multiply>(r, r));
auto new_c = context.mark_node(std::make_shared<v1::Multiply>(c, r_sqr));
auto final_shape =
context.mark_node(std::make_shared<v0::Concat>(OutputVector{dims_before, new_c, new_h, new_w}, 0));
return {context.mark_node(std::make_shared<v1::Reshape>(x_transposed, final_shape, false))};
};

OutputVector translate_channel_shuffle(const NodeContext& context) {
// aten::channel_shuffle(Tensor self, int groups) -> Tensor
num_inputs_check(context, 2, 2);
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ OP_CONVERTER(translate_outer);
OP_CONVERTER(translate_pad);
OP_CONVERTER(translate_pairwise_distance);
OP_CONVERTER(translate_pixel_shuffle);
OP_CONVERTER(translate_pixel_unshuffle);
OP_CONVERTER(translate_pow);
OP_CONVERTER(translate_pythonop);
OP_CONVERTER(translate_quantize_per_channel);
Expand Down Expand Up @@ -409,6 +410,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::pairwise_distance", op::translate_pairwise_distance},
{"aten::permute", op::translate_1to1_match_2_inputs<opset10::Transpose>},
{"aten::pixel_shuffle", op::translate_pixel_shuffle},
{"aten::pixel_unshuffle", op::translate_pixel_unshuffle},
{"aten::prelu", op::translate_1to1_match_2_inputs<opset10::PRelu>},
{"aten::pow", op::translate_pow},
{"aten::quantize_per_channel", op::translate_quantize_per_channel},
Expand Down
30 changes: 29 additions & 1 deletion tests/layer_tests/pytorch_tests/test_pixel_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,34 @@ def test_pixel_shuffle(self, upscale_factor, shape, ie_device, precision, ir_ver
ie_device, precision, ir_version)


class TestPixelUnshuffle(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(*self.shape).astype(np.float32),)

def create_model(self, upscale_factor):
import torch
import torch.nn.functional as F

class aten_pixel_unshuffle(torch.nn.Module):
def __init__(self, upscale_factor):
super(aten_pixel_unshuffle, self).__init__()
self.upscale_factor = upscale_factor

def forward(self, x):
return F.pixel_unshuffle(x, self.upscale_factor)

return aten_pixel_unshuffle(upscale_factor), None, "aten::pixel_unshuffle"

@pytest.mark.parametrize(("upscale_factor,shape"), [(3, [1, 1, 12, 12]),
(2, [1, 2, 3, 2, 8, 8]),])
@pytest.mark.nightly
@pytest.mark.precommit
def test_pixel_unshuffle(self, upscale_factor, shape, ie_device, precision, ir_version):
self.shape = shape
self._test(*self.create_model(upscale_factor),
ie_device, precision, ir_version)


class TestChannelShuffle(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(*self.shape).astype(np.float32),)
Expand Down Expand Up @@ -65,4 +93,4 @@ def forward(self, x):
def test_channel_shuffle(self, groups, shape, ie_device, precision, ir_version):
self.shape = shape
self._test(*self.create_model(groups),
ie_device, precision, ir_version)
ie_device, precision, ir_version)

0 comments on commit 0dcde7f

Please sign in to comment.