From a789e593d9d8f19673cfc6fd0213a5136b63b907 Mon Sep 17 00:00:00 2001 From: siddhant-0707 Date: Thu, 19 Oct 2023 19:50:51 +0530 Subject: [PATCH] Update is_nonzero.cpp --- src/frontends/pytorch/src/op/is_nonzero.cpp | 26 +++++++-------------- 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/src/frontends/pytorch/src/op/is_nonzero.cpp b/src/frontends/pytorch/src/op/is_nonzero.cpp index 8f5a33f43a1a8b..6135af00e4ed72 100644 --- a/src/frontends/pytorch/src/op/is_nonzero.cpp +++ b/src/frontends/pytorch/src/op/is_nonzero.cpp @@ -5,11 +5,7 @@ #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/convert.hpp" -#include "openvino/op/equal.hpp" -#include "openvino/op/logical_and.hpp" -#include "openvino/op/logical_or.hpp" #include "openvino/op/not_equal.hpp" -#include "openvino/op/shape_of.hpp" #include "pt_framework_node.hpp" #include "utils.hpp" @@ -25,24 +21,18 @@ OutputVector translate_is_nonzero(const NodeContext& context) { auto input = context.get_input(0); auto zero_tensor = context.mark_node(v0::Constant::create(element::f32, Shape{1}, {0.0})); - auto one_tensor = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {1})); auto false_tensor = context.mark_node(v0::Constant::create(element::boolean, Shape{1}, {false})); - // check if length input is 1 - auto input_shape = context.mark_node(std::make_shared(input)); - auto is_length_one = context.mark_node(std::make_shared(input_shape, one_tensor)); + std::shared_ptr result; - // perform type conversion - auto converted_input = context.mark_node(std::make_shared(input, element::f32)); + if (input.get_element_type() == element::boolean) { + result = context.mark_node(std::make_shared(input, false_tensor)); + } else { + auto converted_input = context.mark_node(std::make_shared(input, element::f32)); + result = context.mark_node(std::make_shared(converted_input, zero_tensor)); + } - auto is_nonzero_numeric = context.mark_node(std::make_shared(converted_input, zero_tensor)); - auto is_nonzero_boolean = context.mark_node(std::make_shared(input, false_tensor)); - - auto final_result = context.mark_node(std::make_shared( - is_length_one, - context.mark_node(std::make_shared(is_nonzero_numeric, is_nonzero_boolean)))); - - return {final_result}; + return {result}; }; } // namespace op