Skip to content

Commit

Permalink
Support new operations in TS: Selu, Swish, HSwish, Tile, CumSum, Hard…
Browse files Browse the repository at this point in the history
…Sigmoid (openvinotoolkit#19990)

* add new operations as unary

* get unary as input(0) instead of iterating pattern map

* add CumSum + unit tests

* add Tile + unit tests

* add tile

* fix ts_tile

* code review fix: use ADD_MATCHER

* fix bug CI tests
  • Loading branch information
evkotov authored Oct 10, 2023
1 parent a5b6606 commit 4d9f2f3
Show file tree
Hide file tree
Showing 9 changed files with 605 additions and 49 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations/transpose_sinking/ts_base.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {
namespace transpose_sinking {

class TRANSFORMATIONS_API TSCumSumForward;
class TRANSFORMATIONS_API TSCumSumBackward;

} // namespace transpose_sinking
} // namespace pass
} // namespace ov

/**
* @ingroup ie_transformation_common_api
* @brief TSCumSumForward transformation sinks Transpose through CumSum in the forward direction.
*/
class ov::pass::transpose_sinking::TSCumSumForward : public ov::pass::transpose_sinking::TSForwardBase {
public:
OPENVINO_RTTI("ov::pass::TSBinaryForward", "0");
TSCumSumForward();
};

/**
* @ingroup ie_transformation_common_api
* @brief TSCumSumBackward transformation sinks Transpose through CumSum in the backward direction.
*/
class ov::pass::transpose_sinking::TSCumSumBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TSBinaryBackward", "0");
TSCumSumBackward();
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations/transpose_sinking/ts_base.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {
namespace transpose_sinking {

class TRANSFORMATIONS_API TSTileForward;
class TRANSFORMATIONS_API TSTileBackward;

} // namespace transpose_sinking
} // namespace pass
} // namespace ov

/**
* @ingroup ie_transformation_common_api
* @brief TSTileForward transformation sinks Transpose through Tile in the forward direction.
*/
class ov::pass::transpose_sinking::TSTileForward : public ov::pass::transpose_sinking::TSForwardBase {
public:
OPENVINO_RTTI("ov::pass::TSBinaryForward", "0");
TSTileForward();
};

/**
* @ingroup ie_transformation_common_api
* @brief TSTileBackward transformation sinks Transpose through Tile in the backward direction.
*/
class ov::pass::transpose_sinking::TSTileBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TSBinaryBackward", "0");
TSTileBackward();
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/transpose_sinking/ts_cumsum.hpp"

#include "itt.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/cum_sum.hpp"
#include "openvino/op/fake_quantize.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"

using namespace ov;
using namespace ov::pass::pattern;
using namespace ov::pass::transpose_sinking;
using namespace ov::pass::transpose_sinking::utils;

#undef CUMSUM_AXIS_INPUT_IDX
#define CUMSUM_AXIS_INPUT_IDX 1

TSCumSumForward::TSCumSumForward() {
MATCHER_SCOPE(TSCumSumForward);

create_pattern<ov::op::v0::CumSum>(true, {0});

auto sinking_transformation = [=](const std::shared_ptr<Node>& main_node,
const TransposeInputsInfo& transpose_info) -> bool {
if (transformation_callback(main_node)) {
return false;
}

bool res = utils::sink_forward::UpdateInputTransposes(main_node, transpose_info, /* input_indexes= */ {0});
if (!res)
return res;

const auto transpose_axis_order = transpose_info.transpose_const->get_axis_vector_val();
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
const auto& new_axes = ChangeAxes(main_node->input_value(CUMSUM_AXIS_INPUT_IDX), transpose_axis_order, axis);
main_node->input(CUMSUM_AXIS_INPUT_IDX).replace_source_output(new_axes);

default_outputs_update(main_node, transpose_info);
return true;
};
transpose_sinking(matcher_name, sinking_transformation);
}

TSCumSumBackward::TSCumSumBackward() {
MATCHER_SCOPE(TSCumSumBackward);
auto main_node_label = wrap_type<ov::op::v0::CumSum>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && CheckTransposeConsumers(output);
});

auto transpose_const_label = wrap_type<ov::op::v0::Constant>();

auto transpose_label = wrap_type<ov::op::v1::Transpose>({main_node_label, transpose_const_label},
[](const Output<Node>& output) -> bool {
return has_static_rank()(output);
});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose_const =
as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();

if (transformation_callback(main_node)) {
return false;
}

for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
transpose_const,
/* input_indexes= */ {0})) {
register_new_node(new_node);
}

RemoveTransposeConsumers(main_node);
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
const auto reversed_transpose_order = ReverseTransposeOrder(transpose_axis_order);
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
auto new_axes = ChangeAxes(main_node->input_value(CUMSUM_AXIS_INPUT_IDX), reversed_transpose_order, axis);
main_node->input(CUMSUM_AXIS_INPUT_IDX).replace_source_output(new_axes);

main_node->validate_and_infer_types();
return true;
};
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "transformations/common_optimizations/enable_shapeof_constant_folding.hpp"
#include "transformations/transpose_sinking/ts_binary.hpp"
#include "transformations/transpose_sinking/ts_concat.hpp"
#include "transformations/transpose_sinking/ts_cumsum.hpp"
#include "transformations/transpose_sinking/ts_data_movement.hpp"
#include "transformations/transpose_sinking/ts_fuse.hpp"
#include "transformations/transpose_sinking/ts_gather.hpp"
Expand All @@ -23,6 +24,7 @@
#include "transformations/transpose_sinking/ts_slice.hpp"
#include "transformations/transpose_sinking/ts_split.hpp"
#include "transformations/transpose_sinking/ts_squeeze.hpp"
#include "transformations/transpose_sinking/ts_tile.hpp"
#include "transformations/transpose_sinking/ts_unary.hpp"
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
#include "transformations/utils/utils.hpp"
Expand All @@ -31,35 +33,40 @@ using namespace ov::pass::transpose_sinking;

