Skip to content

Commit

Permalink
De-Reshape MatMul
Browse files Browse the repository at this point in the history
  • Loading branch information
jane-intel committed Oct 10, 2023
1 parent 0dcde7f commit 9487cc3
Show file tree
Hide file tree
Showing 7 changed files with 904 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>
#include <openvino/pass/graph_rewrite.hpp>
#include <transformations_visibility.hpp>

namespace ov {
namespace pass {
class TRANSFORMATIONS_API DeReshapeMatMul;
} // namespace pass
} // namespace ov

/**
* @ingroup ie_transformation_common_api
* @brief Transformation uses symbol / label information to optimize out Reshape operations surrounding MatMul.
* It checks that surrounding Reshapes are only manipulating with batch dimensions of tensor in a do-undo kind of way.
*
* Example:
* Before:
* [A,B,C,D] -> Reshape -> [A*B,C,D]
* MatMul [A*B,C,E] -> Reshape -> [A,B,C,E]
* [A,B,D,E] -> Reshape -> [A*B,D,E]
*
* After:
* [A,B,C,D] ->
* MatMul -> [A,B,C,E]
* [A,B,D,E] ->
*
* Transformation allows slightly different variations of the pattern on inputs of MatMul.
* - Simplest pattern contains only Reshape operation on MatMul input:
* Reshape -> MatMul
*
* - The next acceptable variation is Concat of two inputs on MatMul input:
* Reshape -[-> Concat -]-> MatMul
* This variation would be transformed with realignment of the other input of Concat and the other outputs of
* Concat with the help of Reshape operations
*
* - The most complex variation on the MatMul input pattern is with Binary Elementwise Operation with scalar second
* input: Reshape -[-> Concat -]-[-> BEA (scalar) -]-> MatMul
*
* Additionally, transformation supports variation of the pattern on output of MatMul. It allows for
* Binary Elementwise Arithmetic operation without second input scalar restriction.
* MatMul -[-> BEA -]-> Reshape
* this pattern variation is only applicable for the case when input reshapes are 4D -> 3D and output reshape is 3D ->
* 4D. Additionally, shape labels on output of MatMul should be equal to the input shape labels of the last Reshape,
* meaning that this Binary Elementwise Arithmetic doesn't perform any broadcasting of input coming from MatMul -- only
* other input may be broadcasted to the MatMul input of this BEA. This effect (equality of MatMul output shape labels
* and output shape of BEA) is being handled by LabelResolvingThroughSelect transformation in the particular models that
* this variation targets.
*
* Full pattern this transformation searches for:
* -> Reshape -[-> Concat -]-[-> BEA (scalar) -]->
* MatMul -[-> BEA -]-> Reshape ->
* -> Reshape -[-> Concat -]-[-> BEA (scalar) -]->
*
* NOTE: input branches could be (and in observed model cases are) asymmetrical, meaning that the presence of Concat
* on one input of MatMul doesn't require the other input to also have Concat
*/
class ov::pass::DeReshapeMatMul : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("DeReshapeMatMul", "0");
DeReshapeMatMul();
};
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace ov {
namespace pass {
class TRANSFORMATIONS_API SymbolicOptimizations;
class TRANSFORMATIONS_API SymbolicPropagation;
class TRANSFORMATIONS_API LabelResolvingThroughSelect;
} // namespace pass
} // namespace ov

Expand Down Expand Up @@ -48,3 +49,10 @@ class ov::pass::SymbolicPropagation : public ov::pass::ModelPass {
private:
std::shared_ptr<ov::TableOfEquivalence> m_te;
};

// TODO: add description and order
class ov::pass::LabelResolvingThroughSelect : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("LabelResolvingThroughSelect", "0");
LabelResolvingThroughSelect();
};
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ TRANSFORMATIONS_API bool get_labels(const ov::Output<ov::Node>& output, ov::Tens
///
/// \return true if labels are unique and equal between lhs and rhs else false
TRANSFORMATIONS_API bool are_unique_and_equal_labels(const ov::TensorLabel& lhs, const ov::TensorLabel& rhs);

/// \brief Compares dimensions: if dimensions are static compares values of dimensions, if dimensions are dynamic
/// compares their respective labels using TableOfEquivalence
///
/// \param lhs Dimension object to compare
/// \param rhs Dimension object to compare
///
/// \return true if static dimensions are equal and dynamic dimensions have equal labels else false
TRANSFORMATIONS_API bool dims_are_equal(const ov::Dimension& lhs, const ov::Dimension& rhs);

} // namespace util
} // namespace symbol
} // namespace ov
Loading

0 comments on commit 9487cc3

Please sign in to comment.