forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ConcatToBroadcast transformation (openvinotoolkit#24661)
Details: Due to the issue in the e2e test in openvinotoolkit#24597 we decided to exclude Concat to Tile conversion from the transformation. It will be covered in another ticket. - Add a ConcatToBroadcast transformation to replace Concat having inputs from the same output with a Broadcast - Add a test for the ConcatToBroadcast transformation Significantly reduce model compile time and performance time. Tickets: [CVS-138829](https://jira.devtools.intel.com/browse/CVS-138829), [CVS-138077](https://jira.devtools.intel.com/browse/CVS-138077) --------- Signed-off-by: Andrii Staikov <[email protected]> Co-authored-by: Andrii Staikov <[email protected]>
- Loading branch information
1 parent
2a9eadf
commit 17f8e86
Showing
5 changed files
with
219 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
27 changes: 27 additions & 0 deletions
27
...mmon/transformations/include/transformations/common_optimizations/concat_to_broadcast.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/pass/graph_rewrite.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
class TRANSFORMATIONS_API ConcatToBroadcast; | ||
|
||
} // namespace pass | ||
} // namespace ov | ||
|
||
/** | ||
* @ingroup ov_transformation_common_api | ||
* @brief ConcatToBroadcast transformation replaces Concat, having multiple inputs | ||
* from the same output, with a Broadcast node | ||
*/ | ||
class ov::pass::ConcatToBroadcast : public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("ConcatToBroadcast", "0"); | ||
ConcatToBroadcast(); | ||
}; |
90 changes: 90 additions & 0 deletions
90
src/common/transformations/src/transformations/common_optimizations/concat_to_broadcast.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/common_optimizations/concat_to_broadcast.hpp" | ||
|
||
#include "itt.hpp" | ||
#include "openvino/op/broadcast.hpp" | ||
#include "openvino/op/concat.hpp" | ||
#include "openvino/op/tile.hpp" | ||
#include "openvino/pass/graph_rewrite.hpp" | ||
#include "openvino/pass/pattern/op/wrap_type.hpp" | ||
#include "transformations/utils/utils.hpp" | ||
|
||
static bool use_broadcast(const std::shared_ptr<ov::op::v0::Concat>& concat) { | ||
const auto& output = concat->output(0); | ||
const auto& input = concat->input(0); | ||
const auto& input_concat_dim = input.get_partial_shape()[concat->get_concatenation_axis()]; | ||
|
||
return input_concat_dim.is_static() && input_concat_dim.get_length() == 1 && output.get_partial_shape().is_static(); | ||
} | ||
|
||
ov::pass::ConcatToBroadcast::ConcatToBroadcast() { | ||
MATCHER_SCOPE(ConcatToBroadcast); | ||
|
||
auto concat_label = pattern::wrap_type<op::v0::Concat>([](const Output<Node>& value) { | ||
auto node = value.get_node_shared_ptr(); | ||
if (node->output(0).get_partial_shape().rank().is_dynamic()) { | ||
return false; | ||
} | ||
|
||
auto first_input_source_output = node->get_input_source_output(0); | ||
if (first_input_source_output.get_partial_shape().rank().is_dynamic()) { | ||
return false; | ||
} | ||
|
||
const auto& input_values = node->input_values(); | ||
|
||
return std::all_of(input_values.cbegin(), input_values.cend(), [&](const ov::Output<Node>& output) { | ||
return first_input_source_output == output; | ||
}); | ||
}); | ||
|
||
matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](pattern::Matcher& m) { | ||
const auto& pattern_map = m.get_pattern_value_map(); | ||
|
||
auto root_node = pattern_map.at(concat_label).get_node_shared_ptr(); | ||
auto concat = std::dynamic_pointer_cast<op::v0::Concat>(root_node); | ||
if (!concat) { | ||
return false; | ||
} | ||
|
||
if (transformation_callback(concat)) { | ||
return false; | ||
} | ||
|
||
const auto& input = concat->input_value(0); | ||
|
||
std::shared_ptr<Node> replacement; | ||
if (use_broadcast(concat)) { | ||
auto target_shape = std::make_shared<ov::op::v0::Constant>(ov::element::i32, | ||
Shape{concat->output(0).get_shape().size()}, | ||
concat->output(0).get_shape()); | ||
replacement = std::make_shared<ov::op::v3::Broadcast>(input, target_shape); | ||
} else { | ||
return false; | ||
} | ||
|
||
/* Common case (converting to Tile) causes an issue in e2e test with unknown root cause (ticket: 142246) | ||
else { | ||
std::vector<size_t> repeat_num_vec(concat->output(0).get_partial_shape().rank().get_length(), 1); | ||
repeat_num_vec[concat->get_concatenation_axis()] = concat->get_input_size(); | ||
auto repeat_num = | ||
std::make_shared<ov::op::v0::Constant>(ov::element::i32, Shape{repeat_num_vec.size()}, repeat_num_vec); | ||
replacement = std::make_shared<ov::op::v0::Tile>(input, repeat_num); | ||
} | ||
*/ | ||
|
||
replacement->set_friendly_name(concat->get_friendly_name()); | ||
ov::replace_node(concat, replacement); | ||
|
||
ov::copy_runtime_info(concat, replacement); | ||
|
||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<pattern::Matcher>(concat_label, matcher_name); | ||
this->register_matcher(m, callback); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
93 changes: 93 additions & 0 deletions
93
src/common/transformations/tests/common_optimizations/concat_to_broadcast_test.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/common_optimizations/concat_to_broadcast.hpp" | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include "openvino/core/node_vector.hpp" | ||
#include "openvino/core/shape.hpp" | ||
#include "openvino/op/broadcast.hpp" | ||
#include "openvino/op/concat.hpp" | ||
#include "openvino/op/tile.hpp" | ||
#include "openvino/pass/manager.hpp" | ||
|
||
using namespace testing; | ||
|
||
enum class ExpectedType { | ||
Broadcast, | ||
Tile, | ||
Concat, | ||
}; | ||
|
||
using ConcatToBroadcastParams = std::tuple<ov::PartialShape, size_t, size_t, ExpectedType>; | ||
|
||
class ConcatToBroadcastTest : public WithParamInterface<ConcatToBroadcastParams>, public testing::Test { | ||
protected: | ||
void SetUp() override { | ||
std::tie(data_shape, concat_num_inputs, concat_axis, expected_type) = GetParam(); | ||
} | ||
|
||
ov::PartialShape data_shape; | ||
size_t concat_num_inputs; | ||
ExpectedType expected_type; | ||
size_t concat_axis; | ||
}; | ||
|
||
INSTANTIATE_TEST_SUITE_P( | ||
type_prop, | ||
ConcatToBroadcastTest, | ||
Values(ConcatToBroadcastParams({2, 1}, 10, 1, ExpectedType::Broadcast), | ||
ConcatToBroadcastParams({1, 2, 1, 2}, 2604, 0, ExpectedType::Broadcast), | ||
ConcatToBroadcastParams(ov::PartialShape::dynamic(), 2604, 0, ExpectedType::Concat), | ||
// Common case (converting to Tile) causes an issue in e2e test with unknown root cause (ticket: 142246) | ||
// The following tests cover Tile cases, but temporary Concat remains in the graph. | ||
ConcatToBroadcastParams({1, 2, 1, 2}, 2604, 1, ExpectedType::Concat), | ||
ConcatToBroadcastParams({-1, 2, 1, 2}, 2604, 0, ExpectedType::Concat), | ||
ConcatToBroadcastParams({-1, -1, -1, -1}, 2604, 0, ExpectedType::Concat))); | ||
|
||
TEST_P(ConcatToBroadcastTest, TestTransfromationExecuted) { | ||
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, data_shape); | ||
std::vector<ov::Output<ov::Node>> concat_inputs(concat_num_inputs, param->output(0)); | ||
|
||
auto concat = std::make_shared<ov::op::v0::Concat>(concat_inputs, concat_axis); | ||
auto result = std::make_shared<ov::op::v0::Result>(concat->output(0)); | ||
|
||
auto model = std::make_shared<ov::Model>(ov::ResultVector{result}, ov::ParameterVector{param}); | ||
|
||
ov::pass::Manager manager; | ||
manager.register_pass<ov::pass::ConcatToBroadcast>(); | ||
manager.run_passes(model); | ||
|
||
const auto& ops = model->get_ordered_ops(); | ||
|
||
size_t tile_count = 0; | ||
size_t broadcast_count = 0; | ||
size_t concat_count = 0; | ||
|
||
for (auto& op : ops) { | ||
std::cout << op << std::endl; | ||
if (std::dynamic_pointer_cast<ov::op::v3::Broadcast>(op)) { | ||
++broadcast_count; | ||
} else if (std::dynamic_pointer_cast<ov::op::v0::Tile>(op)) { | ||
++tile_count; | ||
} else if (std::dynamic_pointer_cast<ov::op::v0::Concat>(op)) { | ||
++concat_count; | ||
} | ||
} | ||
|
||
if (expected_type == ExpectedType::Broadcast) { | ||
ASSERT_EQ(broadcast_count, 1); | ||
ASSERT_EQ(tile_count, 0); | ||
ASSERT_EQ(concat_count, 0); | ||
} else if (expected_type == ExpectedType::Tile) { | ||
ASSERT_EQ(broadcast_count, 0); | ||
ASSERT_EQ(tile_count, 1); | ||
ASSERT_EQ(concat_count, 0); | ||
} else if (expected_type == ExpectedType::Concat) { | ||
ASSERT_EQ(broadcast_count, 0); | ||
ASSERT_EQ(tile_count, 0); | ||
ASSERT_EQ(concat_count, 1); | ||
} | ||
} |