diff --git a/src/common/transformations/include/transformations/utils/gen_pattern.hpp b/src/common/transformations/include/transformations/utils/gen_pattern.hpp index 976561b4844a17..215825d2cd13eb 100644 --- a/src/common/transformations/include/transformations/utils/gen_pattern.hpp +++ b/src/common/transformations/include/transformations/utils/gen_pattern.hpp @@ -40,6 +40,14 @@ namespace gen_pattern { #ifdef CPU_DEBUG_CAPS +# ifdef __GNUC__ +# define CURRENT_LINE_NO __builtin_LINE() +# define CURRENT_FILE __builtin_FILE() +# else +# define CURRENT_LINE_NO -1 +# define CURRENT_FILE "" +# endif + template static inline void _verbose_log(Args&&... args) { std::stringstream ss; @@ -58,6 +66,10 @@ static bool matcher_verbose_enabled() { if (matcher_verbose_enabled()) \ _verbose_log(__VA_ARGS__) #else + +# define CURRENT_LINE_NO -1 +# define CURRENT_FILE "" + static bool matcher_verbose_enabled() { return false; } @@ -181,6 +193,8 @@ class Symbol { double literal_const_value; std::shared_ptr lhs; std::shared_ptr rhs; + const char* filename = ""; + int line_no = -1; // _,+,-,*,/ // l : literal const // n : named symbol @@ -220,10 +234,12 @@ class Symbol { entity->op = 'n'; entity->name = name; } - Symbol(const int value) { + Symbol(const int value, int line_no = CURRENT_LINE_NO, const char* file = CURRENT_FILE) { entity = std::make_shared(); entity->op = 'l'; entity->literal_const_value = value; + entity->line_no = line_no; + entity->filename = file; } Symbol(char op, const Symbol& lhs, const Symbol& rhs) { entity = std::make_shared(); @@ -246,8 +262,12 @@ class Symbol { void* get_id() const { return entity.get(); } - const char* get_name() const { - return entity->name; + std::string get_name() const { + if (entity->line_no == -1 || is_independent_var()) + return entity->name; + auto filename = strrchr(entity->filename, '/') ? strrchr(entity->filename, '/') + 1 : entity->filename; + std::string name(filename); // use filename:lineno instead + return name + ":" + std::to_string(entity->line_no); } bool operator<(const Symbol& rhs) const { return get_id() < rhs.get_id(); @@ -739,7 +759,9 @@ class GenericPattern : public ov::pass::pattern::op::Pattern { explicit GenericPattern(const DiscreteTypeInfo& type_info, const OutputVector& args, const detail::AttrMap& attrs, - const char* vt) + const char* vt, + const int line_no = -1, + const char* file = "") : ov::pass::pattern::op::Pattern(args), m_type_info(type_info), m_attrs(attrs), @@ -758,6 +780,12 @@ class GenericPattern : public ov::pass::pattern::op::Pattern { sep = ","; } ss << ")"; + if (line_no != -1) { + // add the code line no to the log: + // O P752(P736,P745)@fuse_rotary_positional_embeddings.cpp:551 vs ... + auto filename = strrchr(file, '/') ? strrchr(file, '/') + 1 : file; + ss << "@" << filename << ":" << line_no; + } m_signature = ss.str(); set_friendly_name(std::string("P") + std::to_string(id)); } @@ -776,7 +804,13 @@ class GenericPattern : public ov::pass::pattern::op::Pattern { // strictly requires pattern & graph value to come from output port with same index, // this is absolute necessary when pattern contains split node connections. if (pattern_value.get_index() != graph_value.get_index()) { - _VERBOSE_LOG(level, "X output index mismatch: ", pattern_value.get_index(), "!=", graph_value.get_index()); + _VERBOSE_LOG(level, + "X output index mismatch:(", + m_signature, + "): ", + pattern_value.get_index(), + "!=", + graph_value.get_index()); return false; } @@ -1018,7 +1052,9 @@ template std::shared_ptr makePattern(const std::vector& inputs, detail::AttrMap attrmap = {}, const char* vt = nullptr, - const char* friendly_name = nullptr) { + const char* friendly_name = nullptr, + int line_no = CURRENT_LINE_NO, + const char* file = CURRENT_FILE) { OutputVector args; for (auto& in : inputs) args.push_back(in.get_output()); @@ -1026,7 +1062,8 @@ std::shared_ptr makePattern(const std::vector& inputs // pattern nodes are better for pattern matching because // - it can be generic/incomplete, so normal OP node is not working properly // - it has predicate to correctly decide which branch to take (in Or pattern) - auto pattern_node = std::make_shared(T::get_type_info_static(), args, attrmap, vt); + auto pattern_node = + std::make_shared(T::get_type_info_static(), args, attrmap, vt, line_no, file); if (friendly_name) pattern_node->set_friendly_name(friendly_name); @@ -1120,7 +1157,9 @@ inline std::shared_ptr GenStridedSlice(detail::PatternNode data, detail::PatternNode start, detail::PatternNode stop, detail::PatternNode step, - size_t axis) { + size_t axis, + int line_no = CURRENT_LINE_NO, + const char* file = CURRENT_FILE) { std::vector begin_mask(axis + 1, 1); std::vector end_mask(axis + 1, 1); std::vector new_axis_mask; @@ -1135,12 +1174,27 @@ inline std::shared_ptr GenStridedSlice(detail::PatternNode data, {"end_mask", end_mask}, {"new_axis_mask", new_axis_mask}, {"shrink_axis_mask", shrink_axis_mask}, - {"ellipsis_mask", ellipsis_mask}}); + {"ellipsis_mask", ellipsis_mask}}, + nullptr, + nullptr, + line_no, + file); return opt2; } -inline std::shared_ptr GenSlice(detail::PatternNode data, Symbol start, Symbol stop, Symbol step, size_t axis) { - auto opt1 = makePattern({data, {start}, {stop}, {step}, {static_cast(axis)}}); +inline std::shared_ptr GenSlice(detail::PatternNode data, + Symbol start, + Symbol stop, + Symbol step, + size_t axis, + int line_no = CURRENT_LINE_NO, + const char* file = CURRENT_FILE) { + auto opt1 = makePattern({data, {start}, {stop}, {step}, {static_cast(axis)}}, + {}, + nullptr, + nullptr, + line_no, + file); std::vector vbegin(axis + 1, Symbol(0)); std::vector vend(axis + 1, Symbol(0)); @@ -1168,7 +1222,11 @@ inline std::shared_ptr GenSlice(detail::PatternNode data, Symbol start, Sy {"end_mask", end_mask}, {"new_axis_mask", new_axis_mask}, {"shrink_axis_mask", shrink_axis_mask}, - {"ellipsis_mask", ellipsis_mask}}); + {"ellipsis_mask", ellipsis_mask}}, + nullptr, + nullptr, + line_no, + file); return opt1 | opt2; } @@ -1329,7 +1387,9 @@ class PatternValidator { auto id = sym.get_id(); if (symbol_value_map.count(id)) { if (symbol_value_map[id] != value) { - _VERBOSE_LOG(" in-consistency between multiple references of same symbol : ", + _VERBOSE_LOG(" in-consistency between multiple references of same symbol(", + sym.get_name(), + "): ", symbol_value_map[id], " != ", value); @@ -1345,7 +1405,12 @@ class PatternValidator { if (sym.is_literal_const()) { auto literal = sym.eval(symbol_value_map); if (literal != value) { - _VERBOSE_LOG(" mismatch between literal symbol & value : ", literal, " != ", value); + _VERBOSE_LOG(" mismatch between literal symbol & value(", + sym.get_name(), + "): ", + literal, + " != ", + value); return false; } // no need to put literal into value map to eval them. @@ -1373,7 +1438,9 @@ class PatternValidator { } } if (!is_match) { - _VERBOSE_LOG(" mismatch between derived & value : ", + _VERBOSE_LOG(" mismatch between derived & value(", + sym.get_name(), + "): ", std::setprecision(std::numeric_limits::max_digits10), derived, " != ", diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp index 447adb0b2fe23f..fe4a71f44be958 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -20,7 +21,12 @@ #include "itt.hpp" #include "openvino/opsets/opset1.hpp" #include "ov_ops/type_relaxed.hpp" +#include "transformations/common_optimizations/simplify_shape_of_sub_graph.hpp" #include "transformations/cpu_opset/common/op/sdpa.hpp" +#include "transformations/cpu_opset/x64/pass/sdpa_fuse_transpose_reshape.hpp" +#include "transformations/defs.hpp" +#include "transformations/op_conversions/convert_broadcast3.hpp" +#include "transformations/transpose_sinking/ts_shape_of.hpp" using namespace ov::gen_pattern; namespace ov { @@ -56,8 +62,9 @@ StatefulSDPAFusion::StatefulSDPAFusion() { std::shared_ptr reshape_k, reshape_v, unsqueeze_k, unsqueeze_v; std::shared_ptr computed_bcst_k, computed_bcst_v, multiply_k, multiply_v; std::shared_ptr mq_reshape_k, mq_reshape_v; + std::shared_ptr computed_bcst3_k, computed_bcst3_v; auto multi_query_bcst = [](const std::shared_ptr& kv) { - auto reshape_kv = wrap_type({kv, any_input()}); + auto reshape_kv = makePattern({kv, any_input()}); auto unsqueeze_kv = makePattern({kv, any_input()}); auto check_one = [](Output output) -> bool { @@ -73,13 +80,17 @@ StatefulSDPAFusion::StatefulSDPAFusion() { makePattern({wrap_type(check_one), any_input(), any_input()}, {{"mode", "numpy"}}); - auto multiply_kv = wrap_type({reshape_kv | unsqueeze_kv, constant_bcst | computed_bcst}); - auto result = wrap_type({multiply_kv, any_input()}); - return std::make_tuple(result, reshape_kv, unsqueeze_kv, computed_bcst, multiply_kv); + auto multiply_kv = makePattern({reshape_kv | unsqueeze_kv, constant_bcst | computed_bcst}); + auto computed_bcst3 = makePattern({unsqueeze_kv, any_input()}, {{"mode", "bidirectional"}}); + + auto result = makePattern({multiply_kv | computed_bcst3, any_input()}); + return std::make_tuple(result, reshape_kv, unsqueeze_kv, computed_bcst, multiply_kv, computed_bcst3); }; - std::tie(mq_reshape_k, reshape_k, unsqueeze_k, computed_bcst_k, multiply_k) = multi_query_bcst(concat_k); - std::tie(mq_reshape_v, reshape_v, unsqueeze_v, computed_bcst_v, multiply_v) = multi_query_bcst(concat_v); + std::tie(mq_reshape_k, reshape_k, unsqueeze_k, computed_bcst_k, multiply_k, computed_bcst3_k) = + multi_query_bcst(concat_k); + std::tie(mq_reshape_v, reshape_v, unsqueeze_v, computed_bcst_v, multiply_v, computed_bcst3_v) = + multi_query_bcst(concat_v); auto present_k = concat_k | mq_reshape_k; auto present_v = concat_v | mq_reshape_v; @@ -178,15 +189,19 @@ StatefulSDPAFusion::StatefulSDPAFusion() { opset6::Assign *assign_k_node = nullptr, *assign_v_node = nullptr; opset1::Convert *assign_cvt_k_node = nullptr, *assign_cvt_v_node = nullptr; - if (!find_assign(concat_k_node, assign_k_node, assign_cvt_k_node)) + if (!find_assign(concat_k_node, assign_k_node, assign_cvt_k_node)) { return false; - if (past_k_node->get_variable_id() != assign_k_node->get_variable_id()) + } + if (past_k_node->get_variable_id() != assign_k_node->get_variable_id()) { return false; + } - if (!find_assign(concat_v_node, assign_v_node, assign_cvt_v_node)) + if (!find_assign(concat_v_node, assign_v_node, assign_cvt_v_node)) { return false; - if (past_v_node->get_variable_id() != assign_v_node->get_variable_id()) + } + if (past_v_node->get_variable_id() != assign_v_node->get_variable_id()) { return false; + } auto is_optional_one_child = [&pattern_map](const std::vector>& nodes) { for (auto&& node : nodes) { @@ -212,7 +227,9 @@ StatefulSDPAFusion::StatefulSDPAFusion() { computed_bcst_v, multiply_v, mq_reshape_k, - mq_reshape_v})) { + mq_reshape_v, + computed_bcst3_k, + computed_bcst3_v})) { return false; } @@ -284,5 +301,18 @@ StatefulSDPAFusion::StatefulSDPAFusion() { this->register_matcher(m, callback); } +bool SDPASubgraphFusion::run_on_model(const std::shared_ptr& f) { + RUN_ON_FUNCTION_SCOPE(SDPASubgraphFusion); + ov::pass::Manager manager("SDPASubgraphFusion"); + + CPU_REGISTER_PASS_COMMON(manager, ov::pass::SimplifyGatherShapeOf); + CPU_REGISTER_PASS_COMMON(manager, ov::pass::transpose_sinking::TSShapeOfForward); + CPU_REGISTER_PASS_COMMON(manager, StatefulSDPAFusion); + CPU_REGISTER_PASS_X64(manager, SDPAFuseTransposeReshape); + + manager.run_passes(f); + return false; +} + } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.hpp index 96028402aa9f92..59494736bb2c2e 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.hpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.hpp @@ -14,5 +14,12 @@ class StatefulSDPAFusion : public ov::pass::MatcherPass { StatefulSDPAFusion(); }; +class SDPASubgraphFusion : public ov::pass::ModelPass { +public: + OPENVINO_RTTI("SDPASubgraphFusion", "0"); + + bool run_on_model(const std::shared_ptr& f) override; +}; + } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index fb9e0925bc89e2..9fca3f1a62709b 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -430,6 +430,7 @@ void Transformations::PreLpt(const std::vector& defaultPrecis ov::pass::KeepConstAndDecompression); CPU_REGISTER_PASS_COMMON(manager, ov::pass::AUGRUCellFusion); + CPU_REGISTER_PASS_COMMON(manager, SDPASubgraphFusion); CPU_REGISTER_PASS_COMMON(manager, ov::pass::CommonOptimizations); CPU_REGISTER_PASS_X64(manager, ov::pass::KeepConstsPrecision, decompression_precisions, false, true); CPU_SET_CALLBACK_X64( @@ -653,16 +654,6 @@ void Transformations::PreLpt(const std::vector& defaultPrecis CPU_SET_CALLBACK_COMMON(manager, nmsCallback, ov::pass::ConvertNMS9ToNMSIEInternal); CPU_SET_CALLBACK_COMMON(manager, nmsCallback, ov::pass::ConvertMulticlassNmsToMulticlassNmsIE); CPU_SET_CALLBACK_COMMON(manager, nmsCallback, ov::pass::ConvertMatrixNmsToMatrixNmsIE); - CPU_SET_CALLBACK_COMMON( - manager, - [this](const_node_ptr& node) -> bool { - std::string errorMsg; - // Current SDPA impl is optimized only for LLM models, so we decompose it for others to avoid perf - // regression. Matching the pattern is a little complicated, so we just check if there is any state nodes. - return node::ScaledDotProductAttention::isSupportedOperation(node, errorMsg) && - model->get_variables().size() > 0; - }, - ov::pass::ScaledDotProductAttentionDecomposition); // List of enabled/disabled transformations @@ -944,9 +935,6 @@ void Transformations::PostLpt() { } #endif // OPENVINO_ARCH_X86_64 - CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::transpose_sinking::TSShapeOfForward); - CPU_REGISTER_PASS_COMMON(postLPTPassManager, StatefulSDPAFusion); - CPU_REGISTER_PASS_X64(postLPTPassManager, ov::intel_cpu::SDPAFuseTransposeReshape); CPU_REGISTER_PASS_X64(postLPTPassManager, ov::pass::RMSFusion, false); CPU_REGISTER_PASS_X64(postLPTPassManager, ov::intel_cpu::DecomposeRMSNorm); CPU_SET_CALLBACK_X64( diff --git a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/common/concat_multiple_query_sdp.cpp b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/common/concat_multiple_query_sdp.cpp index d74ab99fb3d5ab..fe5ba2b7eac5e7 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/common/concat_multiple_query_sdp.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/common/concat_multiple_query_sdp.cpp @@ -152,9 +152,9 @@ class ConcatMultiQuerySDPTest : public testing::WithParamInterface(concatK, unsquezeAxis); auto unsqueezeV = std::make_shared(concatV, unsquezeAxis); - auto targetShape = ov::op::v0::Constant::create(qkvType, {1, 1, 1, 4, 1}, {1}); - auto broadcastK = std::make_shared(unsqueezeK, targetShape); - auto broadcastV = std::make_shared(unsqueezeV, targetShape); + auto targetShape = ov::op::v0::Constant::create(element::i32, {5}, {1, 1, 1, 4, 1}); + auto broadcastK = std::make_shared(unsqueezeK, targetShape, op::BroadcastType::BIDIRECTIONAL); + auto broadcastV = std::make_shared(unsqueezeV, targetShape, op::BroadcastType::BIDIRECTIONAL); auto target4D = ov::op::v0::Constant::create(ov::element::i32, {4}, {0, 0, 8, 64}); diff --git a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/common/concat_transpose_sdp_transpose.cpp b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/common/concat_transpose_sdp_transpose.cpp index f4166544af2bf2..8ba978e32c4b9c 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/common/concat_transpose_sdp_transpose.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/common/concat_transpose_sdp_transpose.cpp @@ -71,7 +71,7 @@ class ConcatSDPTransposeTestBase : public testing::WithParamInterfaceGetParam(); std::vector& inputShapes = inputShapeAndOrders.first; transposeOrder = inputShapeAndOrders.second; @@ -124,6 +123,10 @@ class ConcatSDPTransposeTestBase : public testing::WithParamInterface(inputParams[0], preOrder); + std::shared_ptr transposeQ_shapeof; + if (hasShapeOf) { + transposeQ_shapeof = std::make_shared(transposeQ); + } auto concat_axis = transposeOrder[2]; auto beam_idx = std::make_shared(ElementType::i32, ov::PartialShape{-1}); @@ -166,6 +169,7 @@ class ConcatSDPTransposeTestBase : public testing::WithParamInterface(results, sinks, inputParams, "ConcatTranposeSDP"); @@ -237,6 +241,7 @@ class ConcatSDPTransposeTestBase : public testing::WithParamInterface transposeOrder; + bool hasShapeOf; }; class ConcatSDPTransposeTest : public ConcatSDPTransposeTestBase { @@ -287,7 +292,10 @@ TEST_P(ConcatSDPTransposeTest, CompareWithRefs) { CheckNumberOfNodesWithType(compiledModel, "Concatenation", 0); CheckNumberOfNodesWithType(compiledModel, "Reorder", 0); CheckNumberOfNodesWithType(compiledModel, "Transpose", 1); - CheckNumberOfNodesWithType(compiledModel, "Gather", 0); + // Transformation TSShapeOfForward will change: + // ?->transpose->shapeof ==> ?-->shapeof->gather + // |->transpose + CheckNumberOfNodesWithType(compiledModel, "Gather", hasShapeOf ? 1 : 0); auto expectedOutputs = run_test(functionRefs); CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 0); for (size_t i = 0; i < actualOutputs.size(); i++) {