Skip to content

Commit

Permalink
Fix cpu tests and pattern
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Paramuzov <[email protected]>
  • Loading branch information
vladimir-paramuzov committed Dec 18, 2024
1 parent e52a306 commit c4f73d8
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
//

#include "transformations/common_optimizations/sdpa_fusion.hpp"

#include "openvino/core/rt_info.hpp"
#include "openvino/core/type.hpp"
#include "openvino/op/add.hpp"
Expand All @@ -15,6 +14,7 @@
#include "openvino/op/unsqueeze.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/gen_pattern.hpp"

namespace ov {
Expand All @@ -29,7 +29,12 @@ SDPAFusion::SDPAFusion() {
auto v = makePattern(ov::Rank(4));
auto mask = makePattern();

auto k_t = makePattern<ov::op::v1::Transpose>({k, {0, 1, 3, 2}});
auto k_transpose_order = pattern::wrap_type<ov::op::v0::Constant>([](const Output<Node>& node) {
auto axis_order = std::dynamic_pointer_cast<ov::op::v0::Constant>(node.get_node_shared_ptr())->cast_vector<int64_t>();
return axis_order == std::vector<int64_t>{0, 1, 3, 2};
});

auto k_t = pattern::wrap_type<ov::op::v1::Transpose>({k, k_transpose_order});
auto qk_nn = makePattern<ov::op::v0::MatMul>({q, k_t}, {{"transpose_a", false}, {"transpose_b", false}});
auto qk_nt = makePattern<ov::op::v0::MatMul>({q, k}, {{"transpose_a", false}, {"transpose_b", true}});
auto qk = qk_nt | qk_nn;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,24 @@ TEST_F(TransformationTestsF, SDPAFusionTest6) {
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
}

TEST_F(TransformationTestsF, SDPAFusionTest7) {
const PartialShape query_shape{1, 8, -1, 32};
const PartialShape key_shape{-1, 1, 8, 32};
const PartialShape value_shape{1, 8, -1, 32};

const auto query = std::make_shared<ov::op::v0::Parameter>(element::f16, query_shape);
const auto key = std::make_shared<ov::op::v0::Parameter>(element::f16, key_shape);
const auto value = std::make_shared<ov::op::v0::Parameter>(element::f16, value_shape);
{
const auto key_t =
std::make_shared<ov::op::v1::Transpose>(key,
op::v0::Constant::create(element::i64, Shape{4}, {1, 2, 3, 0}));
const auto qk = std::make_shared<ov::op::v0::MatMul>(query, key_t, false, false);
const auto softmax = std::make_shared<ov::op::v8::Softmax>(qk, -1);
const auto qkv = std::make_shared<ov::op::v0::MatMul>(softmax, value, false, false);

model = std::make_shared<ov::Model>(NodeVector{qkv}, ParameterVector{query, key, value});
manager.register_pass<ov::pass::SDPAFusion>();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "transformations/common_optimizations/nop_elimination.hpp"
#include "transformations/common_optimizations/reshape_prelu.hpp"
#include "transformations/common_optimizations/rms_fusion.hpp"
#include "transformations/common_optimizations/sdpa_fusion.hpp"
#include "transformations/common_optimizations/transpose_sinking.hpp"
#include "transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp"
#include "transformations/common_optimizations/wrap_interpolate_into_transposes.hpp"
Expand Down Expand Up @@ -700,6 +701,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
CPU_DISABLE_PASS_COMMON(manager, ov::pass::MatMulConstTransposesExtraction);
CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertScatterNDUpdate15ToScatterNDUpdate3);
CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertSliceScatter);
CPU_DISABLE_PASS_COMMON(manager, ov::pass::SDPAFusion);
CPU_DISABLE_PASS_X64(manager, ov::pass::HSigmoidDecomposition);

CPU_DISABLE_PASS_X64(manager, ov::pass::ReduceL1Decomposition);
Expand Down

0 comments on commit c4f73d8

Please sign in to comment.