From f84484dd8a71b83be8fb2454ee7d2d65d7bf7324 Mon Sep 17 00:00:00 2001 From: Evgeniia Nugmanova Date: Tue, 10 Oct 2023 15:30:51 +0400 Subject: [PATCH] De-Reshape MatMul --- .../dereshape_matmul.hpp | 67 +++ .../symbolic_optimizations.hpp | 8 + .../symbolic_transformations/utils.hpp | 10 + .../dereshape_matmul.cpp | 335 +++++++++++++++ .../symbolic_optimizations.cpp | 69 ++- .../symbolic_transformations/utils.cpp | 17 + .../dereshape_matmul.cpp | 399 ++++++++++++++++++ 7 files changed, 904 insertions(+), 1 deletion(-) create mode 100644 src/common/transformations/include/transformations/symbolic_transformations/dereshape_matmul.hpp create mode 100644 src/common/transformations/src/transformations/symbolic_transformations/dereshape_matmul.cpp create mode 100644 src/common/transformations/tests/symbolic_transformations/dereshape_matmul.cpp diff --git a/src/common/transformations/include/transformations/symbolic_transformations/dereshape_matmul.hpp b/src/common/transformations/include/transformations/symbolic_transformations/dereshape_matmul.hpp new file mode 100644 index 00000000000000..90df7a04206466 --- /dev/null +++ b/src/common/transformations/include/transformations/symbolic_transformations/dereshape_matmul.hpp @@ -0,0 +1,67 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace ov { +namespace pass { +class TRANSFORMATIONS_API DeReshapeMatMul; +} // namespace pass +} // namespace ov + +/** + * @ingroup ie_transformation_common_api + * @brief Transformation uses symbol / label information to optimize out Reshape operations surrounding MatMul. + * It checks that surrounding Reshapes are only manipulating with batch dimensions of tensor in a do-undo kind of way. + * + * Example: + * Before: + * [A,B,C,D] -> Reshape -> [A*B,C,D] + * MatMul [A*B,C,E] -> Reshape -> [A,B,C,E] + * [A,B,D,E] -> Reshape -> [A*B,D,E] + * + * After: + * [A,B,C,D] -> + * MatMul -> [A,B,C,E] + * [A,B,D,E] -> + * + * Transformation allows slightly different variations of the pattern on inputs of MatMul. + * - Simplest pattern contains only Reshape operation on MatMul input: + * Reshape -> MatMul + * + * - The next acceptable variation is Concat of two inputs on MatMul input: + * Reshape -[-> Concat -]-> MatMul + * This variation would be transformed with realignment of the other input of Concat and the other outputs of + * Concat with the help of Reshape operations + * + * - The most complex variation on the MatMul input pattern is with Binary Elementwise Operation with scalar second + * input: Reshape -[-> Concat -]-[-> BEA (scalar) -]-> MatMul + * + * Additionally, transformation supports variation of the pattern on output of MatMul. It allows for + * Binary Elementwise Arithmetic operation without second input scalar restriction. + * MatMul -[-> BEA -]-> Reshape + * this pattern variation is only applicable for the case when input reshapes are 4D -> 3D and output reshape is 3D -> + * 4D. Additionally, shape labels on output of MatMul should be equal to the input shape labels of the last Reshape, + * meaning that this Binary Elementwise Arithmetic doesn't perform any broadcasting of input coming from MatMul -- only + * other input may be broadcasted to the MatMul input of this BEA. This effect (equality of MatMul output shape labels + * and output shape of BEA) is being handled by LabelResolvingThroughSelect transformation in the particular models that + * this variation targets. + * + * Full pattern this transformation searches for: + * -> Reshape -[-> Concat -]-[-> BEA (scalar) -]-> + * MatMul -[-> BEA -]-> Reshape -> + * -> Reshape -[-> Concat -]-[-> BEA (scalar) -]-> + * + * NOTE: input branches could be (and in observed model cases are) asymmetrical, meaning that the presence of Concat + * on one input of MatMul doesn't require the other input to also have Concat + */ +class ov::pass::DeReshapeMatMul : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("DeReshapeMatMul", "0"); + DeReshapeMatMul(); +}; diff --git a/src/common/transformations/include/transformations/symbolic_transformations/symbolic_optimizations.hpp b/src/common/transformations/include/transformations/symbolic_transformations/symbolic_optimizations.hpp index 1cf3cf9577dc78..71e234cfeabd29 100644 --- a/src/common/transformations/include/transformations/symbolic_transformations/symbolic_optimizations.hpp +++ b/src/common/transformations/include/transformations/symbolic_transformations/symbolic_optimizations.hpp @@ -14,6 +14,7 @@ namespace ov { namespace pass { class TRANSFORMATIONS_API SymbolicOptimizations; class TRANSFORMATIONS_API SymbolicPropagation; +class TRANSFORMATIONS_API LabelResolvingThroughSelect; } // namespace pass } // namespace ov @@ -48,3 +49,10 @@ class ov::pass::SymbolicPropagation : public ov::pass::ModelPass { private: std::shared_ptr m_te; }; + +// TODO: add description and order +class ov::pass::LabelResolvingThroughSelect : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("LabelResolvingThroughSelect", "0"); + LabelResolvingThroughSelect(); +}; \ No newline at end of file diff --git a/src/common/transformations/include/transformations/symbolic_transformations/utils.hpp b/src/common/transformations/include/transformations/symbolic_transformations/utils.hpp index 2f3d84dfe825ff..8d6e927e25a995 100644 --- a/src/common/transformations/include/transformations/symbolic_transformations/utils.hpp +++ b/src/common/transformations/include/transformations/symbolic_transformations/utils.hpp @@ -38,6 +38,16 @@ TRANSFORMATIONS_API bool get_labels(const ov::Output& output, ov::Tens /// /// \return true if labels are unique and equal between lhs and rhs else false TRANSFORMATIONS_API bool are_unique_and_equal_labels(const ov::TensorLabel& lhs, const ov::TensorLabel& rhs); + +/// \brief Compares dimensions: if dimensions are static compares values of dimensions, if dimensions are dynamic +/// compares their respective labels using TableOfEquivalence +/// +/// \param lhs Dimension object to compare +/// \param rhs Dimension object to compare +/// +/// \return true if static dimensions are equal and dynamic dimensions have equal labels else false +TRANSFORMATIONS_API bool dims_are_equal(const ov::Dimension& lhs, const ov::Dimension& rhs); + } // namespace util } // namespace symbol } // namespace ov diff --git a/src/common/transformations/src/transformations/symbolic_transformations/dereshape_matmul.cpp b/src/common/transformations/src/transformations/symbolic_transformations/dereshape_matmul.cpp new file mode 100644 index 00000000000000..be866fbc7b1c4f --- /dev/null +++ b/src/common/transformations/src/transformations/symbolic_transformations/dereshape_matmul.cpp @@ -0,0 +1,335 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/symbolic_transformations/dereshape_matmul.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "itt.hpp" +#include "openvino/core/validation_util.hpp" +#include "transformations/utils/utils.hpp" + +using namespace ov::symbol::util; + +namespace { +bool concat_predicate(ov::Output output) { + auto output_pshape = output.get_partial_shape(); + if (output_pshape.rank().is_dynamic() || output_pshape.size() <= 2) + return false; + const auto& concat = ov::as_type_ptr(output.get_node_shared_ptr()); + if (!concat) + return false; + return concat->get_concatenation_axis() >= output_pshape.rank().get_length() - 2; +} + +bool last_two_dims_are_equal(const ov::PartialShape& lhs, const ov::PartialShape& rhs) { + if (lhs.rank().is_dynamic() || lhs.size() < 2) + return false; + if (rhs.rank().is_dynamic() || rhs.size() < 2) + return false; + for (size_t i = 2; i > 0; --i) + if (!dims_are_equal(lhs[lhs.size() - i], rhs[rhs.size() - i])) + return false; + return true; +} + +bool reshape_keeps_last_two_dims(const std::shared_ptr& op) { + return last_two_dims_are_equal(op->get_input_partial_shape(0), op->get_output_partial_shape(0)); +} + +bool batches_are_equal(const ov::PartialShape& lhs, const ov::PartialShape& rhs, bool one_dim_can_differ = false) { + if (lhs.rank().is_dynamic() || rhs.rank().is_dynamic() || lhs.size() != rhs.size()) + return false; + size_t num_dims_differ = 0; + for (size_t i = 0; i < lhs.size() - 2; ++i) + num_dims_differ += !dims_are_equal(lhs[i], rhs[i]); + return num_dims_differ <= one_dim_can_differ; +} + +bool batches_are_equal(const std::shared_ptr& op_0, const std::shared_ptr& op_1) { + auto input_0 = op_0->get_input_partial_shape(0); + auto input_1 = op_1->get_input_partial_shape(0); + auto output_0 = op_0->get_output_partial_shape(0); + auto output_1 = op_1->get_output_partial_shape(0); + return batches_are_equal(input_0, input_1, true) && batches_are_equal(output_0, output_1); +} + +ov::Output get_shape_from_sources(const ov::Output& batch_dims_source, + const ov::Output& non_batch_dims_source, + const std::vector>& copy_rt_info_from) { + ov::NodeVector dims; + size_t num_batch_dims = batch_dims_source.get_partial_shape().size() - 2; + std::vector non_constant_ids; + for (size_t i = 0; i < num_batch_dims; ++i) { + auto node = ov::op::util::node_to_get_shape_value_of_indices_from_shape_source(batch_dims_source, + {i}, + copy_rt_info_from); + OPENVINO_SUPPRESS_DEPRECATED_START + if (auto constant = ov::get_constant_from_source(node)) { + OPENVINO_SUPPRESS_DEPRECATED_END + node = constant; + } else { + non_constant_ids.push_back(i); + } + dims.push_back(node); + } + if (non_constant_ids.size() == 1) { + dims[non_constant_ids[0]] = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1}); + } + + size_t non_batch_dims_start = non_batch_dims_source.get_partial_shape().size() - 2; + for (size_t i = non_batch_dims_start; i < non_batch_dims_start + 2; ++i) { + auto node = ov::op::util::node_to_get_shape_value_of_indices_from_shape_source(non_batch_dims_source, + {i}, + copy_rt_info_from); + OPENVINO_SUPPRESS_DEPRECATED_START + if (auto constant = ov::get_constant_from_source(node)) { + OPENVINO_SUPPRESS_DEPRECATED_END + node = constant; + } + dims.push_back(node); + } + + for (size_t curr_i = 1; curr_i < dims.size(); ++curr_i) { + const auto& curr_node = dims[curr_i]; + if (bool current_node_is_constant = ov::op::util::is_constant(curr_node)) { + size_t prev_i = curr_i - 1; + const auto& prev_node = dims[prev_i]; + if (bool previous_node_exists_and_is_constant = prev_node && ov::op::util::is_constant(prev_node)) { + dims[curr_i] = ov::op::util::make_try_fold(ov::NodeVector{prev_node, curr_node}, 0); + dims[prev_i] = nullptr; + } + } + } + dims.erase(std::remove_if(dims.begin(), + dims.end(), + [](const std::shared_ptr& node) { + return node == nullptr; + }), + dims.end()); + auto target_shape = ov::op::util::make_try_fold(dims, 0); + ov::copy_runtime_info(copy_rt_info_from, target_shape); + return target_shape->output(0); +} + +void pull_reshape_through_optional_concat_and_bea(const ov::pass::pattern::PatternValueMap& vm, + std::shared_ptr concat_label, + std::shared_ptr bea_label, + ov::Output reshape_output, + ov::Input matmul_input, + std::vector& nodes_for_revalidation) { + // Reshape -- [Concat] -- [BEA with scalar] -- > MatMul + auto original_reshape = reshape_output.get_node_shared_ptr(); + if (vm.count(concat_label)) { + auto concat_node = ov::as_type_ptr(vm.at(concat_label).get_node_shared_ptr()); + OPENVINO_ASSERT(concat_node != nullptr, + "DeReshapeMatMul transformation matched operation which should be Concat -- but it is not"); + auto rank = concat_node->get_output_partial_shape(0).rank().get_length(); + auto axis = (concat_node->get_concatenation_axis() == (rank - 1)) ? -1 : -2; + + auto idx_of_reshape_input = reshape_output == concat_node->input_value(0) ? 0 : 1; + auto idx_of_non_reshape_input = static_cast(!idx_of_reshape_input); + + auto target_shape_of_input = get_shape_from_sources(original_reshape->input_value(0), + concat_node->input_value(idx_of_non_reshape_input), + {original_reshape}); + + auto input_reshape = original_reshape->clone_with_new_inputs( + {concat_node->input_value(idx_of_non_reshape_input), target_shape_of_input}); + ov::copy_runtime_info(original_reshape, input_reshape); + + ov::replace_output_update_name(reshape_output, original_reshape->input_value(0)); + + ov::OutputVector new_concat_inputs(2); + new_concat_inputs[idx_of_reshape_input] = concat_node->input_value(idx_of_reshape_input); + new_concat_inputs[idx_of_non_reshape_input] = input_reshape->output(0); + + auto new_concat = std::make_shared(new_concat_inputs, axis); + ov::copy_runtime_info({concat_node, original_reshape}, new_concat); + + auto target_shape_of_output = + get_shape_from_sources(input_reshape->input_value(0), new_concat->output(0), {original_reshape}); + auto output_reshape = original_reshape->clone_with_new_inputs({new_concat->output(0), target_shape_of_output}); + ov::copy_runtime_info(original_reshape, output_reshape); + + if (vm.count(bea_label)) { + auto bea_node = vm.at(bea_label).get_node_shared_ptr(); + auto idx_of_non_scalar_data = bea_node->input_value(0) == vm.at(concat_label) ? 0 : 1; + bea_node->input(idx_of_non_scalar_data).replace_source_output(new_concat); + nodes_for_revalidation.insert(nodes_for_revalidation.begin(), bea_node.get()); + } else { + matmul_input.replace_source_output(new_concat); + } + ov::replace_output_update_name(concat_node->output(0), output_reshape->output(0)); + } else { + // no Concat and it doesn't matter if BEA is present -- just delete reshape + ov::replace_output_update_name(reshape_output, original_reshape->input_value(0)); + } +} +} // namespace + +#define IN_RESHAPE \ + pattern::wrap_type(pattern::op::as_value_predicate([](std::shared_ptr n) -> bool { \ + return pattern::consumers_count(1)(n->output(0)) && reshape_keeps_last_two_dims(n); \ + })); + +#define SCALAR_INPUT \ + pattern::any_input([](ov::Output out) { \ + return out.get_partial_shape().is_static() && ov::shape_size(out.get_shape()) == 1; \ + }); + +ov::pass::DeReshapeMatMul::DeReshapeMatMul() { + MATCHER_SCOPE(DeReshapeMatMul); + // BEGIN: symmetrical patterns for MatMul inputs + + // lhs of MatMul + auto lhs_reshape = IN_RESHAPE; + + auto lhs_concat_0 = pattern::wrap_type({pattern::any_input(), lhs_reshape}, concat_predicate); + auto lhs_concat_1 = pattern::wrap_type({lhs_reshape, pattern::any_input()}, concat_predicate); + auto lhs_concat = std::make_shared(OutputVector{lhs_concat_0, lhs_concat_1}); + + auto lhs_reshape_or_concat = std::make_shared(OutputVector{lhs_reshape, lhs_concat}); + + auto lhs_bea_scalar = SCALAR_INPUT; + auto lhs_bea = pattern::wrap_type({lhs_reshape_or_concat, lhs_bea_scalar}, + pattern::consumers_count(1)); + + auto lhs_bea_or_concat = std::make_shared(OutputVector{lhs_reshape_or_concat, lhs_bea}); + + // rhs of MatMul + auto rhs_reshape = IN_RESHAPE; + + auto rhs_concat_0 = pattern::wrap_type({pattern::any_input(), rhs_reshape}, concat_predicate); + auto rhs_concat_1 = pattern::wrap_type({rhs_reshape, pattern::any_input()}, concat_predicate); + auto rhs_concat = std::make_shared(OutputVector{rhs_concat_0, rhs_concat_1}); + + auto rhs_reshape_or_concat = std::make_shared(OutputVector{rhs_reshape, rhs_concat}); + + auto rhs_bea_scalar = SCALAR_INPUT; + auto rhs_bea = pattern::wrap_type({rhs_reshape_or_concat, rhs_bea_scalar}, + pattern::consumers_count(1)); + + auto rhs_bea_or_concat = std::make_shared(OutputVector{rhs_reshape_or_concat, rhs_bea}); + // END: symmetrical patterns for MatMul inputs + + auto matmul = + pattern::wrap_type({lhs_bea_or_concat, rhs_bea_or_concat}, pattern::consumers_count(1)); + + auto add = pattern::wrap_type( + OutputVector{matmul, pattern::any_input()}, + [](ov::Output out) -> bool { + if (!pattern::consumers_count(1)(out)) + return false; + auto input_0_pshape = out.get_node_shared_ptr()->get_input_partial_shape(0); + auto input_1_pshape = out.get_node_shared_ptr()->get_input_partial_shape(1); + auto output_pshape = out.get_partial_shape(); + ov::TensorLabel output_labels, input_0_labels, input_1_labels; + if (get_labels(input_0_pshape, input_0_labels) && get_labels(input_1_pshape, input_1_labels) && + get_labels(output_pshape, output_labels)) { + if (input_0_pshape.size() != 3 || input_1_pshape.size() != 3 || output_pshape.size() != 3) + return false; + return are_unique_and_equal_labels(input_0_labels, output_labels) || + are_unique_and_equal_labels(input_1_labels, output_labels); + } else { + return false; + } + }); + + auto matmul_or_add = std::make_shared(OutputVector{matmul, add}); + auto final_reshape = + pattern::wrap_type({matmul_or_add, pattern::any_input()}, + pattern::op::as_value_predicate([](std::shared_ptr n) -> bool { + return reshape_keeps_last_two_dims(n); + })); + + ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) { + const auto& pm = m.get_pattern_map(); + const auto& vm = m.get_pattern_value_map(); + std::vector nodes_for_revalidation{pm.at(matmul).get()}; + // reshapes check: BEGIN + // reshape_keeps_last_two_dims checks were already applied for all Reshapes in the pattern predicates + auto in_reshape_0 = pm.at(lhs_reshape); + auto in_reshape_1 = pm.at(rhs_reshape); + auto out_reshape = pm.at(final_reshape); + if (!batches_are_equal(in_reshape_0, in_reshape_1) || + !batches_are_equal(in_reshape_0->get_output_partial_shape(0), out_reshape->get_input_partial_shape(0)) || + !batches_are_equal(in_reshape_0->get_input_partial_shape(0), + out_reshape->get_output_partial_shape(0), + true)) { + return false; + } + // reshapes check: END + + if (vm.count(add)) { + const auto& in_reshape_0_in_pshape = in_reshape_0->get_input_partial_shape(0); + if (in_reshape_0_in_pshape.size() != 4 || in_reshape_0_in_pshape[1].is_dynamic()) + return false; + // we only allow MatMul -> Add pattern to be optimized in case of 4d -> 3d -> 4d DeReshaping + } + + // preventing wrong matches + if (vm.count(lhs_concat) && !ov::as_type_ptr(pm.at(lhs_concat))) + return false; + if (vm.count(rhs_concat) && !ov::as_type_ptr(pm.at(rhs_concat))) + return false; + + pull_reshape_through_optional_concat_and_bea(vm, + lhs_concat, + lhs_bea, + in_reshape_0, + pm.at(matmul)->input(0), + nodes_for_revalidation); + pull_reshape_through_optional_concat_and_bea(vm, + rhs_concat, + rhs_bea, + in_reshape_1, + pm.at(matmul)->input(1), + nodes_for_revalidation); + + for (auto& node : nodes_for_revalidation) + node->validate_and_infer_types(); + + if (vm.count(add)) { + // TODO: make sure other elements of the shape are equal -- only those which aren't equal should be handled + auto add_node = pm.at(add); + size_t matmul_port = (add_node->input_value(0) == vm.at(matmul) ? 0 : 1); + size_t non_matmul_port = static_cast(!matmul_port); + + auto first_batch_dim = + ov::op::util::node_to_get_shape_value_of_indices_from_shape_source(add_node->input_value(matmul_port), + {0}, + {in_reshape_0, in_reshape_1}); + auto divisor = + ov::op::util::node_to_get_shape_value_of_indices_from_shape_source(in_reshape_0->input_value(0), + {1}, + {in_reshape_0, in_reshape_1}); + first_batch_dim = std::make_shared(first_batch_dim, divisor, true); + auto minus_one = ov::op::v0::Constant::create(element::i64, {1}, {-1}); + auto non_batch_dims = ov::op::util::node_to_get_shape_value_of_indices_from_shape_source( + add_node->input_value(non_matmul_port), + {1, 2}, + {in_reshape_0, in_reshape_1}); + auto pattern = + std::make_shared(OutputVector{first_batch_dim, minus_one, non_batch_dims}, 0); + auto other_input_reshape = + op::util::make_try_fold(add_node->input_value(non_matmul_port), pattern, true); + add_node->input(non_matmul_port).replace_source_output(other_input_reshape->output(0)); + ov::copy_runtime_info({in_reshape_0, in_reshape_1}, {first_batch_dim, minus_one, other_input_reshape}); + add_node->validate_and_infer_types(); + } + ov::replace_output_update_name(out_reshape->output(0), out_reshape->input_value(0)); + return true; + }; + + auto m = std::make_shared(final_reshape, matcher_name); + register_matcher(m, matcher_pass_callback); +} diff --git a/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp b/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp index 7451df397ba33c..4892f213f63871 100644 --- a/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp +++ b/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp @@ -6,16 +6,26 @@ #include #include +#include #include #include +#include +#include #include #include #include +#include #include +#include #include #include +#include #include "itt.hpp" +#include "openvino/pass/pattern/op/or.hpp" + +using namespace ov::pass; +using namespace ov::symbol::util; namespace { void symbolic_set_up_for_shape(ov::DimensionTracker& dt, ov::PartialShape& shape) { @@ -116,6 +126,60 @@ bool ov::pass::SymbolicPropagation::run_on_model(const std::shared_ptr(); + auto input_reshape = pattern::wrap_type({add, pattern::any_input()}); + + auto select_then = pattern::wrap_type({pattern::any_input(), input_reshape, pattern::any_input()}); + auto select_else = pattern::wrap_type({pattern::any_input(), pattern::any_input(), input_reshape}); + auto select = std::make_shared(OutputVector{select_then, select_else}); + + auto softmax = pattern::wrap_type({select}); + auto reshape = pattern::wrap_type({softmax, pattern::any_input()}); + + ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) { + const auto& value_map = m.get_pattern_value_map(); + ov::TensorLabel reshape_labels, add_0_labels, add_1_labels; + if (!get_labels(value_map.at(reshape).get_partial_shape(), reshape_labels)) + return false; + auto add_node = value_map.at(add).get_node_shared_ptr(); + auto add_0_pshape = add_node->input_value(0).get_partial_shape(); + auto add_1_pshape = add_node->input_value(1).get_partial_shape(); + if (!get_labels(add_0_pshape, add_0_labels) && !get_labels(add_1_pshape, add_1_labels)) + return false; + + if (are_unique_and_equal_labels(reshape_labels, add_0_labels)) { + // we detected that no broadcasting was done during binary elementwise and select, propagating labels + // through + add_node->set_output_type(0, add_node->get_output_element_type(0), add_0_pshape); + } else if (are_unique_and_equal_labels(reshape_labels, add_1_labels)) { + // we detected that no broadcasting was done during binary elementwise and select, propagating labels + // through + add_node->set_output_type(0, add_node->get_output_element_type(0), add_1_pshape); + } else { + return false; + } + + std::shared_ptr select_node = nullptr; + if (value_map.count(select_then)) + select_node = value_map.at(select_then).get_node_shared_ptr(); + if (value_map.count(select_else)) + select_node = value_map.at(select_else).get_node_shared_ptr(); + if (select_node == nullptr) + return false; + + auto select_output = select_node->output(0); + const auto& reshape_pshape = value_map.at(input_reshape).get_partial_shape(); + select_node->set_output_type(0, select_node->get_output_element_type(0), reshape_pshape); + value_map.at(softmax).get_node_shared_ptr()->validate_and_infer_types(); + return true; + }; + + auto m = std::make_shared(reshape, matcher_name); + register_matcher(m, matcher_pass_callback); +} + ov::pass::SymbolicOptimizations::SymbolicOptimizations(bool full_run) { m_manager = std::make_shared(); m_manager->set_per_pass_validation(false); @@ -134,7 +198,10 @@ ov::pass::SymbolicOptimizations::SymbolicOptimizations(bool full_run) { // transformations which use labels for optimizations REGISTER_SYMBOLIC(ApplyTableOfEquivalence) if (full_run) { - REGISTER_SYMBOLIC(OptimizeLabelsUsedAsValues) // reduce shape sub-graphs + REGISTER_SYMBOLIC(OptimizeLabelsUsedAsValues) // reduce shape sub-graphs + REGISTER_SYMBOLIC(LabelResolvingThroughSelect) // figures out that broadcasting didn't happen through Select op + REGISTER_SYMBOLIC(DeReshapeMatMul) + REGISTER_SYMBOLIC(SimplifyShapeOfSubGraph) } } diff --git a/src/common/transformations/src/transformations/symbolic_transformations/utils.cpp b/src/common/transformations/src/transformations/symbolic_transformations/utils.cpp index 3fedc3bd4c85be..32b572908c5fa5 100644 --- a/src/common/transformations/src/transformations/symbolic_transformations/utils.cpp +++ b/src/common/transformations/src/transformations/symbolic_transformations/utils.cpp @@ -32,3 +32,20 @@ bool ov::symbol::util::are_unique_and_equal_labels(const ov::TensorLabel& lhs, c return false; return true; } + +bool dims_are_equal(const ov::Dimension& lhs, const ov::Dimension& rhs) { + bool labels_exist_and_equal = false; + + auto lhs_label = ov::DimensionTracker::get_label(lhs); + auto rhs_label = ov::DimensionTracker::get_label(rhs); + auto table_l = ov::DimensionTracker::get_table_of_equivalence(lhs); + auto table_r = ov::DimensionTracker::get_table_of_equivalence(rhs); + if (table_l) + labels_exist_and_equal = lhs_label != ov::no_label && table_l->are_equal(lhs, rhs); + else if (table_r) + labels_exist_and_equal = lhs_label != ov::no_label && table_r->are_equal(lhs, rhs); + else + labels_exist_and_equal = lhs_label != ov::no_label && lhs_label == rhs_label; + bool dims_are_static_and_equal = lhs.is_static() && lhs == rhs; + return labels_exist_and_equal || dims_are_static_and_equal; +} diff --git a/src/common/transformations/tests/symbolic_transformations/dereshape_matmul.cpp b/src/common/transformations/tests/symbolic_transformations/dereshape_matmul.cpp new file mode 100644 index 00000000000000..6e87090e488316 --- /dev/null +++ b/src/common/transformations/tests/symbolic_transformations/dereshape_matmul.cpp @@ -0,0 +1,399 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/symbolic_transformations/dereshape_matmul.hpp" + +#include + +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" +#include "openvino/core/dimension_tracker.hpp" +#include "transformations/utils/utils.hpp" + +using namespace ov; +using namespace ov::op; +using namespace std; + +namespace { +/* Helps to organize dimension representation in the following tests: + * 1. Creates requested amount of dimensions + * 2. Labels them automatically + * 3. Creates value representation of the dimension via creating Parameter->Shape->Gather subgraph + * 4. Gives access to dimension and its value representation via operator[] + * 5. Gives access to utility Parameter via get_parameter -- only used for ov::Model creation in tests + * */ +class DimensionTestHelper { +public: + struct DimensionWithOutput { + Dimension dim; + Output source; + }; + + explicit DimensionTestHelper(const size_t& num_dims) { + auto te = make_shared(); + auto dt = ov::DimensionTracker(te); + auto dimensions = PartialShape::dynamic(Rank(num_dims)); + dt.set_up_for_tracking(dimensions); + parameter = make_shared(element::f32, dimensions); + for (size_t i = 0; i < num_dims; ++i) + m_map[i] = {dimensions[i], op::util::node_to_get_shape_value_of_indices_from_shape_source(parameter, {i})}; + } + + DimensionWithOutput operator[](size_t idx) const { + return m_map.at(idx); + } + + ov::PartialShape make_shape(const vector& dim_indices) const { + auto shape = PartialShape::dynamic(Rank(dim_indices.size())); + for (size_t i = 0; i < dim_indices.size(); ++i) + shape[i] = m_map.at(dim_indices[i]).dim; + return shape; + } + + shared_ptr make_reshape(const Output& source, const vector& dims_indices) const { + OutputVector sources(dims_indices.size()); + for (size_t i = 0; i < dims_indices.size(); ++i) + sources[i] = m_map.at(dims_indices[i]).source; + auto concat = make_shared(sources, 0); + return make_shared(source, concat, false); + } + + std::shared_ptr get_parameter() const { + return parameter; + } + +private: + std::shared_ptr parameter; + std::map m_map; +}; + +size_t max_element(const vector>& vectors) { + size_t current_max = 0; + for (const auto& vector : vectors) + current_max = max(current_max, *std::max_element(vector.begin(), vector.end())); + return current_max; +} + +shared_ptr reshape(const Output& source, + const vector& dims_indices, + const DimensionTestHelper& helper) { + OutputVector sources(dims_indices.size()); + for (size_t i = 0; i < dims_indices.size(); ++i) + sources[i] = helper[dims_indices[i]].source; + auto concat = make_shared(sources, 0); + return make_shared(source, concat, false); +} + +ov::Output get_shape_from_sources(const ov::Output& batch_dims_source, + const ov::Output& non_batch_dims_source) { + auto batch_indices = std::vector(batch_dims_source.get_partial_shape().size() - 2); + std::iota(batch_indices.begin(), batch_indices.end(), 0); + auto batch_dims = + ov::op::util::node_to_get_shape_value_of_indices_from_shape_source(batch_dims_source, batch_indices); + auto non_batch_indices = std::vector(2); + std::iota(non_batch_indices.begin(), non_batch_indices.end(), non_batch_dims_source.get_partial_shape().size() - 2); + auto non_batch_dims = + ov::op::util::node_to_get_shape_value_of_indices_from_shape_source(non_batch_dims_source, non_batch_indices); + auto target_shape = + ov::op::util::make_try_fold(ov::OutputVector{batch_dims, non_batch_dims}, 0); + return target_shape->output(0); +} + +PartialShape make_concat_input_pshape(const DimensionTestHelper& dims, const vector& dims_indices) { + auto another_pshape = dims.make_shape(dims_indices); + size_t rank = dims_indices.size(); + // To reduce test graph we avoid changing Concat axis dimension with this Concat + another_pshape[rank - 1] = Dimension(0); + return another_pshape; +} + +static std::ostream& operator<<(std::ostream& os, const vector& vals) { + bool first = true; + for (const auto& val : vals) { + if (!first) + os << "_"; + first = false; + os << val; + } + return os; +} +} // namespace + +using DeReshapeMatMulParameters = + tuple, vector, vector, vector, vector>, + size_t, + size_t, + size_t>; + +class DeReshapeMatMulTest : public TransformationTestsF, public testing::WithParamInterface { +public: + void SetUp() override { + TransformationTestsF::SetUp(); + const auto& params = std::get<0>(GetParam()); + + const auto& lhs_shape_idx = std::get<0>(params); + const auto& lhs_reshape_idx = std::get<1>(params); + const auto& rhs_shape_idx = std::get<2>(params); + const auto& rhs_reshape_idx = std::get<3>(params); + const auto& out_reshape_idx = std::get<4>(params); + + // 0 - no bea, 1 - lhs, 2 - rhs, 3 - lhs and rhs + const size_t& bea_scalar_mode = std::get<1>(GetParam()); + + // 0 - no concat + // 10 - concat on lhs, reshape on 0 port + // 11 - concat on lhs, reshape on 1 port + // 20 - concat on rhs, reshape on 0 port + // 21 - concat on rhs, reshape on 1 port + // 300 - concat on both sizes, both reshapes on 0 port of concats + // 301 - concat on both sizes, lhs reshape on 0 port, rhs reshape on 1 port + // 310 - concat on both sizes, lhs reshape on 1 port, rhs reshape on 0 port + // 311 - concat on both sizes, both reshapes on 1 port of concats + const size_t& concat_mode = std::get<2>(GetParam()); + + // 0 - no add, 1 - add has matmul on lhs, 2 - add has matmul on rhs + const size_t& final_add_mode = std::get<3>(GetParam()); + + const auto& max_idx = + max_element({lhs_shape_idx, rhs_shape_idx, lhs_reshape_idx, rhs_reshape_idx, out_reshape_idx}); + const DimensionTestHelper dims(max_idx + 1); + + PartialShape lhs_original_pshape = dims.make_shape(lhs_shape_idx); + PartialShape rhs_original_pshape = dims.make_shape(rhs_shape_idx); + + get_model(dims, + lhs_original_pshape, + rhs_original_pshape, + lhs_reshape_idx, + rhs_reshape_idx, + out_reshape_idx, + bea_scalar_mode, + concat_mode, + final_add_mode); + manager.register_pass(); + get_model_ref(dims, + lhs_original_pshape, + rhs_original_pshape, + lhs_reshape_idx, + rhs_reshape_idx, + bea_scalar_mode, + concat_mode, + final_add_mode); + } + + void get_model(const DimensionTestHelper& dims, + const PartialShape& lhs_original_pshape, + const PartialShape& rhs_original_pshape, + const vector& lhs_reshape_idx, + const vector& rhs_reshape_idx, + const vector& out_reshape_idx, + const size_t& bea_scalar_mode, + const size_t& concat_mode, + const size_t& final_add_mode) { + ParameterVector inputs; + OutputVector outputs; + + // LHS input of MatMul + auto lhs_input = make_shared(element::f32, lhs_original_pshape); + auto lhs_output = dims.make_reshape(lhs_input, lhs_reshape_idx); + + if (set{10, 11, 300, 301, 310, 311}.count(concat_mode)) { + const auto& another_pshape = make_concat_input_pshape(dims, lhs_reshape_idx); + const auto& another_input = make_shared(element::f32, another_pshape); + + if (set{10, 300, 301}.count(concat_mode)) { // reshape on 0 port + lhs_output = make_shared(OutputVector{lhs_output, another_input}, -1); + } else if (set{11, 310, 311}.count(concat_mode)) { // reshape on 1 port + lhs_output = make_shared(OutputVector{another_input, lhs_output}, -1); + } else { + ASSERT_TRUE(false) << "Unknown mode of concat: " << concat_mode; + } + inputs.push_back(another_input); + outputs.emplace_back(lhs_output); + } + + if (bea_scalar_mode == 1 || bea_scalar_mode == 3) + lhs_output = make_shared(lhs_output, v0::Constant::create(element::f32, {}, {0.125})); + + // RHS input of MatMul + auto rhs_input = make_shared(element::f32, rhs_original_pshape); + auto rhs_output = dims.make_reshape(rhs_input, rhs_reshape_idx); + + if (set{20, 21, 300, 301, 310, 311}.count(concat_mode)) { + const auto& another_pshape = make_concat_input_pshape(dims, rhs_reshape_idx); + const auto& another_input = make_shared(element::f32, another_pshape); + if (set{20, 300, 310}.count(concat_mode)) { // reshape on 0 port + rhs_output = make_shared(OutputVector{rhs_output, another_input}, -1); + } else if (set{21, 301, 311}.count(concat_mode)) { // reshape on 1 port + rhs_output = make_shared(OutputVector{another_input, rhs_output}, -1); + } else { + ASSERT_TRUE(false) << "Unknown mode of concat: " << concat_mode; + } + inputs.push_back(another_input); + outputs.emplace_back(rhs_output); + } + + if (bea_scalar_mode == 2 || bea_scalar_mode == 3) + rhs_output = make_shared(rhs_output, v0::Constant::create(element::f32, {}, {0.125})); + + Output matmul = make_shared(lhs_output, rhs_output); + + if (final_add_mode == 1) // 1 - add has matmul on lhs + matmul = + make_shared(matmul, v0::Constant::create(element::f32, Shape(lhs_reshape_idx.size(), 1), {1})); + else if (final_add_mode == 2) // 2 - add has matmul on rhs + matmul = + make_shared(v0::Constant::create(element::f32, Shape(lhs_reshape_idx.size(), 1), {1}), matmul); + + auto output_reshape = reshape(matmul, out_reshape_idx, dims); + + inputs.push_back(dims.get_parameter()); + inputs.push_back(lhs_input); + inputs.push_back(rhs_input); + outputs.emplace_back(output_reshape); + + for (auto& output : outputs) + output = std::make_shared(output, v0::Constant::create(element::i32, {1}, {-1}), false); + auto output = make_shared(outputs, 0); + model = make_shared(output, inputs, "Tested model"); + } + + void get_model_ref(const DimensionTestHelper& dims, + const PartialShape& lhs_original_pshape, + const PartialShape& rhs_original_pshape, + const vector& lhs_reshape_idx, + const vector& rhs_reshape_idx, + const size_t& bea_scalar_mode, + const size_t& concat_mode, + const size_t& final_add_mode) { + ParameterVector inputs; + OutputVector outputs; + + // LHS input of MatMul + auto lhs_input = make_shared(element::f32, lhs_original_pshape); + auto lhs_output = lhs_input->output(0); + + if (set{10, 11, 300, 301, 310, 311}.count(concat_mode)) { + const auto& another_pshape = make_concat_input_pshape(dims, lhs_reshape_idx); + const auto& another_input = make_shared(element::f32, another_pshape); + + auto target_shape_of_input = get_shape_from_sources(lhs_output, another_input); + auto input_reshape = make_shared(another_input, target_shape_of_input, false); + + if (set{10, 300, 301}.count(concat_mode)) { // reshape on 0 port + lhs_output = make_shared(OutputVector{lhs_output, input_reshape}, -1); + } else if (set{11, 310, 311}.count(concat_mode)) { // reshape on 1 port + lhs_output = make_shared(OutputVector{input_reshape, lhs_output}, -1); + } else { + ASSERT_TRUE(false) << "Unknown mode of concat: " << concat_mode; + } + + auto target_shape_of_output = get_shape_from_sources(input_reshape->input_value(0), lhs_output); + auto output_reshape = make_shared(lhs_output, target_shape_of_output, false); + + inputs.push_back(another_input); + outputs.emplace_back(output_reshape); + } + + if (bea_scalar_mode == 1 || bea_scalar_mode == 3) + lhs_output = make_shared(lhs_output, v0::Constant::create(element::f32, {}, {0.125})); + + // RHS input of MatMul + auto rhs_input = make_shared(element::f32, rhs_original_pshape); + auto rhs_output = rhs_input->output(0); + + if (set{20, 21, 300, 301, 310, 311}.count(concat_mode)) { + const auto& another_pshape = make_concat_input_pshape(dims, rhs_reshape_idx); + const auto& another_input = make_shared(element::f32, another_pshape); + + auto target_shape_of_input = get_shape_from_sources(rhs_output, another_input); + auto input_reshape = make_shared(another_input, target_shape_of_input, false); + + if (set{20, 300, 310}.count(concat_mode)) { // reshape on 0 port + rhs_output = make_shared(OutputVector{rhs_output, input_reshape}, -1); + } else if (set{21, 301, 311}.count(concat_mode)) { // reshape on 1 port + rhs_output = make_shared(OutputVector{input_reshape, rhs_output}, -1); + } else { + ASSERT_TRUE(false) << "Unknown mode of concat: " << concat_mode; + } + auto target_shape_of_output = get_shape_from_sources(input_reshape->input_value(0), rhs_output); + auto output_reshape = make_shared(rhs_output, target_shape_of_output, false); + + inputs.push_back(another_input); + outputs.emplace_back(output_reshape); + } + + if (bea_scalar_mode == 2 || bea_scalar_mode == 3) + rhs_output = make_shared(rhs_output, v0::Constant::create(element::f32, {}, {0.125})); + + Output matmul = make_shared(lhs_output, rhs_output); + + if (final_add_mode == 1) // 1 - add has matmul on lhs + matmul = + make_shared(matmul, v0::Constant::create(element::f32, Shape(lhs_reshape_idx.size(), 1), {1})); + else if (final_add_mode == 2) // 2 - add has matmul on rhs + matmul = + make_shared(v0::Constant::create(element::f32, Shape(lhs_reshape_idx.size(), 1), {1}), matmul); + + inputs.push_back(dims.get_parameter()); + inputs.push_back(lhs_input); + inputs.push_back(rhs_input); + outputs.emplace_back(matmul); + + for (auto& output : outputs) + output = std::make_shared(output, v0::Constant::create(element::i32, {1}, {-1}), false); + auto output = make_shared(outputs, 0); + + model_ref = make_shared(output, inputs, "Reference model"); + } + + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + vector lhs_input_shape_indices, lhs_reshape_indices; + vector rhs_input_shape_indices, rhs_reshape_indices; + vector output_reshape_indices; + size_t bea_scalar_mode, concat_mode, final_add_mode; + + tuple, vector, vector, vector, vector> tmp; + + std::tie(tmp, bea_scalar_mode, concat_mode, final_add_mode) = obj.param; + std::tie(lhs_input_shape_indices, + lhs_reshape_indices, + rhs_input_shape_indices, + rhs_reshape_indices, + output_reshape_indices) = tmp; + + std::ostringstream result; + result << "l_in_shape_idx=" << lhs_input_shape_indices << "_l_reshape_idx=" << lhs_reshape_indices + << "_r_in_shape_idx=" << rhs_input_shape_indices << "_r_reshape_idx=" << rhs_reshape_indices + << "_out_reshape_idx=" << output_reshape_indices << "_bea_scalar_mode=" << bea_scalar_mode + << "_concat_mode=" << concat_mode << "_final_add_mode=" << final_add_mode; + return result.str(); + } +}; + +const auto shape_test_cases = + vector, vector, vector, vector, vector>>{ + {{0, 1, 2, 3}, {5, 2, 3}, {0, 1, 3, 4}, {5, 3, 4}, {0, 1, 2, 4}}, // 4D -> 3D -> 4D + {{5, 2, 3}, {0, 1, 2, 3}, {5, 3, 4}, {0, 1, 3, 4}, {5, 2, 4}}, // 3D -> 4D -> 3D + {{0, 1, 2, 3, 4}, {0, 6, 3, 4}, {0, 1, 2, 4, 5}, {0, 6, 4, 5}, {0, 1, 2, 3, 5}}, // 5D -> 4D -> 5D + }; + +const auto bea_scalar_modes = vector{0, 1, 2, 3}; +const auto concat_modes = vector{0, 10, 11, 20, 21, 300, 301, 310, 311}; +const auto final_add_modes = vector{0, 1, 2}; + +TEST_P(DeReshapeMatMulTest, DeReshapeTests) {} + +INSTANTIATE_TEST_SUITE_P( + TransformationTestsF, + DeReshapeMatMulTest, + testing::Combine(testing::ValuesIn(shape_test_cases), // lhs_idx, rhs_idx, reshape_idx, reshape_idx, reshape_idx + testing::ValuesIn(bea_scalar_modes), + testing::ValuesIn(concat_modes), + testing::ValuesIn(final_add_modes)), + DeReshapeMatMulTest::getTestCaseName);