Skip to content

Commit

Permalink
requested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
siddhant-0707 committed Oct 20, 2023
1 parent 6607b79 commit ec18187
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 12 deletions.
14 changes: 3 additions & 11 deletions src/frontends/pytorch/src/op/is_nonzero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/not_equal.hpp"
#include "pt_framework_node.hpp"
#include "utils.hpp"
Expand All @@ -20,17 +19,10 @@ OutputVector translate_is_nonzero(const NodeContext& context) {
num_inputs_check(context, 1, 1);
auto input = context.get_input(0);

Output<Node> zero_tensor = context.mark_node(v0::Constant::create(element::f32, Shape{1}, {0.0}));
auto false_tensor = context.mark_node(v0::Constant::create(element::boolean, Shape{1}, {false}));
Output<Node> zero_tensor = context.mark_node(v0::Constant::create(element::boolean, Shape{1}, {false}));

std::shared_ptr<ov::Node> result;

if (input.get_element_type() == element::boolean) {
result = context.mark_node(std::make_shared<v1::NotEqual>(input, false_tensor));
} else {
align_eltwise_input_types(context, input, zero_tensor);
result = context.mark_node(std::make_shared<v1::NotEqual>(input, zero_tensor));
}
align_eltwise_input_types(context, input, zero_tensor);
auto result = context.mark_node(std::make_shared<v1::NotEqual>(input, zero_tensor));

return {result};
};
Expand Down
2 changes: 1 addition & 1 deletion tests/layer_tests/pytorch_tests/test_is_nonzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pytorch_layer_test_class import PytorchLayerTest


@pytest.mark.parametrize('input_tensor', (np.array([0.]), np.array([1.5]), np.array([False]), np.array([3]), np.array([1, 3, 5])))
@pytest.mark.parametrize('input_tensor', (np.array([0.]), np.array([1.5]), np.array([False]), np.array([3])))
class TestIsNonZero(PytorchLayerTest):

def _prepare_input(self):
Expand Down

0 comments on commit ec18187

Please sign in to comment.