forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support new operations in TS: Selu, Swish, HSwish, Tile, CumSum, Hard…
…Sigmoid (openvinotoolkit#19990) * add new operations as unary * get unary as input(0) instead of iterating pattern map * add CumSum + unit tests * add Tile + unit tests * add tile * fix ts_tile * code review fix: use ADD_MATCHER * fix bug CI tests
- Loading branch information
Showing
9 changed files
with
605 additions
and
49 deletions.
There are no files selected for viewing
41 changes: 41 additions & 0 deletions
41
src/common/transformations/include/transformations/transpose_sinking/ts_cumsum.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
// Copyright (C) 2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/pass/graph_rewrite.hpp" | ||
#include "openvino/pass/pass.hpp" | ||
#include "transformations/transpose_sinking/ts_base.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
namespace transpose_sinking { | ||
|
||
class TRANSFORMATIONS_API TSCumSumForward; | ||
class TRANSFORMATIONS_API TSCumSumBackward; | ||
|
||
} // namespace transpose_sinking | ||
} // namespace pass | ||
} // namespace ov | ||
|
||
/** | ||
* @ingroup ie_transformation_common_api | ||
* @brief TSCumSumForward transformation sinks Transpose through CumSum in the forward direction. | ||
*/ | ||
class ov::pass::transpose_sinking::TSCumSumForward : public ov::pass::transpose_sinking::TSForwardBase { | ||
public: | ||
OPENVINO_RTTI("ov::pass::TSBinaryForward", "0"); | ||
TSCumSumForward(); | ||
}; | ||
|
||
/** | ||
* @ingroup ie_transformation_common_api | ||
* @brief TSCumSumBackward transformation sinks Transpose through CumSum in the backward direction. | ||
*/ | ||
class ov::pass::transpose_sinking::TSCumSumBackward : public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("ov::pass::TSBinaryBackward", "0"); | ||
TSCumSumBackward(); | ||
}; |
41 changes: 41 additions & 0 deletions
41
src/common/transformations/include/transformations/transpose_sinking/ts_tile.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
// Copyright (C) 2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/pass/graph_rewrite.hpp" | ||
#include "openvino/pass/pass.hpp" | ||
#include "transformations/transpose_sinking/ts_base.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
namespace transpose_sinking { | ||
|
||
class TRANSFORMATIONS_API TSTileForward; | ||
class TRANSFORMATIONS_API TSTileBackward; | ||
|
||
} // namespace transpose_sinking | ||
} // namespace pass | ||
} // namespace ov | ||
|
||
/** | ||
* @ingroup ie_transformation_common_api | ||
* @brief TSTileForward transformation sinks Transpose through Tile in the forward direction. | ||
*/ | ||
class ov::pass::transpose_sinking::TSTileForward : public ov::pass::transpose_sinking::TSForwardBase { | ||
public: | ||
OPENVINO_RTTI("ov::pass::TSBinaryForward", "0"); | ||
TSTileForward(); | ||
}; | ||
|
||
/** | ||
* @ingroup ie_transformation_common_api | ||
* @brief TSTileBackward transformation sinks Transpose through Tile in the backward direction. | ||
*/ | ||
class ov::pass::transpose_sinking::TSTileBackward : public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("ov::pass::TSBinaryBackward", "0"); | ||
TSTileBackward(); | ||
}; |
92 changes: 92 additions & 0 deletions
92
src/common/transformations/src/transformations/transpose_sinking/ts_cumsum.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
// Copyright (C) 2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/transpose_sinking/ts_cumsum.hpp" | ||
|
||
#include "itt.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/cum_sum.hpp" | ||
#include "openvino/op/fake_quantize.hpp" | ||
#include "openvino/op/transpose.hpp" | ||
#include "openvino/op/util/op_types.hpp" | ||
#include "openvino/pass/pattern/op/wrap_type.hpp" | ||
#include "transformations/rt_info/transpose_sinking_attr.hpp" | ||
#include "transformations/transpose_sinking/ts_utils.hpp" | ||
|
||
using namespace ov; | ||
using namespace ov::pass::pattern; | ||
using namespace ov::pass::transpose_sinking; | ||
using namespace ov::pass::transpose_sinking::utils; | ||
|
||
#undef CUMSUM_AXIS_INPUT_IDX | ||
#define CUMSUM_AXIS_INPUT_IDX 1 | ||
|
||
TSCumSumForward::TSCumSumForward() { | ||
MATCHER_SCOPE(TSCumSumForward); | ||
|
||
create_pattern<ov::op::v0::CumSum>(true, {0}); | ||
|
||
auto sinking_transformation = [=](const std::shared_ptr<Node>& main_node, | ||
const TransposeInputsInfo& transpose_info) -> bool { | ||
if (transformation_callback(main_node)) { | ||
return false; | ||
} | ||
|
||
bool res = utils::sink_forward::UpdateInputTransposes(main_node, transpose_info, /* input_indexes= */ {0}); | ||
if (!res) | ||
return res; | ||
|
||
const auto transpose_axis_order = transpose_info.transpose_const->get_axis_vector_val(); | ||
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0); | ||
const auto& new_axes = ChangeAxes(main_node->input_value(CUMSUM_AXIS_INPUT_IDX), transpose_axis_order, axis); | ||
main_node->input(CUMSUM_AXIS_INPUT_IDX).replace_source_output(new_axes); | ||
|
||
default_outputs_update(main_node, transpose_info); | ||
return true; | ||
}; | ||
transpose_sinking(matcher_name, sinking_transformation); | ||
} | ||
|
||
TSCumSumBackward::TSCumSumBackward() { | ||
MATCHER_SCOPE(TSCumSumBackward); | ||
auto main_node_label = wrap_type<ov::op::v0::CumSum>([](const Output<Node>& output) -> bool { | ||
return has_static_rank()(output) && CheckTransposeConsumers(output); | ||
}); | ||
|
||
auto transpose_const_label = wrap_type<ov::op::v0::Constant>(); | ||
|
||
auto transpose_label = wrap_type<ov::op::v1::Transpose>({main_node_label, transpose_const_label}, | ||
[](const Output<Node>& output) -> bool { | ||
return has_static_rank()(output); | ||
}); | ||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { | ||
const auto& pattern_to_output = m.get_pattern_value_map(); | ||
auto transpose_const = | ||
as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr()); | ||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); | ||
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr(); | ||
|
||
if (transformation_callback(main_node)) { | ||
return false; | ||
} | ||
|
||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, | ||
transpose_const, | ||
/* input_indexes= */ {0})) { | ||
register_new_node(new_node); | ||
} | ||
|
||
RemoveTransposeConsumers(main_node); | ||
const auto transpose_axis_order = transpose_const->get_axis_vector_val(); | ||
const auto reversed_transpose_order = ReverseTransposeOrder(transpose_axis_order); | ||
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0); | ||
auto new_axes = ChangeAxes(main_node->input_value(CUMSUM_AXIS_INPUT_IDX), reversed_transpose_order, axis); | ||
main_node->input(CUMSUM_AXIS_INPUT_IDX).replace_source_output(new_axes); | ||
|
||
main_node->validate_and_infer_types(); | ||
return true; | ||
}; | ||
auto m = std::make_shared<Matcher>(transpose_label, matcher_name); | ||
register_matcher(m, matcher_pass_callback); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
93 changes: 93 additions & 0 deletions
93
src/common/transformations/src/transformations/transpose_sinking/ts_tile.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
// Copyright (C) 2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/transpose_sinking/ts_tile.hpp" | ||
|
||
#include "itt.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/fake_quantize.hpp" | ||
#include "openvino/op/tile.hpp" | ||
#include "openvino/op/transpose.hpp" | ||
#include "openvino/op/util/op_types.hpp" | ||
#include "openvino/pass/pattern/op/wrap_type.hpp" | ||
#include "transformations/rt_info/transpose_sinking_attr.hpp" | ||
#include "transformations/transpose_sinking/ts_utils.hpp" | ||
|
||
using namespace ov; | ||
using namespace ov::pass::pattern; | ||
using namespace ov::pass::transpose_sinking; | ||
using namespace ov::pass::transpose_sinking::utils; | ||
|
||
#undef TILE_REPEATS_INPUT_IDX | ||
#define TILE_REPEATS_INPUT_IDX 1 | ||
|
||
TSTileForward::TSTileForward() { | ||
MATCHER_SCOPE(TSTileForward); | ||
|
||
create_pattern<ov::op::v0::Tile>(true, {0}); | ||
|
||
auto sinking_transformation = [=](const std::shared_ptr<Node>& main_node, | ||
const TransposeInputsInfo& transpose_info) -> bool { | ||
if (transformation_callback(main_node)) { | ||
return false; | ||
} | ||
|
||
bool res = utils::sink_forward::UpdateInputTransposes(main_node, transpose_info, /* input_indexes= */ {0}); | ||
if (!res) | ||
return res; | ||
|
||
const auto transpose_axis_order = transpose_info.transpose_const->get_axis_vector_val(); | ||
auto repeats = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0); | ||
const auto& new_repeats = | ||
ChangeValuesOrder(main_node->input_value(TILE_REPEATS_INPUT_IDX), transpose_axis_order, repeats); | ||
main_node->input(TILE_REPEATS_INPUT_IDX).replace_source_output(new_repeats); | ||
|
||
default_outputs_update(main_node, transpose_info); | ||
return true; | ||
}; | ||
transpose_sinking(matcher_name, sinking_transformation); | ||
} | ||
|
||
TSTileBackward::TSTileBackward() { | ||
MATCHER_SCOPE(TSTileBackward); | ||
auto main_node_label = wrap_type<ov::op::v0::Tile>([](const Output<Node>& output) -> bool { | ||
return has_static_rank()(output) && CheckTransposeConsumers(output); | ||
}); | ||
|
||
auto transpose_const_label = wrap_type<ov::op::v0::Constant>(); | ||
|
||
auto transpose_label = wrap_type<ov::op::v1::Transpose>({main_node_label, transpose_const_label}, | ||
[](const Output<Node>& output) -> bool { | ||
return has_static_rank()(output); | ||
}); | ||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { | ||
const auto& pattern_to_output = m.get_pattern_value_map(); | ||
auto transpose_const = | ||
as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr()); | ||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); | ||
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr(); | ||
|
||
if (transformation_callback(main_node)) { | ||
return false; | ||
} | ||
|
||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, | ||
transpose_const, | ||
/* input_indexes= */ {0})) { | ||
register_new_node(new_node); | ||
} | ||
|
||
RemoveTransposeConsumers(main_node); | ||
const auto transpose_axis_order = transpose_const->get_axis_vector_val(); | ||
auto repeats = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0); | ||
auto new_repeats = | ||
ChangeValuesOrder(main_node->input_value(TILE_REPEATS_INPUT_IDX), transpose_axis_order, repeats); | ||
main_node->input(TILE_REPEATS_INPUT_IDX).replace_source_output(new_repeats); | ||
|
||
main_node->validate_and_infer_types(); | ||
return true; | ||
}; | ||
auto m = std::make_shared<Matcher>(transpose_label, matcher_name); | ||
register_matcher(m, matcher_pass_callback); | ||
} |
Oops, something went wrong.