Skip to content

Commit

Permalink
[TRANSFORMATIONS][GPU] SDPA Fusion passes (#28042)
Browse files Browse the repository at this point in the history
### Details:
 - Added basic SDPA fusion pass and QK scaling fusion into SDPA

T5 case

---------

Signed-off-by: Vladimir Paramuzov <[email protected]>
  • Loading branch information
vladimir-paramuzov authored Dec 24, 2024
1 parent f62b94f commit b4c81e0
Show file tree
Hide file tree
Showing 9 changed files with 853 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/matcher_pass.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

/// This pass transforms the following sub-graph to a single Scaled Dot Product Attention operation.
/// Before:
/// ┌───────┐ ┌───────┐ ┌───────┐
/// │ Q │ │ K │ │ V │
/// └───┬───┘ └───┬───┘ └───┬───┘
/// │ │ │
/// │ │ │
/// ┌───┴───┐ ┌─────┴──────┐ │
/// │ MatMul│<──│ Transpose │ │
/// └───┬───┘ | (Optional) │ │
/// │ └────────────┘ │
/// ┌───┴───┐ ┌─────────────┐ │
/// │ Add │<───│AttentionMask│ │
/// └───┬───┘ | (Optional) │ │
/// │ └─────────────┘ │
/// ┌───┴───┐ │
/// │Softmax│ │
/// └───┬───┘ │
/// │ │
/// ┌───┴───┐ │
/// │ MatMul│<─────────────────────┘
/// └───┬───┘
/// ┌───┴───┐
/// │ Output│
/// └───────┘
///
/// After:
/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐
/// │ Q │ │ K │ │ V │ │AttentionMask│
/// └───┬───┘ └───┬───┘ └───┬───┘ └──────┬──────┘
/// │ │ │ │
/// │ │ │ │
/// ┌───┴────────────┴────────────┴───────────────┴─┐
/// │ ScaledDotProductAttention │
/// └────────────────────┬──────────────────────────┘
/// │
/// │
/// ┌────┴────┐
/// │ Output │
/// └─────────┘
class TRANSFORMATIONS_API SDPAFusion : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("SDPAFusion", "0");
SDPAFusion();
};

} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/matcher_pass.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

