Skip to content

Commit

Permalink
[TF FE] Add support for TensorScatterAdd in TF FE (#28419)
Browse files Browse the repository at this point in the history
**Overview**:

This pull request fixes #25050
All testcases passed

Continuation of PR #26481


**Dependencies**:

- No dependencies on other pull requests.

**CC**:

@rkazants, @mlukasze

---------

Co-authored-by: Roman Kazantsev <[email protected]>
  • Loading branch information
sumhaj and rkazants authored Jan 14, 2025
1 parent e390175 commit 842fedc
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/frontends/tensorflow/docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -1314,7 +1314,7 @@ A "supported operation" is one that TensorFlow Frontend can convert to the OpenV
| TensorListSetItem | YES | |
| TensorListSplit | NO | |
| TensorListStack | YES | |
| TensorScatterAdd | NO | |
| TensorScatterAdd | YES | |
| TensorScatterMax | NO | |
| TensorScatterMin | NO | |
| TensorScatterSub | NO | |
Expand Down
1 change: 1 addition & 0 deletions src/frontends/tensorflow/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"TensorListReserve", CreatorFunction(translate_tensor_list_reserve_op)},
{"TensorListResize", CreatorFunction(translate_tensor_list_resize_op)},
{"TensorListConcatV2", CreatorFunction(translate_tensor_list_concat_v2_op)},
{"TensorScatterAdd", CreatorFunction(translate_tensor_scatter_add_op)},
{"TensorScatterUpdate", CreatorFunction(translate_tensor_scatter_update_op)},
{"Tile", CreatorFunction(translate_tile_op)},
{"ToBool", CreatorFunction(translate_tobool_op)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ OP_CONVERTER(translate_tensor_list_set_item_op);
OP_CONVERTER(translate_tensor_list_stack_op);
OP_CONVERTER(translate_tensor_list_resize_op);
OP_CONVERTER(translate_tensor_list_concat_v2_op);
OP_CONVERTER(translate_tensor_scatter_add_op);
OP_CONVERTER(translate_tensor_scatter_update_op);
OP_CONVERTER(translate_tile_op);
OP_CONVERTER(translate_tobool_op);
Expand Down
29 changes: 29 additions & 0 deletions src/frontends/tensorflow_common/src/op/tensor_scatter_add.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "common_op_table.hpp"
#include "openvino/op/scatter_nd_update.hpp"

using namespace std;
using namespace ov::op;

namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_tensor_scatter_add_op(const NodeContext& node) {
default_op_checks(node, 3, {"TensorScatterAdd"});
auto data = node.get_input(0);
auto indices = node.get_input(1);
auto updates = node.get_input(2);
auto reduction = v15::ScatterNDUpdate::Reduction::SUM;
auto scatter_add_op = make_shared<v15::ScatterNDUpdate>(data, indices, updates, reduction);
set_node_name(node.get_name(), scatter_add_op);

return {scatter_add_op};
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov
89 changes: 89 additions & 0 deletions tests/layer_tests/tensorflow_tests/test_tf_TensorScatterAdd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest

rng = np.random.default_rng(872173)


class TestTensorScatterAdd(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'tensor:0' in inputs_info
assert 'indices:0' in inputs_info
assert 'updates:0' in inputs_info

tensor_shape = inputs_info['tensor:0']
updates_shape = inputs_info['updates:0']
indices_shape = inputs_info['indices:0']

inputs_data = {}
if np.issubdtype(self.data_type, np.floating):
inputs_data['tensor:0'] = rng.uniform(-5.0, 5.0, tensor_shape).astype(self.data_type)
inputs_data['updates:0'] = rng.uniform(-5.0, 5.0, updates_shape).astype(self.data_type)
elif np.issubdtype(self.data_type, np.signedinteger):
inputs_data['tensor:0'] = rng.integers(-8, 8, tensor_shape).astype(self.data_type)
inputs_data['updates:0'] = rng.integers(-8, 8, updates_shape).astype(self.data_type)
else:
inputs_data['tensor:0'] = rng.integers(0, 8, tensor_shape).astype(self.data_type)
inputs_data['updates:0'] = rng.integers(0, 8, updates_shape).astype(self.data_type)

indices_rows, indices_col = indices_shape

indices_of_tensor_shape = []
for i in range(0, indices_col):
indices_of_tensor_shape.append(np.arange(tensor_shape[i]))

mesh = np.meshgrid(*indices_of_tensor_shape)

all_indicies = np.stack(mesh, axis=indices_col)
all_indicies = all_indicies.reshape(-1, all_indicies.shape[-1])

inputs_data['indices:0'] = rng.choice(all_indicies, indices_rows, replace=False).astype(self.indices_type)

return inputs_data

def create_tensor_scatter_add_net(self, data_type, indices_type, tensor_shape, updates_shape, indices_shape):
self.data_type = data_type
self.indices_type = indices_type
self.tensor_shape = tensor_shape
self.updates_shape = updates_shape
self.indices_shape = indices_shape
tf.compat.v1.reset_default_graph()
with tf.compat.v1.Session() as sess:
indices = tf.compat.v1.placeholder(indices_type, indices_shape, 'indices')
tensor = tf.compat.v1.placeholder(data_type, tensor_shape, 'tensor')
updates = tf.compat.v1.placeholder(data_type, updates_shape, 'updates')
tf.raw_ops.TensorScatterAdd(
tensor=tensor,
indices=indices,
updates=updates)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def

ref_net = None

return tf_net, ref_net

@pytest.mark.parametrize('data_type', [np.float32, np.float64, np.int32])
@pytest.mark.parametrize('indices_type', [np.int32, np.int64])
@pytest.mark.parametrize('tensor_shape, updates_shape, indices_shape', [
[[10, 5], [2], [2, 2]],
[[4, 4, 4], [2, 4, 4], [2, 1]],
[[2, 4, 8], [3], [3, 3]],
[[4, 3, 5], [1, 5], [1, 2]],
])
@pytest.mark.precommit
@pytest.mark.nightly
def test_tensor_scatter_add(self, data_type, indices_type,
tensor_shape, updates_shape, indices_shape,
ie_device, precision, ir_version, temp_dir,
use_legacy_frontend):
if ie_device == 'GPU':
pytest.skip("160549: ScatterNDUpdate(opset15) is not supported on GPU")
self._test(*self.create_tensor_scatter_add_net(data_type, indices_type,
tensor_shape, updates_shape, indices_shape),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)

0 comments on commit 842fedc

Please sign in to comment.