diff --git a/src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp index f759be2b5de293..b92bfe999114b9 100644 --- a/src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp @@ -3,7 +3,6 @@ // #include "transformations/common_optimizations/sdpa_fusion.hpp" - #include "openvino/core/rt_info.hpp" #include "openvino/core/type.hpp" #include "openvino/op/add.hpp" @@ -15,6 +14,7 @@ #include "openvino/op/unsqueeze.hpp" #include "openvino/pass/pattern/op/optional.hpp" #include "openvino/pass/pattern/op/pattern.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" #include "transformations/utils/gen_pattern.hpp" namespace ov { @@ -29,7 +29,12 @@ SDPAFusion::SDPAFusion() { auto v = makePattern(ov::Rank(4)); auto mask = makePattern(); - auto k_t = makePattern({k, {0, 1, 3, 2}}); + auto k_transpose_order = pattern::wrap_type([](const Output& node) { + auto axis_order = std::dynamic_pointer_cast(node.get_node_shared_ptr())->cast_vector(); + return axis_order == std::vector{0, 1, 3, 2}; + }); + + auto k_t = pattern::wrap_type({k, k_transpose_order}); auto qk_nn = makePattern({q, k_t}, {{"transpose_a", false}, {"transpose_b", false}}); auto qk_nt = makePattern({q, k}, {{"transpose_a", false}, {"transpose_b", true}}); auto qk = qk_nt | qk_nn; diff --git a/src/common/transformations/tests/common_optimizations/sdpa_fusion_test.cpp b/src/common/transformations/tests/common_optimizations/sdpa_fusion_test.cpp index aef62a783c6d60..52c10ba5967bd8 100644 --- a/src/common/transformations/tests/common_optimizations/sdpa_fusion_test.cpp +++ b/src/common/transformations/tests/common_optimizations/sdpa_fusion_test.cpp @@ -211,3 +211,24 @@ TEST_F(TransformationTestsF, SDPAFusionTest6) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); } + +TEST_F(TransformationTestsF, SDPAFusionTest7) { + const PartialShape query_shape{1, 8, -1, 32}; + const PartialShape key_shape{-1, 1, 8, 32}; + const PartialShape value_shape{1, 8, -1, 32}; + + const auto query = std::make_shared(element::f16, query_shape); + const auto key = std::make_shared(element::f16, key_shape); + const auto value = std::make_shared(element::f16, value_shape); + { + const auto key_t = + std::make_shared(key, + op::v0::Constant::create(element::i64, Shape{4}, {1, 2, 3, 0})); + const auto qk = std::make_shared(query, key_t, false, false); + const auto softmax = std::make_shared(qk, -1); + const auto qkv = std::make_shared(softmax, value, false, false); + + model = std::make_shared(NodeVector{qkv}, ParameterVector{query, key, value}); + manager.register_pass(); + } +} diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 469abbd99eb149..9066be8adaf866 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -37,6 +37,7 @@ #include "transformations/common_optimizations/nop_elimination.hpp" #include "transformations/common_optimizations/reshape_prelu.hpp" #include "transformations/common_optimizations/rms_fusion.hpp" +#include "transformations/common_optimizations/sdpa_fusion.hpp" #include "transformations/common_optimizations/transpose_sinking.hpp" #include "transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp" #include "transformations/common_optimizations/wrap_interpolate_into_transposes.hpp" @@ -700,6 +701,7 @@ void Transformations::PreLpt(const std::vector& defaultPrecis CPU_DISABLE_PASS_COMMON(manager, ov::pass::MatMulConstTransposesExtraction); CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertScatterNDUpdate15ToScatterNDUpdate3); CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertSliceScatter); + CPU_DISABLE_PASS_COMMON(manager, ov::pass::SDPAFusion); CPU_DISABLE_PASS_X64(manager, ov::pass::HSigmoidDecomposition); CPU_DISABLE_PASS_X64(manager, ov::pass::ReduceL1Decomposition);