/// Merges explicit multiplication by scalar value for Q and K into scale attribute of SDPA op
/// Before:
/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐ ┌─────────────┐
/// │ Q │ │ K │ │ V │ │AttentionMask│ │ Scale |
/// └───┬───┘ └───┬───┘ └───┬───┘ │ (Optional) │ │ (Optional) │
/// │ │ │ └──────┬──────┘ └───────┬─────┘
/// │ │ │ │ |
/// ┌───┴───┐ ┌───┴───┐ │ │ |
/// │ Mul | │ Mul │ | │ |
/// └───┬───┘ └───┬───┘ │ │ │
/// │ │ │ │ │
/// | │ │ │ │
/// ┌───┴────────────┴────────────┴─────────────┴─┐ |
/// │ ScaledDotProductAttention │──────────────────┘
/// └────────────────────┬────────────────────────┘
/// │
/// │
/// ┌────┴────┐
/// │ Output │
/// └─────────┘
/// After:
/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐ ┌───────┐
/// │ Q │ │ K │ │ V │ │AttentionMask│ │ Scale |
/// └───┬───┘ └───┬───┘ └───┬───┘ └──────┬──────┘ └───┬───┘
/// │ │ │ │ |
/// │ │ │ │ |
/// | │ │ │ |
/// ┌───┴────────────┴────────────┴─────────────┴─┐ |
/// │ ScaledDotProductAttention │───────────┘
/// └────────────────────┬────────────────────────┘
/// │
/// │
/// ┌────┴────┐
/// │ Output │
/// └─────────┘
/// Multiply ops for Q and K are eliminated in the following cases:
/// 1. Q_scale and K_scale are constant
/// 2. Q_scale * SDPA_Scale == 1 or K_scale * SDPA_Scale == 1
class TRANSFORMATIONS_API SDPAScaleFusion : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("SDPAScaleFusion", "0");
SDPAScaleFusion();
};

} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
#include "transformations/common_optimizations/remove_multi_subgraph_op_dangling_params.hpp"
#include "transformations/common_optimizations/reshape_sequence_fusion.hpp"
#include "transformations/common_optimizations/ric_fusion.hpp"
#include "transformations/common_optimizations/sdpa_fusion.hpp"
#include "transformations/common_optimizations/select_with_one_value_condition.hpp"
#include "transformations/common_optimizations/sequence_fusion.hpp"
#include "transformations/common_optimizations/shared_ops_optimization.hpp"
Expand Down Expand Up @@ -229,6 +230,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
ADD_MATCHER(common_fusions, ConvertTensorIteratorToSequence)
ADD_MATCHER(common_fusions, SplitConcatPairToInterpolateFusion, m_use_shapes)
ADD_MATCHER(common_fusions, ConvolutionToGroupConvolutionFusion)
ADD_MATCHER(common_fusions, SDPAFusion)
if (m_use_shapes) {
ADD_MATCHER(common_fusions, NearestNeighborUpsamplingFusion)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/sdpa_fusion.hpp"

#include "openvino/core/rt_info.hpp"
#include "openvino/core/type.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/scaled_dot_product_attention.hpp"
#include "openvino/op/softmax.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/gen_pattern.hpp"

namespace ov {
namespace pass {

SDPAFusion::SDPAFusion() {
using namespace ov::pass::pattern;
using namespace ov::gen_pattern;

auto q = makePattern(ov::Rank(4));
auto k = makePattern(ov::Rank(4));
auto v = makePattern(ov::Rank(4));
auto mask = makePattern();

auto k_transpose_order = pattern::wrap_type<ov::op::v0::Constant>([](const Output<Node>& node) {
auto axis_order =
std::dynamic_pointer_cast<ov::op::v0::Constant>(node.get_node_shared_ptr())->cast_vector<int64_t>();
return axis_order == std::vector<int64_t>{0, 1, 3, 2};
});

auto k_t = pattern::wrap_type<ov::op::v1::Transpose>({k, k_transpose_order});
auto qk_nn = makePattern<ov::op::v0::MatMul>({q, k_t}, {{"transpose_a", false}, {"transpose_b", false}});
auto qk_nt = makePattern<ov::op::v0::MatMul>({q, k}, {{"transpose_a", false}, {"transpose_b", true}});
auto qk = qk_nt | qk_nn;
auto optional_add_mask = optional<ov::op::v1::Add>({qk, mask});
auto softmax = makePattern<ov::op::v8::Softmax>({optional_add_mask}, {{"axis", "-1"}});
auto qkv = makePattern<ov::op::v0::MatMul>({softmax, v}, {{"transpose_a", false}, {"transpose_b", false}});

auto valid_qk_shapes = [](const std::shared_ptr<ov::op::v0::MatMul>& qk_matmul) {
auto q_pshape = qk_matmul->get_input_partial_shape(0);
auto k_pshape = qk_matmul->get_input_partial_shape(1);

const size_t q_head_size_idx = 3;
const size_t k_head_size_idx = qk_matmul->get_transpose_b() ? 3 : 2;

return q_pshape.size() == 4 && k_pshape.size() == 4 && q_pshape[q_head_size_idx].is_static() &&
k_pshape[k_head_size_idx].is_static() &&
q_pshape[q_head_size_idx].get_length() == k_pshape[k_head_size_idx].get_length();
};

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
if (transformation_callback(m.get_match_root())) {
return false;
}

auto q_node = pattern_map.at(q);
auto k_node = pattern_map.at(k);
auto v_node = pattern_map.at(v);

if (!valid_qk_shapes(ov::as_type_ptr<ov::op::v0::MatMul>(pattern_map.at(qk).get_node_shared_ptr()))) {
return false;
}

if (pattern_map.at(qk).get_target_inputs().size() > 1 ||
pattern_map.at(softmax).get_target_inputs().size() > 1) {
return false;
}
if (pattern_map.count(optional_add_mask) && (pattern_map.at(optional_add_mask).get_target_inputs().size() > 1 ||
pattern_map.at(mask).get_partial_shape().size() > 4)) {
return false;
}

Output<ov::Node> mask_value;
Output<ov::Node> mask_input;
if (pattern_map.find(optional_add_mask) != pattern_map.end()) {
mask_value = pattern_map.at(mask);
} else {
mask_value = ov::op::v0::Constant::create(q_node.get_element_type(), ov::Shape{}, std::vector<float>{0});
}

if (mask_value.get_partial_shape().size() > 4) {
return false;
}

if (mask_value.get_partial_shape().rank() == 0 || mask_value.get_partial_shape().rank() == 4) {
mask_input = mask_value;
} else {
size_t rank_diff = q_node.get_partial_shape().size() - mask_value.get_partial_shape().size();
std::vector<int64_t> axes(rank_diff);
std::iota(axes.begin(), axes.end(), 0);
mask_input = std::make_shared<ov::op::v0::Unsqueeze>(
mask_value,
ov::op::v0::Constant::create(ov::element::i64, ov::Shape{rank_diff}, axes));
}

std::shared_ptr<ov::Node> scale_node =
ov::op::v0::Constant::create(q_node.get_element_type(), ov::Shape{}, std::vector<float>{1.0f});

std::shared_ptr<ov::Node> sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q_node,
k_node,
v_node,
mask_input,
scale_node,
false);

sdpa->set_friendly_name(m.get_match_root()->get_friendly_name());
ov::copy_runtime_info(m.get_matched_nodes(), sdpa);
ov::replace_node(m.get_match_root(), sdpa);

return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(qkv, "SDPAFusion");
this->register_matcher(m, callback);
}

} // namespace pass
} // namespace ov
Loading

0 comments on commit b4c81e0

Please sign in to comment.