Skip to content

Commit

Permalink
Add is_nonzero operator and test
Browse files Browse the repository at this point in the history
  • Loading branch information
siddhant-0707 committed Oct 18, 2023
1 parent 4914541 commit 837a357
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 0 deletions.
53 changes: 53 additions & 0 deletions src/frontends/pytorch/src/op/is_nonzero.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/equal.hpp"
#include "openvino/op/logical_and.hpp"
#include "openvino/op/logical_or.hpp"
#include "openvino/op/not_equal.hpp"
#include "openvino/op/shape_of.hpp"
#include "pt_framework_node.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_is_nonzero(const NodeContext& context) {
num_inputs_check(context, 1, 1);
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
auto input = context.get_input(0);

// check if length input is 1
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
auto is_length_one = context.mark_node(std::make_shared<v1::Equal>(input_shape, const_1));

// check if element is not equal to 0 or false
auto zero_tensor = context.mark_node(v0::Constant::create(element::f32, Shape{1}, {0.0}));
auto false_tensor = context.mark_node(v0::Constant::create(element::boolean, Shape{1}, {false}));

// perform type conversion
auto converted_input = context.mark_node(std::make_shared<v0::Convert>(input, element::f32));

auto is_nonzero_numeric = context.mark_node(std::make_shared<v1::NotEqual>(converted_input, zero_tensor));
auto is_nonzero_boolean = context.mark_node(std::make_shared<v1::NotEqual>(input, false_tensor));

auto final_result = context.mark_node(std::make_shared<v1::LogicalAnd>(
is_length_one,
context.mark_node(std::make_shared<v1::LogicalOr>(is_nonzero_numeric, is_nonzero_boolean))));

return {final_result};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ OP_CONVERTER(translate_index_put_);
OP_CONVERTER(translate_index_select);
OP_CONVERTER(translate_instance_norm);
OP_CONVERTER(translate_int);
OP_CONVERTER(translate_is_nonzero);
OP_CONVERTER(translate_layer_norm);
OP_CONVERTER(translate_len);
OP_CONVERTER(translate_linalg_norm);
Expand Down Expand Up @@ -355,6 +356,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::Int", op::translate_int},
{"aten::IntImplicit", op::translate_int},
{"aten::is_grad_enabled", op::return_false_scalar},
{"aten::is_nonzero", op::translate_is_nonzero},
{"aten::item", op::translate_1to1_match_1_inputs<opset10::Squeeze>},
{"aten::layer_norm", op::translate_layer_norm},
{"aten::le", op::translate_1to1_match_2_inputs_align_types<opset10::LessEqual>},
Expand Down
32 changes: 32 additions & 0 deletions tests/layer_tests/pytorch_tests/test_is_nonzero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
import torch

from pytorch_layer_test_class import PytorchLayerTest


@pytest.mark.parametrize('input_tensor', (np.array([0.]), np.array([1.5]), np.array([False]), np.array([3]), np.array([1, 3, 5])))
class TestIsNonZero(PytorchLayerTest):

def _prepare_input(self):
input_tensor = self.input_tensor
return (input_tensor.astype(np.int64),)

def create_model(self):
class aten_is_nonzero(torch.nn.Module):

def forward(self, input_tensor):
return torch.is_nonzero(input_tensor)

ref_net = None

return aten_is_nonzero(), ref_net, "aten::is_nonzero"

@pytest.mark.nightly
@pytest.mark.precommit
def test_is_nonzero(self, ie_device, precision, ir_version, input_tensor):
self.input_tensor = input_tensor
self._test(*self.create_model(), ie_device, precision, ir_version)

0 comments on commit 837a357

Please sign in to comment.