Skip to content

Commit

Permalink
ConcatToBroadcast transformation (openvinotoolkit#24661)
Browse files Browse the repository at this point in the history
Details:

Due to the issue in the e2e test in
openvinotoolkit#24597 we decided to
exclude Concat to Tile conversion from the transformation. It will be
covered in another ticket.

- Add a ConcatToBroadcast transformation to replace Concat having inputs
from the same output with a Broadcast
- Add a test for the ConcatToBroadcast transformation


Significantly reduce model compile time and performance time.

Tickets:
[CVS-138829](https://jira.devtools.intel.com/browse/CVS-138829),
[CVS-138077](https://jira.devtools.intel.com/browse/CVS-138077)

---------

Signed-off-by: Andrii Staikov <[email protected]>
Co-authored-by: Andrii Staikov <[email protected]>
  • Loading branch information
itikhono and CuriousPanCake authored May 24, 2024
1 parent 2a9eadf commit 17f8e86
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/common/transformations/include/ov_ops/rms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,13 @@ class TRANSFORMATIONS_API RMS : public ov::op::Op {

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

double get_epsilon() const { return m_epsilon; }
double get_epsilon() const {
return m_epsilon;
}

void set_epsilon(double epsilon) { m_epsilon = epsilon; }
void set_epsilon(double epsilon) {
m_epsilon = epsilon;
}

private:
double m_epsilon{0};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class TRANSFORMATIONS_API ConcatToBroadcast;

} // namespace pass
} // namespace ov

/**
* @ingroup ov_transformation_common_api
* @brief ConcatToBroadcast transformation replaces Concat, having multiple inputs
* from the same output, with a Broadcast node
*/
class ov::pass::ConcatToBroadcast : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ConcatToBroadcast", "0");
ConcatToBroadcast();
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/concat_to_broadcast.hpp"

#include "itt.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/tile.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"

static bool use_broadcast(const std::shared_ptr<ov::op::v0::Concat>& concat) {
const auto& output = concat->output(0);
const auto& input = concat->input(0);
const auto& input_concat_dim = input.get_partial_shape()[concat->get_concatenation_axis()];

return input_concat_dim.is_static() && input_concat_dim.get_length() == 1 && output.get_partial_shape().is_static();
}

ov::pass::ConcatToBroadcast::ConcatToBroadcast() {
MATCHER_SCOPE(ConcatToBroadcast);

auto concat_label = pattern::wrap_type<op::v0::Concat>([](const Output<Node>& value) {
auto node = value.get_node_shared_ptr();
if (node->output(0).get_partial_shape().rank().is_dynamic()) {
return false;
}

auto first_input_source_output = node->get_input_source_output(0);
if (first_input_source_output.get_partial_shape().rank().is_dynamic()) {
return false;
}

const auto& input_values = node->input_values();

return std::all_of(input_values.cbegin(), input_values.cend(), [&](const ov::Output<Node>& output) {
return first_input_source_output == output;
});
});

matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();

auto root_node = pattern_map.at(concat_label).get_node_shared_ptr();
auto concat = std::dynamic_pointer_cast<op::v0::Concat>(root_node);
if (!concat) {
return false;
}

if (transformation_callback(concat)) {
return false;
}

const auto& input = concat->input_value(0);

std::shared_ptr<Node> replacement;
if (use_broadcast(concat)) {
auto target_shape = std::make_shared<ov::op::v0::Constant>(ov::element::i32,
Shape{concat->output(0).get_shape().size()},
concat->output(0).get_shape());
replacement = std::make_shared<ov::op::v3::Broadcast>(input, target_shape);
} else {
return false;
}

/* Common case (converting to Tile) causes an issue in e2e test with unknown root cause (ticket: 142246)
else {
std::vector<size_t> repeat_num_vec(concat->output(0).get_partial_shape().rank().get_length(), 1);
repeat_num_vec[concat->get_concatenation_axis()] = concat->get_input_size();
auto repeat_num =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, Shape{repeat_num_vec.size()}, repeat_num_vec);
replacement = std::make_shared<ov::op::v0::Tile>(input, repeat_num);
}
*/

replacement->set_friendly_name(concat->get_friendly_name());
ov::replace_node(concat, replacement);

ov::copy_runtime_info(concat, replacement);

return true;
};

auto m = std::make_shared<pattern::Matcher>(concat_label, matcher_name);
this->register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "transformations/common_optimizations/broadcast_elementwise_fusion.hpp"
#include "transformations/common_optimizations/broadcast_transition.hpp"
#include "transformations/common_optimizations/clamp_fusion.hpp"
#include "transformations/common_optimizations/concat_to_broadcast.hpp"
#include "transformations/common_optimizations/conv_mul_fusion.hpp"
#include "transformations/common_optimizations/conv_to_binary_conv.hpp"
#include "transformations/common_optimizations/convert_nms_gather_path_to_unsigned.hpp"
Expand Down Expand Up @@ -186,6 +187,8 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
REGISTER_PASS(manager, GRUCellFusion)
REGISTER_PASS(manager, SequenceFusion)

REGISTER_PASS(manager, ConcatToBroadcast);

auto transpose_sinking = manager.register_pass<ov::pass::GraphRewrite>();
ADD_MATCHER(transpose_sinking, TransposeSinking)
// SplitSqueezeConcatFusion should work in same GraphRewrite as TransposesSinking,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/concat_to_broadcast.hpp"

#include <gtest/gtest.h>

#include "openvino/core/node_vector.hpp"
#include "openvino/core/shape.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/tile.hpp"
#include "openvino/pass/manager.hpp"

using namespace testing;

enum class ExpectedType {
Broadcast,
Tile,
Concat,
};

using ConcatToBroadcastParams = std::tuple<ov::PartialShape, size_t, size_t, ExpectedType>;

class ConcatToBroadcastTest : public WithParamInterface<ConcatToBroadcastParams>, public testing::Test {
protected:
void SetUp() override {
std::tie(data_shape, concat_num_inputs, concat_axis, expected_type) = GetParam();
}

ov::PartialShape data_shape;
size_t concat_num_inputs;
ExpectedType expected_type;
size_t concat_axis;
};

INSTANTIATE_TEST_SUITE_P(
type_prop,
ConcatToBroadcastTest,
Values(ConcatToBroadcastParams({2, 1}, 10, 1, ExpectedType::Broadcast),
ConcatToBroadcastParams({1, 2, 1, 2}, 2604, 0, ExpectedType::Broadcast),
ConcatToBroadcastParams(ov::PartialShape::dynamic(), 2604, 0, ExpectedType::Concat),
// Common case (converting to Tile) causes an issue in e2e test with unknown root cause (ticket: 142246)
// The following tests cover Tile cases, but temporary Concat remains in the graph.
ConcatToBroadcastParams({1, 2, 1, 2}, 2604, 1, ExpectedType::Concat),
ConcatToBroadcastParams({-1, 2, 1, 2}, 2604, 0, ExpectedType::Concat),
ConcatToBroadcastParams({-1, -1, -1, -1}, 2604, 0, ExpectedType::Concat)));

TEST_P(ConcatToBroadcastTest, TestTransfromationExecuted) {
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, data_shape);
std::vector<ov::Output<ov::Node>> concat_inputs(concat_num_inputs, param->output(0));

auto concat = std::make_shared<ov::op::v0::Concat>(concat_inputs, concat_axis);
auto result = std::make_shared<ov::op::v0::Result>(concat->output(0));

auto model = std::make_shared<ov::Model>(ov::ResultVector{result}, ov::ParameterVector{param});

ov::pass::Manager manager;
manager.register_pass<ov::pass::ConcatToBroadcast>();
manager.run_passes(model);

const auto& ops = model->get_ordered_ops();

size_t tile_count = 0;
size_t broadcast_count = 0;
size_t concat_count = 0;

for (auto& op : ops) {
std::cout << op << std::endl;
if (std::dynamic_pointer_cast<ov::op::v3::Broadcast>(op)) {
++broadcast_count;
} else if (std::dynamic_pointer_cast<ov::op::v0::Tile>(op)) {
++tile_count;
} else if (std::dynamic_pointer_cast<ov::op::v0::Concat>(op)) {
++concat_count;
}
}

if (expected_type == ExpectedType::Broadcast) {
ASSERT_EQ(broadcast_count, 1);
ASSERT_EQ(tile_count, 0);
ASSERT_EQ(concat_count, 0);
} else if (expected_type == ExpectedType::Tile) {
ASSERT_EQ(broadcast_count, 0);
ASSERT_EQ(tile_count, 1);
ASSERT_EQ(concat_count, 0);
} else if (expected_type == ExpectedType::Concat) {
ASSERT_EQ(broadcast_count, 0);
ASSERT_EQ(tile_count, 0);
ASSERT_EQ(concat_count, 1);
}
}

0 comments on commit 17f8e86

Please sign in to comment.