Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
siddhant-0707 committed Oct 19, 2023
1 parent 837a357 commit 25cea50
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions src/frontends/pytorch/src/op/is_nonzero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,16 @@ using namespace ov::op;

OutputVector translate_is_nonzero(const NodeContext& context) {
num_inputs_check(context, 1, 1);
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
auto input = context.get_input(0);

// check if length input is 1
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
auto is_length_one = context.mark_node(std::make_shared<v1::Equal>(input_shape, const_1));

// check if element is not equal to 0 or false
auto zero_tensor = context.mark_node(v0::Constant::create(element::f32, Shape{1}, {0.0}));
auto one_tensor = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {1}));
auto false_tensor = context.mark_node(v0::Constant::create(element::boolean, Shape{1}, {false}));

// check if length input is 1
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input));
auto is_length_one = context.mark_node(std::make_shared<v1::Equal>(input_shape, one_tensor));

// perform type conversion
auto converted_input = context.mark_node(std::make_shared<v0::Convert>(input, element::f32));

Expand Down

0 comments on commit 25cea50

Please sign in to comment.