Skip to content

Commit

Permalink
[TF FE] SparseSegmentMean translator (openvinotoolkit#26540)
Browse files Browse the repository at this point in the history
### Details:
 - Added translator for SparseSegmentMean 

### Tickets:
 - 149705

---------

Co-authored-by: Roman Kazantsev <[email protected]>
  • Loading branch information
popovaan and rkazants authored Sep 13, 2024
1 parent 9de51dd commit 4e17541
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/frontends/tensorflow/docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,7 @@ A "supported operation" is one that TensorFlow Frontend can convert to the OpenV
| SparseReduceSumSparse | NO | |
| SparseReorder | NO | |
| SparseReshape | YES | |
| SparseSegmentMean | NO | |
| SparseSegmentMean | YES | |
| SparseSegmentMeanGrad | NO | |
| SparseSegmentMeanGradV2 | NO | |
| SparseSegmentMeanWithNumSegments | NO | |
Expand Down
1 change: 1 addition & 0 deletions src/frontends/tensorflow/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"Softmax", CreatorFunction(translate_softmax_op)},
{"SpaceToDepth", CreatorFunction(translate_space_to_depth_op)},
{"SparseReshape", CreatorFunction(translate_sparse_reshape_op)},
{"SparseSegmentMean", CreatorFunction(translate_sparse_segment_mean_op)},
{"SparseTensorDenseAdd", CreatorFunction(translate_sparse_tensor_dense_add_op)},
{"SparseTensorDenseMatMul", CreatorFunction(translate_sparse_tensor_dense_mat_mul_op)},
{"SparseToDense", CreatorFunction(translate_sparse_to_dense_op)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ OP_CONVERTER(translate_rsqrt_op);
OP_CONVERTER(translate_scatter_nd_op);
OP_CONVERTER(translate_segment_sum_op);
OP_CONVERTER(translate_space_to_batch_nd_op);
OP_CONVERTER(translate_sparse_segment_mean_op);
OP_CONVERTER(translate_sparse_tensor_dense_add_op);
OP_CONVERTER(translate_sparse_tensor_dense_mat_mul_op);
OP_CONVERTER(translate_sparse_to_dense_op);
Expand Down
75 changes: 75 additions & 0 deletions src/frontends/tensorflow_common/src/op/sparse_segment_mean.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "common_op_table.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/embedding_segments_sum.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/scatter_update.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/unique.hpp"
#include "utils.hpp"

using namespace std;
using namespace ov::op;

namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_sparse_segment_mean_op(const NodeContext& node) {
default_op_checks(node, 3, {"SparseSegmentMean"});
auto data = node.get_input(0);
auto indices = std::make_shared<v0::Convert>(node.get_input(1), element::i64);
auto segment_ids = std::make_shared<v0::Convert>(node.get_input(2), element::i64);
auto data_rank = std::make_shared<v3::ShapeOf>(std::make_shared<v3::ShapeOf>(node.get_input(0)));

// get the last index from segment_ids
auto segments_ids_size = std::make_shared<v3::ShapeOf>(segment_ids, element::i64);
auto const_one = create_same_type_const<int32_t>(indices, vector<int32_t>{1}, Shape{1});
auto const_zero = create_same_type_const<int32_t>(indices, vector<int32_t>{0}, Shape{1});
auto last_idx = std::make_shared<v1::Subtract>(segments_ids_size, const_one);

// segment_ids are always sorted, so the last index from segment_ids can be used to determine the number of output
// segments.
auto last_segment_idx = std::make_shared<v8::Gather>(segment_ids, last_idx, const_zero);
auto n_segments = std::make_shared<v1::Add>(last_segment_idx, const_one);

// get sums of sparse segments
auto embedding_segments_sum =
make_shared<v3::EmbeddingSegmentsSum>(data, indices, segment_ids, std::make_shared<v0::Squeeze>(n_segments));

// get the sizes of each segment
auto unique_segment_ids = make_shared<v10::Unique>(segment_ids, true, element::i64, element::i64);
auto broadcast = make_shared<v3::Broadcast>(const_one, n_segments);
auto divisors = make_shared<v3::ScatterUpdate>(broadcast,
unique_segment_ids->output(0),
unique_segment_ids->output(3),
const_zero);
auto divisors_with_correct_type = make_shared<v1::ConvertLike>(divisors, data);
auto divisors_shape = make_shared<v3::ScatterUpdate>(make_shared<v3::Broadcast>(const_one, data_rank),
const_zero,
n_segments,
const_zero);
auto divisors_with_correct_shape = std::make_shared<v1::Reshape>(divisors_with_correct_type, divisors_shape, false);

// divide the sums by the size of the segments
auto mean = std::make_shared<v1::Divide>(embedding_segments_sum, divisors_with_correct_shape);

set_node_name(node.get_name(), mean);
return {mean};
}

} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov
79 changes: 79 additions & 0 deletions tests/layer_tests/tensorflow_tests/test_tf_SparseSegmentMean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest

rng = np.random.default_rng(475912)


class TestSparseSegmentMean(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'indices:0' in inputs_info
assert 'values:0' in inputs_info
assert 'segment_indices:0' in inputs_info

values_shape = inputs_info['values:0']

inputs_data = {}
inputs_data['values:0'] = rng.uniform(-5.0, 5.0, values_shape).astype(self.data_type)

# generate all possible indices
all_indices = []
for row_ind in range(0, self.shape[0]):
all_indices.append(row_ind)
inputs_data['indices:0'] = rng.choice(all_indices, self.indices_shape[0], replace=False).astype(self.indices_type)

segment_ids = []
for ind in range(0, self.indices_shape[0]):
segment_ids.append(self.segment_indices_type(rng.integers(0, self.segments_num)))
inputs_data['segment_indices:0'] = sorted(segment_ids)

return inputs_data

def create_sparse_segment_mean(self, data_type, indices_type, segment_indices_type,
shape, indices_shape, segments_num):
self.data_type = data_type
self.indices_type = indices_type
self.segment_indices_type = segment_indices_type
self.shape = shape
self.indices_shape = indices_shape
self.segments_num = segments_num
tf.compat.v1.reset_default_graph()
with tf.compat.v1.Session() as sess:
values = tf.compat.v1.placeholder(data_type, shape, 'values')
indices = tf.compat.v1.placeholder(indices_type, indices_shape, 'indices')
segments_ids = tf.compat.v1.placeholder(segment_indices_type, indices_shape, 'segment_indices')
tf.raw_ops.SparseSegmentMean(
data=values,
indices=indices,
segment_ids=segments_ids)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def

return tf_net, None

@pytest.mark.parametrize('data_type', [np.float16, np.float32, np.float64])
@pytest.mark.parametrize('indices_type', [np.int32, np.int64])
@pytest.mark.parametrize('segment_indices_type', [np.int32, np.int64])
@pytest.mark.parametrize('shape, indices_shape, segments_num', [
[[10], [7], 8],
[[5], [5], 3],
[[5], [2], 4],
[[10, 20], [7], 8],
[[10, 2, 4], [10], 4]
])
@pytest.mark.precommit
@pytest.mark.nightly
def test_sparse_segment_mean(self, data_type, indices_type, segment_indices_type,
shape, indices_shape, segments_num,
ie_device, precision, ir_version, temp_dir,
use_legacy_frontend):
if ie_device == 'GPU':
pytest.skip("GPU error: to_shape was called on a dynamic shape, ticket: 152352")
self._test(*self.create_sparse_segment_mean(data_type, indices_type, segment_indices_type,
shape, indices_shape, segments_num),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)

0 comments on commit 4e17541

Please sign in to comment.