diff --git a/src/frontends/tensorflow/src/op_table.cpp b/src/frontends/tensorflow/src/op_table.cpp index bdc13df7b0e08a..2ea0e696c1bba8 100644 --- a/src/frontends/tensorflow/src/op_table.cpp +++ b/src/frontends/tensorflow/src/op_table.cpp @@ -166,8 +166,8 @@ const std::map get_supported_ops() { {"Swish", CreatorFunction(translate_unary_op)}, // note: BinaryOp translator declaration for each op must to be added in binary_op.cpp file - {"Add", CreatorFunction(translate_binary_op)}, - {"AddV2", CreatorFunction(translate_binary_op)}, + {"Add", CreatorFunction(translate_addv2_op)}, + {"AddV2", CreatorFunction(translate_addv2_op)}, {"Atan2", CreatorFunction(translate_atan2_op)}, {"BitwiseAnd", CreatorFunction(translate_binary_op)}, {"BitwiseOr", CreatorFunction(translate_binary_op)}, diff --git a/src/frontends/tensorflow_common/include/common_op_table.hpp b/src/frontends/tensorflow_common/include/common_op_table.hpp index f0b564f0f07a86..41110be14288da 100644 --- a/src/frontends/tensorflow_common/include/common_op_table.hpp +++ b/src/frontends/tensorflow_common/include/common_op_table.hpp @@ -31,7 +31,7 @@ OP_T_CONVERTER(translate_unary_op); OP_CONVERTER(translate_selu_op); OP_T_CONVERTER(translate_binary_op); OP_T_CONVERTER(translate_direct_reduce_op); - +OP_CONVERTER(translate_addv2_op); OP_CONVERTER(translate_add_n_op); OP_CONVERTER(translate_adjust_contrast_op); OP_CONVERTER(translate_arg_max_op); diff --git a/src/frontends/tensorflow_common/src/op/binary_op.cpp b/src/frontends/tensorflow_common/src/op/binary_op.cpp index e992c1bfc0b760..832aff9409c288 100644 --- a/src/frontends/tensorflow_common/src/op/binary_op.cpp +++ b/src/frontends/tensorflow_common/src/op/binary_op.cpp @@ -142,6 +142,33 @@ OutputVector translate_mul_op(const NodeContext& node) { set_node_name(node.get_name(), result); return {result}; } + +OutputVector translate_addv2_op(const NodeContext& node) { + default_op_checks(node, 2, {"Add", "AddV2"}, true); + auto lhs = node.get_input(0); + auto rhs = node.get_input(1); + + auto complex_type_mark_lhs = as_type_ptr(lhs.get_node_shared_ptr()); + auto complex_type_mark_rhs = as_type_ptr(rhs.get_node_shared_ptr()); + auto complex_type_inputs = (complex_type_mark_lhs || complex_type_mark_rhs) ? true : false; + + if (complex_type_inputs) { + lhs = complex_type_mark_lhs->input_value(0); + rhs = complex_type_mark_rhs->input_value(0); + } + + auto result = make_shared(lhs, rhs); + if (complex_type_inputs) { + auto complex_result = make_shared(result, complex_type_mark_lhs->get_complex_part_type()); + set_node_name(node.get_name(), result); + + return {complex_result}; + } + + set_node_name(node.get_name(), result); + return {result}; +} + template OutputVector translate_binary_op(const NodeContext& node); template OutputVector translate_binary_op(const NodeContext& node); template OutputVector translate_binary_op(const NodeContext& node); diff --git a/tests/layer_tests/tensorflow_tests/test_tf_Add.py b/tests/layer_tests/tensorflow_tests/test_tf_Add.py index 7fb4977b802698..fa8245bf288796 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_Add.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_Add.py @@ -222,3 +222,60 @@ def test_add_placeholder_const_broadcast_5D(self, params, ie_device, precision, use_legacy_frontend=use_legacy_frontend), ie_device, precision, ir_version=ir_version, temp_dir=temp_dir, use_legacy_frontend=use_legacy_frontend) + + +class TestComplexAdd(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + rng = np.random.default_rng() + assert 'param_real_1:0' in inputs_info + assert 'param_imag_1:0' in inputs_info + assert 'param_real_2:0' in inputs_info + assert 'param_imag_2:0' in inputs_info + param_real_shape_1 = inputs_info['param_real_1:0'] + param_imag_shape_1 = inputs_info['param_imag_1:0'] + param_real_shape_2 = inputs_info['param_real_2:0'] + param_imag_shape_2 = inputs_info['param_imag_2:0'] + inputs_data = {} + inputs_data['param_real_1:0'] = 4 * rng.random(param_real_shape_1).astype(np.float32) - 2 + inputs_data['param_imag_1:0'] = 4 * rng.random(param_imag_shape_1).astype(np.float32) - 2 + inputs_data['param_real_2:0'] = 4 * rng.random(param_real_shape_2).astype(np.float32) - 2 + inputs_data['param_imag_2:0'] = 4 * rng.random(param_imag_shape_2).astype(np.float32) - 2 + return inputs_data + + def create_complex_addv2_net(self, input_shape): + import tensorflow as tf + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + param_real1 = tf.compat.v1.placeholder(np.float32, input_shape, 'param_real_1') + param_imag1 = tf.compat.v1.placeholder(np.float32, input_shape, 'param_imag_1') + param_real2 = tf.compat.v1.placeholder(np.float32, input_shape, 'param_real_2') + param_imag2 = tf.compat.v1.placeholder(np.float32, input_shape, 'param_imag_2') + complex1 = tf.raw_ops.Complex(real=param_real1, imag=param_imag1) + complex2 = tf.raw_ops.Complex(real=param_real2, imag=param_imag2) + add = tf.raw_ops.AddV2(x=complex1, y=complex2, name="complex_add") + real = tf.raw_ops.Real(input=add) + img = tf.raw_ops.Imag(input=add) + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + + test_data_basic = [ + dict(input_shape=[]), + dict(input_shape=[2]), + dict(input_shape=[1, 3]), + dict(input_shape=[2, 3, 4]), + dict(input_shape=[3, 4, 5, 6]), + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_complex_add(self, params, ie_device, precision, ir_version, temp_dir, + use_legacy_frontend): + self._test( + *self.create_complex_addv2_net(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_legacy_frontend=use_legacy_frontend)