TSGeneralForward::TSGeneralForward() {
MATCHER_SCOPE(TSGeneralForward);
add_matcher<TSUnaryForward>();
add_matcher<TSBinaryForward>();
add_matcher<TSConcatForward>();
add_matcher<TSSplitForward>();
add_matcher<TSDataMovementForward>();
add_matcher<TSReductionForward>();
add_matcher<TSSqueezeForward>();
add_matcher<TSUnsqueezeForward>();
add_matcher<TSInterpolateForward>();
add_matcher<TSSliceForward>();
add_matcher<TSGatherForward>();
add_matcher<TSShapeOfForward>();
add_matcher<TSFuse>();
ADD_MATCHER(this, TSUnaryForward);
ADD_MATCHER(this, TSBinaryForward);
ADD_MATCHER(this, TSConcatForward);
ADD_MATCHER(this, TSSplitForward);
ADD_MATCHER(this, TSDataMovementForward);
ADD_MATCHER(this, TSReductionForward);
ADD_MATCHER(this, TSSqueezeForward);
ADD_MATCHER(this, TSUnsqueezeForward);
ADD_MATCHER(this, TSInterpolateForward);
ADD_MATCHER(this, TSSliceForward);
ADD_MATCHER(this, TSGatherForward);
ADD_MATCHER(this, TSShapeOfForward);
ADD_MATCHER(this, TSCumSumForward);
ADD_MATCHER(this, TSTileForward);
ADD_MATCHER(this, TSFuse);
}

TSGeneralBackward::TSGeneralBackward() {
MATCHER_SCOPE(TSGeneralBackward);
add_matcher<TSUnaryBackward>();
add_matcher<TSBinaryBackward>();
add_matcher<TSConcatBackward>();
add_matcher<TSSplitBackward>();
add_matcher<TSDataMovementBackward>();
add_matcher<TSReductionBackward>();
add_matcher<TSSqueezeBackward>();
add_matcher<TSUnsqueezeBackward>();
add_matcher<TSInterpolateBackward>();
add_matcher<TSSliceBackward>();
add_matcher<TSGatherBackward>();
add_matcher<TSFuse>();
ADD_MATCHER(this, TSUnaryBackward);
ADD_MATCHER(this, TSUnaryBackward);
ADD_MATCHER(this, TSBinaryBackward);
ADD_MATCHER(this, TSConcatBackward);
ADD_MATCHER(this, TSSplitBackward);
ADD_MATCHER(this, TSDataMovementBackward);
ADD_MATCHER(this, TSReductionBackward);
ADD_MATCHER(this, TSSqueezeBackward);
ADD_MATCHER(this, TSUnsqueezeBackward);
ADD_MATCHER(this, TSInterpolateBackward);
ADD_MATCHER(this, TSSliceBackward);
ADD_MATCHER(this, TSGatherBackward);
ADD_MATCHER(this, TSCumSumBackward);
ADD_MATCHER(this, TSTileBackward);
ADD_MATCHER(this, TSFuse);
}

bool TSGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/transpose_sinking/ts_tile.hpp"

#include "itt.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/fake_quantize.hpp"
#include "openvino/op/tile.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"

using namespace ov;
using namespace ov::pass::pattern;
using namespace ov::pass::transpose_sinking;
using namespace ov::pass::transpose_sinking::utils;

#undef TILE_REPEATS_INPUT_IDX
#define TILE_REPEATS_INPUT_IDX 1

TSTileForward::TSTileForward() {
MATCHER_SCOPE(TSTileForward);

create_pattern<ov::op::v0::Tile>(true, {0});

auto sinking_transformation = [=](const std::shared_ptr<Node>& main_node,
const TransposeInputsInfo& transpose_info) -> bool {
if (transformation_callback(main_node)) {
return false;
}

bool res = utils::sink_forward::UpdateInputTransposes(main_node, transpose_info, /* input_indexes= */ {0});
if (!res)
return res;

const auto transpose_axis_order = transpose_info.transpose_const->get_axis_vector_val();
auto repeats = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
const auto& new_repeats =
ChangeValuesOrder(main_node->input_value(TILE_REPEATS_INPUT_IDX), transpose_axis_order, repeats);
main_node->input(TILE_REPEATS_INPUT_IDX).replace_source_output(new_repeats);

default_outputs_update(main_node, transpose_info);
return true;
};
transpose_sinking(matcher_name, sinking_transformation);
}

TSTileBackward::TSTileBackward() {
MATCHER_SCOPE(TSTileBackward);
auto main_node_label = wrap_type<ov::op::v0::Tile>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && CheckTransposeConsumers(output);
});

auto transpose_const_label = wrap_type<ov::op::v0::Constant>();

auto transpose_label = wrap_type<ov::op::v1::Transpose>({main_node_label, transpose_const_label},
[](const Output<Node>& output) -> bool {
return has_static_rank()(output);
});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose_const =
as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();

if (transformation_callback(main_node)) {
return false;
}

for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
transpose_const,
/* input_indexes= */ {0})) {
register_new_node(new_node);
}

RemoveTransposeConsumers(main_node);
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
auto repeats = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
auto new_repeats =
ChangeValuesOrder(main_node->input_value(TILE_REPEATS_INPUT_IDX), transpose_axis_order, repeats);
main_node->input(TILE_REPEATS_INPUT_IDX).replace_source_output(new_repeats);

main_node->validate_and_infer_types();
return true;
};
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
Loading

0 comments on commit 4d9f2f3

Please sign in to comment.