Skip to content

Commit

Permalink
Update is_nonzero.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
siddhant-0707 committed Oct 19, 2023
1 parent 25cea50 commit a789e59
Showing 1 changed file with 8 additions and 18 deletions.
26 changes: 8 additions & 18 deletions src/frontends/pytorch/src/op/is_nonzero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/equal.hpp"
#include "openvino/op/logical_and.hpp"
#include "openvino/op/logical_or.hpp"
#include "openvino/op/not_equal.hpp"
#include "openvino/op/shape_of.hpp"
#include "pt_framework_node.hpp"
#include "utils.hpp"

Expand All @@ -25,24 +21,18 @@ OutputVector translate_is_nonzero(const NodeContext& context) {
auto input = context.get_input(0);

auto zero_tensor = context.mark_node(v0::Constant::create(element::f32, Shape{1}, {0.0}));
auto one_tensor = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {1}));
auto false_tensor = context.mark_node(v0::Constant::create(element::boolean, Shape{1}, {false}));

// check if length input is 1
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input));
auto is_length_one = context.mark_node(std::make_shared<v1::Equal>(input_shape, one_tensor));
std::shared_ptr<ov::Node> result;

// perform type conversion
auto converted_input = context.mark_node(std::make_shared<v0::Convert>(input, element::f32));
if (input.get_element_type() == element::boolean) {
result = context.mark_node(std::make_shared<v1::NotEqual>(input, false_tensor));
} else {
auto converted_input = context.mark_node(std::make_shared<v0::Convert>(input, element::f32));
result = context.mark_node(std::make_shared<v1::NotEqual>(converted_input, zero_tensor));
}

auto is_nonzero_numeric = context.mark_node(std::make_shared<v1::NotEqual>(converted_input, zero_tensor));
auto is_nonzero_boolean = context.mark_node(std::make_shared<v1::NotEqual>(input, false_tensor));

auto final_result = context.mark_node(std::make_shared<v1::LogicalAnd>(
is_length_one,
context.mark_node(std::make_shared<v1::LogicalOr>(is_nonzero_numeric, is_nonzero_boolean))));

return {final_result};
return {result};
};

} // namespace op
Expand Down

0 comments on commit a789e59

Please sign in to comment.