diff --git a/src/frontends/pytorch/src/op/is_nonzero.cpp b/src/frontends/pytorch/src/op/is_nonzero.cpp index 7635860e36d52f..8f5a33f43a1a8b 100644 --- a/src/frontends/pytorch/src/op/is_nonzero.cpp +++ b/src/frontends/pytorch/src/op/is_nonzero.cpp @@ -22,18 +22,16 @@ using namespace ov::op; OutputVector translate_is_nonzero(const NodeContext& context) { num_inputs_check(context, 1, 1); - auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0})); - auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1})); auto input = context.get_input(0); - // check if length input is 1 - auto input_shape = context.mark_node(std::make_shared(input, element::i32)); - auto is_length_one = context.mark_node(std::make_shared(input_shape, const_1)); - - // check if element is not equal to 0 or false auto zero_tensor = context.mark_node(v0::Constant::create(element::f32, Shape{1}, {0.0})); + auto one_tensor = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {1})); auto false_tensor = context.mark_node(v0::Constant::create(element::boolean, Shape{1}, {false})); + // check if length input is 1 + auto input_shape = context.mark_node(std::make_shared(input)); + auto is_length_one = context.mark_node(std::make_shared(input_shape, one_tensor)); + // perform type conversion auto converted_input = context.mark_node(std::make_shared(input, element::f32));