diff --git a/src/common/transformations/include/transformations/common_optimizations/sdpa_fusion.hpp b/src/common/transformations/include/transformations/common_optimizations/sdpa_fusion.hpp new file mode 100644 index 00000000000000..84383b777604ea --- /dev/null +++ b/src/common/transformations/include/transformations/common_optimizations/sdpa_fusion.hpp @@ -0,0 +1,60 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/matcher_pass.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +/// This pass transforms the following sub-graph to a single Scaled Dot Product Attention operation. +/// Before: +/// ┌───────┐ ┌───────┐ ┌───────┐ +/// │ Q │ │ K │ │ V │ +/// └───┬───┘ └───┬───┘ └───┬───┘ +/// │ │ │ +/// │ │ │ +/// ┌───┴───┐ ┌─────┴──────┐ │ +/// │ MatMul│<──│ Transpose │ │ +/// └───┬───┘ | (Optional) │ │ +/// │ └────────────┘ │ +/// ┌───┴───┐ ┌─────────────┐ │ +/// │ Add │<───│AttentionMask│ │ +/// └───┬───┘ | (Optional) │ │ +/// │ └─────────────┘ │ +/// ┌───┴───┐ │ +/// │Softmax│ │ +/// └───┬───┘ │ +/// │ │ +/// ┌───┴───┐ │ +/// │ MatMul│<─────────────────────┘ +/// └───┬───┘ +/// ┌───┴───┐ +/// │ Output│ +/// └───────┘ +/// +/// After: +/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐ +/// │ Q │ │ K │ │ V │ │AttentionMask│ +/// └───┬───┘ └───┬───┘ └───┬───┘ └──────┬──────┘ +/// │ │ │ │ +/// │ │ │ │ +/// ┌───┴────────────┴────────────┴───────────────┴─┐ +/// │ ScaledDotProductAttention │ +/// └────────────────────┬──────────────────────────┘ +/// │ +/// │ +/// ┌────┴────┐ +/// │ Output │ +/// └─────────┘ +class TRANSFORMATIONS_API SDPAFusion : public ov::pass::MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("SDPAFusion", "0"); + SDPAFusion(); +}; + +} // namespace pass +} // namespace ov diff --git a/src/common/transformations/include/transformations/common_optimizations/sdpa_scale_fusion.hpp b/src/common/transformations/include/transformations/common_optimizations/sdpa_scale_fusion.hpp new file mode 100644 index 00000000000000..cae0363e785f4e --- /dev/null +++ b/src/common/transformations/include/transformations/common_optimizations/sdpa_scale_fusion.hpp @@ -0,0 +1,58 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/matcher_pass.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +/// Merges explicit multiplication by scalar value for Q and K into scale attribute of SDPA op +/// Before: +/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐ ┌─────────────┐ +/// │ Q │ │ K │ │ V │ │AttentionMask│ │ Scale | +/// └───┬───┘ └───┬───┘ └───┬───┘ │ (Optional) │ │ (Optional) │ +/// │ │ │ └──────┬──────┘ └───────┬─────┘ +/// │ │ │ │ | +/// ┌───┴───┐ ┌───┴───┐ │ │ | +/// │ Mul | │ Mul │ | │ | +/// └───┬───┘ └───┬───┘ │ │ │ +/// │ │ │ │ │ +/// | │ │ │ │ +/// ┌───┴────────────┴────────────┴─────────────┴─┐ | +/// │ ScaledDotProductAttention │──────────────────┘ +/// └────────────────────┬────────────────────────┘ +/// │ +/// │ +/// ┌────┴────┐ +/// │ Output │ +/// └─────────┘ +/// After: +/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐ ┌───────┐ +/// │ Q │ │ K │ │ V │ │AttentionMask│ │ Scale | +/// └───┬───┘ └───┬───┘ └───┬───┘ └──────┬──────┘ └───┬───┘ +/// │ │ │ │ | +/// │ │ │ │ | +/// | │ │ │ | +/// ┌───┴────────────┴────────────┴─────────────┴─┐ | +/// │ ScaledDotProductAttention │───────────┘ +/// └────────────────────┬────────────────────────┘ +/// │ +/// │ +/// ┌────┴────┐ +/// │ Output │ +/// └─────────┘ +/// Multiply ops for Q and K are eliminated in the following cases: +/// 1. Q_scale and K_scale are constant +/// 2. Q_scale * SDPA_Scale == 1 or K_scale * SDPA_Scale == 1 +class TRANSFORMATIONS_API SDPAScaleFusion : public ov::pass::MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("SDPAScaleFusion", "0"); + SDPAScaleFusion(); +}; + +} // namespace pass +} // namespace ov diff --git a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp index 185ae84ec83642..23fbf882024bdc 100644 --- a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp @@ -65,6 +65,7 @@ #include "transformations/common_optimizations/remove_multi_subgraph_op_dangling_params.hpp" #include "transformations/common_optimizations/reshape_sequence_fusion.hpp" #include "transformations/common_optimizations/ric_fusion.hpp" +#include "transformations/common_optimizations/sdpa_fusion.hpp" #include "transformations/common_optimizations/select_with_one_value_condition.hpp" #include "transformations/common_optimizations/sequence_fusion.hpp" #include "transformations/common_optimizations/shared_ops_optimization.hpp" @@ -229,6 +230,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr ADD_MATCHER(common_fusions, ConvertTensorIteratorToSequence) ADD_MATCHER(common_fusions, SplitConcatPairToInterpolateFusion, m_use_shapes) ADD_MATCHER(common_fusions, ConvolutionToGroupConvolutionFusion) + ADD_MATCHER(common_fusions, SDPAFusion) if (m_use_shapes) { ADD_MATCHER(common_fusions, NearestNeighborUpsamplingFusion) } diff --git a/src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp new file mode 100644 index 00000000000000..fc581580f70001 --- /dev/null +++ b/src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp @@ -0,0 +1,127 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/common_optimizations/sdpa_fusion.hpp" + +#include "openvino/core/rt_info.hpp" +#include "openvino/core/type.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/scaled_dot_product_attention.hpp" +#include "openvino/op/softmax.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "openvino/pass/pattern/op/optional.hpp" +#include "openvino/pass/pattern/op/pattern.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/utils/gen_pattern.hpp" + +namespace ov { +namespace pass { + +SDPAFusion::SDPAFusion() { + using namespace ov::pass::pattern; + using namespace ov::gen_pattern; + + auto q = makePattern(ov::Rank(4)); + auto k = makePattern(ov::Rank(4)); + auto v = makePattern(ov::Rank(4)); + auto mask = makePattern(); + + auto k_transpose_order = pattern::wrap_type([](const Output& node) { + auto axis_order = + std::dynamic_pointer_cast(node.get_node_shared_ptr())->cast_vector(); + return axis_order == std::vector{0, 1, 3, 2}; + }); + + auto k_t = pattern::wrap_type({k, k_transpose_order}); + auto qk_nn = makePattern({q, k_t}, {{"transpose_a", false}, {"transpose_b", false}}); + auto qk_nt = makePattern({q, k}, {{"transpose_a", false}, {"transpose_b", true}}); + auto qk = qk_nt | qk_nn; + auto optional_add_mask = optional({qk, mask}); + auto softmax = makePattern({optional_add_mask}, {{"axis", "-1"}}); + auto qkv = makePattern({softmax, v}, {{"transpose_a", false}, {"transpose_b", false}}); + + auto valid_qk_shapes = [](const std::shared_ptr& qk_matmul) { + auto q_pshape = qk_matmul->get_input_partial_shape(0); + auto k_pshape = qk_matmul->get_input_partial_shape(1); + + const size_t q_head_size_idx = 3; + const size_t k_head_size_idx = qk_matmul->get_transpose_b() ? 3 : 2; + + return q_pshape.size() == 4 && k_pshape.size() == 4 && q_pshape[q_head_size_idx].is_static() && + k_pshape[k_head_size_idx].is_static() && + q_pshape[q_head_size_idx].get_length() == k_pshape[k_head_size_idx].get_length(); + }; + + ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + if (transformation_callback(m.get_match_root())) { + return false; + } + + auto q_node = pattern_map.at(q); + auto k_node = pattern_map.at(k); + auto v_node = pattern_map.at(v); + + if (!valid_qk_shapes(ov::as_type_ptr(pattern_map.at(qk).get_node_shared_ptr()))) { + return false; + } + + if (pattern_map.at(qk).get_target_inputs().size() > 1 || + pattern_map.at(softmax).get_target_inputs().size() > 1) { + return false; + } + if (pattern_map.count(optional_add_mask) && (pattern_map.at(optional_add_mask).get_target_inputs().size() > 1 || + pattern_map.at(mask).get_partial_shape().size() > 4)) { + return false; + } + + Output mask_value; + Output mask_input; + if (pattern_map.find(optional_add_mask) != pattern_map.end()) { + mask_value = pattern_map.at(mask); + } else { + mask_value = ov::op::v0::Constant::create(q_node.get_element_type(), ov::Shape{}, std::vector{0}); + } + + if (mask_value.get_partial_shape().size() > 4) { + return false; + } + + if (mask_value.get_partial_shape().rank() == 0 || mask_value.get_partial_shape().rank() == 4) { + mask_input = mask_value; + } else { + size_t rank_diff = q_node.get_partial_shape().size() - mask_value.get_partial_shape().size(); + std::vector axes(rank_diff); + std::iota(axes.begin(), axes.end(), 0); + mask_input = std::make_shared( + mask_value, + ov::op::v0::Constant::create(ov::element::i64, ov::Shape{rank_diff}, axes)); + } + + std::shared_ptr scale_node = + ov::op::v0::Constant::create(q_node.get_element_type(), ov::Shape{}, std::vector{1.0f}); + + std::shared_ptr sdpa = std::make_shared(q_node, + k_node, + v_node, + mask_input, + scale_node, + false); + + sdpa->set_friendly_name(m.get_match_root()->get_friendly_name()); + ov::copy_runtime_info(m.get_matched_nodes(), sdpa); + ov::replace_node(m.get_match_root(), sdpa); + + return true; + }; + + auto m = std::make_shared(qkv, "SDPAFusion"); + this->register_matcher(m, callback); +} + +} // namespace pass +} // namespace ov diff --git a/src/common/transformations/src/transformations/common_optimizations/sdpa_scale_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/sdpa_scale_fusion.cpp new file mode 100644 index 00000000000000..3d750fe38a868e --- /dev/null +++ b/src/common/transformations/src/transformations/common_optimizations/sdpa_scale_fusion.cpp @@ -0,0 +1,140 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/common_optimizations/sdpa_scale_fusion.hpp" + +#include + +#include "openvino/core/node.hpp" +#include "openvino/core/rt_info.hpp" +#include "openvino/core/type.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/scaled_dot_product_attention.hpp" +#include "openvino/pass/pattern/op/optional.hpp" +#include "openvino/pass/pattern/op/pattern.hpp" +#include "transformations/utils/gen_pattern.hpp" + +namespace ov { +namespace pass { + +SDPAScaleFusion::SDPAScaleFusion() { + using namespace ov::pass::pattern; + using namespace ov::gen_pattern; + + auto q = makePattern(ov::Rank(4)); + auto k = makePattern(ov::Rank(4)); + auto v = makePattern(ov::Rank(4)); + auto mask = makePattern(); + auto sdpa_scale = makeConst({}); + auto scale_q = makePattern("[]") | makePattern("[1]"); + auto scale_k = makePattern("[]") | makePattern("[1]"); + + auto scaled_q = optional({q, scale_q}); + auto scaled_k = optional({k, scale_k}); + auto sdpa_mask_scale = + makePattern({scaled_q, scaled_k, v, mask, sdpa_scale}, + {{"causal", false}}); + auto sdpa_mask = + makePattern({scaled_q, scaled_k, v, mask}, {{"causal", false}}); + auto sdpa_simple = + makePattern({scaled_q, scaled_k, v}, {{"causal", false}}); + auto sdpa = sdpa_simple | sdpa_mask | sdpa_mask_scale; + + ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + if (transformation_callback(m.get_match_root())) { + return false; + } + + auto sdpa = m.get_match_root(); + + const bool has_q_scale = pattern_map.count(scaled_q); + const bool has_k_scale = pattern_map.count(scaled_k); + + // Nothing to do + if (!has_q_scale && !has_k_scale) + return false; + + auto prev_scale_value = 1.0f; + auto scale_q_value = 1.0f; + auto scale_k_value = 1.0f; + auto scale_et = sdpa->get_output_element_type(0); + + Output q_input = sdpa->get_input_source_output(0); + Output k_input = sdpa->get_input_source_output(1); + + std::shared_ptr scale_q_node = nullptr; + std::shared_ptr scale_k_node = nullptr; + + if (pattern_map.find(sdpa_scale) != pattern_map.end()) { + auto prev_scale_node = + ov::as_type_ptr(pattern_map.at(sdpa_scale).get_node_shared_ptr()); + prev_scale_value = prev_scale_node->cast_vector()[0]; + scale_et = prev_scale_node->get_output_element_type(0); + } else { + auto head_size = q_input.get_partial_shape()[3]; + if (head_size.is_dynamic()) + return false; + + prev_scale_value = 1.0f / std::sqrt(static_cast(head_size.get_length())); + } + + // Extract scalar scale values for Q and K if those are constant and set new inputs for SDPA + if (has_q_scale) { + scale_q_node = pattern_map.at(scale_q).get_node_shared_ptr(); + if (ov::is_type(scale_q_node)) { + scale_q_value = ov::as_type_ptr(scale_q_node)->cast_vector()[0]; + q_input = pattern_map.at(q); + } + } + if (has_k_scale) { + scale_k_node = pattern_map.at(scale_k).get_node_shared_ptr(); + if (ov::is_type(scale_k_node)) { + scale_k_value = ov::as_type_ptr(scale_k_node)->cast_vector()[0]; + k_input = pattern_map.at(k); + } + } + + Output new_scale_node; + auto new_scale_val = prev_scale_value * scale_q_value * scale_k_value; + + // If new scale is 1 and we have non-constant scale node for either Q or K, then we can make it a scale of SDPA + if (new_scale_val == 1.0f) { + if (has_q_scale && !ov::is_type(scale_q_node)) { + new_scale_node = pattern_map.at(scale_q); + q_input = pattern_map.at(q); + } else if (has_k_scale && !ov::is_type(scale_k_node)) { + new_scale_node = pattern_map.at(scale_k); + k_input = pattern_map.at(k); + } else { + new_scale_node = ov::op::v0::Constant::create(scale_et, ov::Shape{}, std::vector{new_scale_val}); + } + } else { + new_scale_node = ov::op::v0::Constant::create(scale_et, ov::Shape{}, std::vector{new_scale_val}); + } + + OutputVector new_inputs = {q_input, k_input, pattern_map.at(v)}; + if (pattern_map.find(mask) != pattern_map.end()) { + new_inputs.push_back(pattern_map.at(mask)); + } else { + new_inputs.push_back( + ov::op::v0::Constant::create(new_scale_node.get_element_type(), ov::Shape{}, std::vector{0.0f})); + } + + new_inputs.push_back(new_scale_node); + + auto new_sdpa = sdpa->clone_with_new_inputs(new_inputs); + new_sdpa->set_friendly_name(sdpa->get_friendly_name()); + ov::copy_runtime_info(sdpa, new_sdpa); + ov::replace_node(sdpa, new_sdpa); + + return true; + }; + + auto m = std::make_shared(sdpa, "SDPAScaleFusion"); + this->register_matcher(m, callback); +} + +} // namespace pass +} // namespace ov diff --git a/src/common/transformations/tests/common_optimizations/sdpa_fusion_test.cpp b/src/common/transformations/tests/common_optimizations/sdpa_fusion_test.cpp new file mode 100644 index 00000000000000..52c10ba5967bd8 --- /dev/null +++ b/src/common/transformations/tests/common_optimizations/sdpa_fusion_test.cpp @@ -0,0 +1,234 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ov_test_utils.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/softmax.hpp" +#include "openvino/op/transpose.hpp" + +using namespace testing; +using namespace ov::pass; +using namespace ov; + +TEST_F(TransformationTestsF, SDPAFusionTest1) { + const PartialShape query_shape{1, 32, -1, 32}; + const PartialShape key_shape{1, 32, -1, 32}; + const PartialShape value_shape{1, 32, -1, 32}; + + const auto query = std::make_shared(element::f32, query_shape); + const auto key = std::make_shared(element::f32, key_shape); + const auto value = std::make_shared(element::f32, value_shape); + const auto casual = false; + { + const auto qk = std::make_shared(query, key, false, true); + const auto softmax = std::make_shared(qk, -1); + const auto qkv = std::make_shared(softmax, value, false, false); + + model = std::make_shared(NodeVector{qkv}, ParameterVector{query, key, value}); + manager.register_pass(); + } + + { + const auto scale_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{1.0f}); + const auto mask_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{0.0f}); + const auto sdpa = std::make_shared(query, + key, + value, + mask_const, + scale_const, + casual); + model_ref = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value}); + } + + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} + +TEST_F(TransformationTestsF, SDPAFusionTest2) { + const PartialShape query_shape{1, 32, -1, 32}; + const PartialShape key_shape{1, 32, -1, 32}; + const PartialShape value_shape{1, 32, -1, 32}; + + const auto query = std::make_shared(element::f16, query_shape); + const auto key = std::make_shared(element::f16, key_shape); + const auto value = std::make_shared(element::f16, value_shape); + const auto casual = false; + { + const auto qk = std::make_shared(query, key, false, true); + const auto softmax = std::make_shared(qk, -1); + const auto qkv = std::make_shared(softmax, value, false, false); + + model = std::make_shared(NodeVector{qkv}, ParameterVector{query, key, value}); + manager.register_pass(); + } + + { + const auto scale_const = ov::op::v0::Constant::create(element::f16, ov::Shape{}, std::vector{1.0f}); + const auto mask_const = ov::op::v0::Constant::create(element::f16, ov::Shape{}, std::vector{0.0f}); + const auto sdpa = std::make_shared(query, + key, + value, + mask_const, + scale_const, + casual); + model_ref = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value}); + } + + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} + +TEST_F(TransformationTestsF, SDPAFusionTest3) { + const PartialShape query_shape{1, 32, -1, 32}; + const PartialShape key_shape{1, 32, -1, 32}; + const PartialShape value_shape{1, 32, -1, 32}; + + const auto query = std::make_shared(element::f16, query_shape); + const auto key = std::make_shared(element::f16, key_shape); + const auto value = std::make_shared(element::f16, value_shape); + const auto casual = false; + { + const auto key_t = + std::make_shared(key, + op::v0::Constant::create(element::i64, Shape{4}, {0, 1, 3, 2})); + const auto qk = std::make_shared(query, key_t, false, false); + const auto softmax = std::make_shared(qk, -1); + const auto qkv = std::make_shared(softmax, value, false, false); + + model = std::make_shared(NodeVector{qkv}, ParameterVector{query, key, value}); + manager.register_pass(); + } + + { + const auto scale_const = ov::op::v0::Constant::create(element::f16, ov::Shape{}, std::vector{1.0f}); + const auto mask_const = ov::op::v0::Constant::create(element::f16, ov::Shape{}, std::vector{0.0f}); + const auto sdpa = std::make_shared(query, + key, + value, + mask_const, + scale_const, + casual); + model_ref = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value}); + } + + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} + +TEST_F(TransformationTestsF, SDPAFusionTest4) { + const PartialShape query_shape{1, 32, -1, 32}; + const PartialShape key_shape{1, 32, 32, -1}; + const PartialShape value_shape{1, 32, -1, 32}; + + const auto query = std::make_shared(element::f16, query_shape); + const auto key = std::make_shared(element::f16, key_shape); + const auto value = std::make_shared(element::f16, value_shape); + { + const auto qk = std::make_shared(query, key, false, false); + const auto softmax = std::make_shared(qk, -1); + const auto qkv = std::make_shared(softmax, value, false, false); + + model = std::make_shared(NodeVector{qkv}, ParameterVector{query, key, value}); + manager.register_pass(); + } + + model_ref = model->clone(); + + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} + +TEST_F(TransformationTestsF, SDPAFusionTest5) { + const PartialShape query_shape{1, 32, -1, 32}; + const PartialShape key_shape{1, 32, -1, 32}; + const PartialShape value_shape{1, 32, -1, 32}; + const PartialShape attention_mask_shape{1, 32, -1, -1}; + + const auto query = std::make_shared(element::f16, query_shape); + const auto key = std::make_shared(element::f16, key_shape); + const auto value = std::make_shared(element::f16, value_shape); + const auto mask = std::make_shared(element::f16, attention_mask_shape); + const auto casual = false; + { + const auto qk = std::make_shared(query, key, false, true); + const auto mask_add = std::make_shared(qk, mask); + const auto softmax = std::make_shared(mask_add, -1); + const auto qkv = std::make_shared(softmax, value, false, false); + + model = std::make_shared(NodeVector{qkv}, ParameterVector{query, key, value, mask}); + manager.register_pass(); + } + + { + const auto scale_const = ov::op::v0::Constant::create(element::f16, ov::Shape{}, std::vector{1.0f}); + const auto sdpa = + std::make_shared(query, key, value, mask, scale_const, casual); + model_ref = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value, mask}); + } + + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} + +TEST_F(TransformationTestsF, SDPAFusionTest6) { + const PartialShape query_shape{1, 32, 10, 32}; + const PartialShape key_shape{1, 32, 10, 32}; + const PartialShape value_shape{1, 32, 10, 32}; + const PartialShape attention_mask_shape{1, 1, 10, 10}; + + const auto query = std::make_shared(element::f16, query_shape); + const auto key = std::make_shared(element::f16, key_shape); + const auto value = std::make_shared(element::f16, value_shape); + const auto mask = std::make_shared(element::f16, attention_mask_shape); + const auto casual = false; + { + const auto qk = std::make_shared(query, key, false, true); + const auto mask_add = std::make_shared(qk, mask); + const auto softmax = std::make_shared(mask_add, -1); + const auto qkv = std::make_shared(softmax, value, false, false); + + model = std::make_shared(NodeVector{qkv}, ParameterVector{query, key, value, mask}); + manager.register_pass(); + } + + { + const auto scale_const = ov::op::v0::Constant::create(element::f16, ov::Shape{}, std::vector{1.0f}); + const auto sdpa = + std::make_shared(query, key, value, mask, scale_const, casual); + model_ref = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value, mask}); + } + + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} + +TEST_F(TransformationTestsF, SDPAFusionTest7) { + const PartialShape query_shape{1, 8, -1, 32}; + const PartialShape key_shape{-1, 1, 8, 32}; + const PartialShape value_shape{1, 8, -1, 32}; + + const auto query = std::make_shared(element::f16, query_shape); + const auto key = std::make_shared(element::f16, key_shape); + const auto value = std::make_shared(element::f16, value_shape); + { + const auto key_t = + std::make_shared(key, + op::v0::Constant::create(element::i64, Shape{4}, {1, 2, 3, 0})); + const auto qk = std::make_shared(query, key_t, false, false); + const auto softmax = std::make_shared(qk, -1); + const auto qkv = std::make_shared(softmax, value, false, false); + + model = std::make_shared(NodeVector{qkv}, ParameterVector{query, key, value}); + manager.register_pass(); + } +} diff --git a/src/common/transformations/tests/common_optimizations/sdpa_scale_fusion_test.cpp b/src/common/transformations/tests/common_optimizations/sdpa_scale_fusion_test.cpp new file mode 100644 index 00000000000000..f922f030a9c43b --- /dev/null +++ b/src/common/transformations/tests/common_optimizations/sdpa_scale_fusion_test.cpp @@ -0,0 +1,228 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ov_test_utils.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/scaled_dot_product_attention.hpp" + +using namespace testing; +using namespace ov::pass; +using namespace ov; + +TEST_F(TransformationTestsF, SDPAScaleFusionTest1) { + const PartialShape query_shape{1, 32, -1, 32}; + const PartialShape key_shape{1, 32, -1, 32}; + const PartialShape value_shape{1, 32, -1, 32}; + + const auto query = std::make_shared(element::f32, query_shape); + const auto key = std::make_shared(element::f32, key_shape); + const auto value = std::make_shared(element::f32, value_shape); + const auto scale_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{8.0f}); + const auto v_scaled = std::make_shared(value, scale_const); + const auto casual = false; + { + const auto q_scaled = std::make_shared(query, scale_const); + const auto k_scaled = std::make_shared(key, scale_const); + const auto sdpa = + std::make_shared(q_scaled, k_scaled, v_scaled, casual); + + model = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value}); + manager.register_pass(); + } + + { + const auto new_mask_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{0.0f}); + const auto new_scale_const = + ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{64.0f / std::sqrt(32.0f)}); + const auto sdpa = std::make_shared(query, + key, + v_scaled, + new_mask_const, + new_scale_const, + casual); + model_ref = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value}); + } + + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} + +TEST_F(TransformationTestsF, SDPAScaleFusionTest2) { + const PartialShape query_shape{1, 32, -1, 32}; + const PartialShape key_shape{1, 32, -1, 32}; + const PartialShape value_shape{1, 32, -1, 32}; + + const auto query = std::make_shared(element::f32, query_shape); + const auto key = std::make_shared(element::f32, key_shape); + const auto value = std::make_shared(element::f32, value_shape); + const auto sdpa_mask_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{0.0f}); + const auto sdpa_scale_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{2.0f}); + const auto scale_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{8.0f}); + const auto v_scaled = std::make_shared(value, scale_const); + const auto casual = false; + { + const auto q_scaled = std::make_shared(query, scale_const); + const auto k_scaled = std::make_shared(key, scale_const); + const auto sdpa = std::make_shared(q_scaled, + k_scaled, + v_scaled, + sdpa_mask_const, + sdpa_scale_const, + casual); + + model = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value}); + manager.register_pass(); + } + + { + const auto new_scale_const = + ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{128.0f}); + const auto sdpa = std::make_shared(query, + key, + v_scaled, + sdpa_mask_const, + new_scale_const, + casual); + model_ref = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value}); + } + + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} + +TEST_F(TransformationTestsF, SDPAScaleFusionTest3) { + const PartialShape query_shape{1, 32, -1, 32}; + const PartialShape key_shape{1, 32, -1, 32}; + const PartialShape value_shape{1, 32, -1, 32}; + + const auto query = std::make_shared(element::f32, query_shape); + const auto key = std::make_shared(element::f32, key_shape); + const auto value = std::make_shared(element::f32, value_shape); + const auto sdpa_mask_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{0.0f}); + const auto sdpa_scale_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{2.0f}); + const auto scale_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{8.0f}); + const auto v_scaled = std::make_shared(value, scale_const); + const auto casual = false; + { + const auto q_scaled = std::make_shared(query, scale_const); + const auto sdpa = std::make_shared(q_scaled, + key, + v_scaled, + sdpa_mask_const, + sdpa_scale_const, + casual); + + model = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value}); + manager.register_pass(); + } + + { + const auto new_scale_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{16.0f}); + const auto sdpa = std::make_shared(query, + key, + v_scaled, + sdpa_mask_const, + new_scale_const, + casual); + model_ref = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value}); + } + + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} + +TEST_F(TransformationTestsF, SDPAScaleFusionTest4) { + const PartialShape query_shape{1, 32, -1, 32}; + const PartialShape key_shape{1, 32, -1, 32}; + const PartialShape value_shape{1, 32, -1, 32}; + + const auto query = std::make_shared(element::f32, query_shape); + const auto key = std::make_shared(element::f32, key_shape); + const auto value = std::make_shared(element::f32, value_shape); + const auto sdpa_mask_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{0.0f}); + const auto sdpa_scale_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{2.0f}); + const auto scale_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{8.0f}); + const auto scale_dyn = std::make_shared(element::f32, ov::Shape{}); + const auto v_scaled = std::make_shared(value, scale_const); + const auto casual = false; + const auto q_scaled = std::make_shared(query, scale_dyn); + { + const auto k_scaled = std::make_shared(key, scale_const); + const auto sdpa = std::make_shared(q_scaled, + k_scaled, + v_scaled, + sdpa_mask_const, + sdpa_scale_const, + casual); + + model = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value, scale_dyn}); + manager.register_pass(); + } + + { + const auto new_scale_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{16.0f}); + const auto sdpa = std::make_shared(q_scaled, + key, + v_scaled, + sdpa_mask_const, + new_scale_const, + casual); + model_ref = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value, scale_dyn}); + } + + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} + +TEST_F(TransformationTestsF, SDPAScaleFusionTest5) { + const PartialShape query_shape{1, 32, -1, 32}; + const PartialShape key_shape{1, 32, -1, 32}; + const PartialShape value_shape{1, 32, -1, 32}; + + const auto query = std::make_shared(element::f32, query_shape); + const auto key = std::make_shared(element::f32, key_shape); + const auto value = std::make_shared(element::f32, value_shape); + const auto sdpa_mask_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{0.0f}); + const auto sdpa_scale_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{1.0f}); + const auto scale_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{1.0f}); + const auto scale_dyn = std::make_shared(element::f32, ov::Shape{}); + const auto v_scaled = std::make_shared(value, scale_const); + const auto casual = false; + { + const auto q_scaled = std::make_shared(query, scale_dyn); + const auto k_scaled = std::make_shared(key, scale_const); + const auto sdpa = std::make_shared(q_scaled, + k_scaled, + v_scaled, + sdpa_mask_const, + sdpa_scale_const, + casual); + + model = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value, scale_dyn}); + manager.register_pass(); + } + + { + const auto sdpa = std::make_shared(query, + key, + v_scaled, + sdpa_mask_const, + scale_dyn, + casual); + model_ref = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value, scale_dyn}); + } + + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index a63377312ecb95..fb9e0925bc89e2 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -37,6 +37,7 @@ #include "transformations/common_optimizations/nop_elimination.hpp" #include "transformations/common_optimizations/reshape_prelu.hpp" #include "transformations/common_optimizations/rms_fusion.hpp" +#include "transformations/common_optimizations/sdpa_fusion.hpp" #include "transformations/common_optimizations/transpose_sinking.hpp" #include "transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp" #include "transformations/common_optimizations/wrap_interpolate_into_transposes.hpp" @@ -695,6 +696,7 @@ void Transformations::PreLpt(const std::vector& defaultPrecis CPU_DISABLE_PASS_COMMON(manager, ov::pass::MatMulConstTransposesExtraction); CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertScatterNDUpdate15ToScatterNDUpdate3); CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertSliceScatter); + CPU_DISABLE_PASS_COMMON(manager, ov::pass::SDPAFusion); CPU_DISABLE_PASS_X64(manager, ov::pass::HSigmoidDecomposition); CPU_DISABLE_PASS_X64(manager, ov::pass::ReduceL1Decomposition); diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 53ab9aa188b7aa..7c7c09adcd182f 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -92,6 +92,7 @@ #include "transformations/common_optimizations/lstm_cell_fusion.hpp" #include "transformations/common_optimizations/move_eltwise_up_data_movement.hpp" #include "transformations/common_optimizations/mvn_fusion.hpp" +#include "transformations/common_optimizations/sdpa_scale_fusion.hpp" #include "transformations/common_optimizations/softmax_fusion.hpp" #include "transformations/common_optimizations/glu_fusion.hpp" #include "transformations/common_optimizations/transpose_sinking.hpp" @@ -941,6 +942,7 @@ void TransformationsPipeline::apply(std::shared_ptr func) { if (!disable_horizontal_fc_fusion) manager.register_pass(); + manager.register_pass(); manager.register_pass(); auto pass_config = manager.get_pass_config(); manager.register_pass();