Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Fuse SDPA and Concat as early as possible #28189

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ namespace gen_pattern {

#ifdef CPU_DEBUG_CAPS

# ifdef __GNUC__
# define CURRENT_LINE_NO __builtin_LINE()
# define CURRENT_FILE __builtin_FILE()
# else
# define CURRENT_LINE_NO -1
# define CURRENT_FILE ""
# endif

template <typename... Args>
static inline void _verbose_log(Args&&... args) {
std::stringstream ss;
Expand All @@ -58,6 +66,10 @@ static bool matcher_verbose_enabled() {
if (matcher_verbose_enabled()) \
_verbose_log(__VA_ARGS__)
#else

# define CURRENT_LINE_NO -1
# define CURRENT_FILE ""

static bool matcher_verbose_enabled() {
return false;
}
Expand Down Expand Up @@ -181,6 +193,8 @@ class Symbol {
double literal_const_value;
std::shared_ptr<Entity> lhs;
std::shared_ptr<Entity> rhs;
const char* filename = "";
int line_no = -1;
// _,+,-,*,/
// l : literal const
// n : named symbol
Expand Down Expand Up @@ -220,10 +234,12 @@ class Symbol {
entity->op = 'n';
entity->name = name;
}
Symbol(const int value) {
Symbol(const int value, int line_no = CURRENT_LINE_NO, const char* file = CURRENT_FILE) {
entity = std::make_shared<Entity>();
entity->op = 'l';
entity->literal_const_value = value;
entity->line_no = line_no;
entity->filename = file;
}
Symbol(char op, const Symbol& lhs, const Symbol& rhs) {
entity = std::make_shared<Entity>();
Expand All @@ -246,8 +262,12 @@ class Symbol {
void* get_id() const {
return entity.get();
}
const char* get_name() const {
return entity->name;
std::string get_name() const {
if (entity->line_no == -1 || is_independent_var())
return entity->name;
auto filename = strrchr(entity->filename, '/') ? strrchr(entity->filename, '/') + 1 : entity->filename;
std::string name(filename); // use filename:lineno instead
return name + ":" + std::to_string(entity->line_no);
}
bool operator<(const Symbol& rhs) const {
return get_id() < rhs.get_id();
Expand Down Expand Up @@ -739,7 +759,9 @@ class GenericPattern : public ov::pass::pattern::op::Pattern {
explicit GenericPattern(const DiscreteTypeInfo& type_info,
const OutputVector& args,
const detail::AttrMap& attrs,
const char* vt)
const char* vt,
const int line_no = -1,
const char* file = "")
: ov::pass::pattern::op::Pattern(args),
m_type_info(type_info),
m_attrs(attrs),
Expand All @@ -758,6 +780,12 @@ class GenericPattern : public ov::pass::pattern::op::Pattern {
sep = ",";
}
ss << ")";
if (line_no != -1) {
// add the code line no to the log:
// O P752<opset1::Multiply>(P736,P745)@fuse_rotary_positional_embeddings.cpp:551 vs ...
auto filename = strrchr(file, '/') ? strrchr(file, '/') + 1 : file;
ss << "@" << filename << ":" << line_no;
}
m_signature = ss.str();
set_friendly_name(std::string("P") + std::to_string(id));
}
Expand All @@ -776,7 +804,13 @@ class GenericPattern : public ov::pass::pattern::op::Pattern {
// strictly requires pattern & graph value to come from output port with same index,
// this is absolute necessary when pattern contains split node connections.
if (pattern_value.get_index() != graph_value.get_index()) {
_VERBOSE_LOG(level, "X output index mismatch: ", pattern_value.get_index(), "!=", graph_value.get_index());
_VERBOSE_LOG(level,
"X output index mismatch:(",
m_signature,
"): ",
pattern_value.get_index(),
"!=",
graph_value.get_index());
return false;
}

Expand Down Expand Up @@ -1018,15 +1052,18 @@ template <class T>
std::shared_ptr<Node> makePattern(const std::vector<detail::PatternNode>& inputs,
detail::AttrMap attrmap = {},
const char* vt = nullptr,
const char* friendly_name = nullptr) {
const char* friendly_name = nullptr,
int line_no = CURRENT_LINE_NO,
const char* file = CURRENT_FILE) {
OutputVector args;
for (auto& in : inputs)
args.push_back(in.get_output());

// pattern nodes are better for pattern matching because
// - it can be generic/incomplete, so normal OP node is not working properly
// - it has predicate to correctly decide which branch to take (in Or pattern)
auto pattern_node = std::make_shared<detail::GenericPattern>(T::get_type_info_static(), args, attrmap, vt);
auto pattern_node =
std::make_shared<detail::GenericPattern>(T::get_type_info_static(), args, attrmap, vt, line_no, file);

if (friendly_name)
pattern_node->set_friendly_name(friendly_name);
Expand Down Expand Up @@ -1120,7 +1157,9 @@ inline std::shared_ptr<Node> GenStridedSlice(detail::PatternNode data,
detail::PatternNode start,
detail::PatternNode stop,
detail::PatternNode step,
size_t axis) {
size_t axis,
int line_no = CURRENT_LINE_NO,
const char* file = CURRENT_FILE) {
std::vector<int64_t> begin_mask(axis + 1, 1);
std::vector<int64_t> end_mask(axis + 1, 1);
std::vector<int64_t> new_axis_mask;
Expand All @@ -1135,12 +1174,27 @@ inline std::shared_ptr<Node> GenStridedSlice(detail::PatternNode data,
{"end_mask", end_mask},
{"new_axis_mask", new_axis_mask},
{"shrink_axis_mask", shrink_axis_mask},
{"ellipsis_mask", ellipsis_mask}});
{"ellipsis_mask", ellipsis_mask}},
nullptr,
nullptr,
line_no,
file);
return opt2;
}

inline std::shared_ptr<Node> GenSlice(detail::PatternNode data, Symbol start, Symbol stop, Symbol step, size_t axis) {
auto opt1 = makePattern<opset8::Slice>({data, {start}, {stop}, {step}, {static_cast<int>(axis)}});
inline std::shared_ptr<Node> GenSlice(detail::PatternNode data,
Symbol start,
Symbol stop,
Symbol step,
size_t axis,
int line_no = CURRENT_LINE_NO,
const char* file = CURRENT_FILE) {
auto opt1 = makePattern<opset8::Slice>({data, {start}, {stop}, {step}, {static_cast<int>(axis)}},
{},
nullptr,
nullptr,
line_no,
file);

std::vector<Symbol> vbegin(axis + 1, Symbol(0));
std::vector<Symbol> vend(axis + 1, Symbol(0));
Expand Down Expand Up @@ -1168,7 +1222,11 @@ inline std::shared_ptr<Node> GenSlice(detail::PatternNode data, Symbol start, Sy
{"end_mask", end_mask},
{"new_axis_mask", new_axis_mask},
{"shrink_axis_mask", shrink_axis_mask},
{"ellipsis_mask", ellipsis_mask}});
{"ellipsis_mask", ellipsis_mask}},
nullptr,
nullptr,
line_no,
file);
return opt1 | opt2;
}

Expand Down Expand Up @@ -1329,7 +1387,9 @@ class PatternValidator {
auto id = sym.get_id();
if (symbol_value_map.count(id)) {
if (symbol_value_map[id] != value) {
_VERBOSE_LOG(" in-consistency between multiple references of same symbol : ",
_VERBOSE_LOG(" in-consistency between multiple references of same symbol(",
sym.get_name(),
"): ",
symbol_value_map[id],
" != ",
value);
Expand All @@ -1345,7 +1405,12 @@ class PatternValidator {
if (sym.is_literal_const()) {
auto literal = sym.eval(symbol_value_map);
if (literal != value) {
_VERBOSE_LOG(" mismatch between literal symbol & value : ", literal, " != ", value);
_VERBOSE_LOG(" mismatch between literal symbol & value(",
sym.get_name(),
"): ",
literal,
" != ",
value);
return false;
}
// no need to put literal into value map to eval them.
Expand Down Expand Up @@ -1373,7 +1438,9 @@ class PatternValidator {
}
}
if (!is_match) {
_VERBOSE_LOG(" mismatch between derived & value : ",
_VERBOSE_LOG(" mismatch between derived & value(",
sym.get_name(),
"): ",
std::setprecision(std::numeric_limits<float>::max_digits10),
derived,
" != ",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <openvino/opsets/opset13.hpp>
#include <openvino/opsets/opset6.hpp>
#include <openvino/opsets/opset8.hpp>
#include <openvino/pass/manager.hpp>
#include <openvino/pass/pattern/op/or.hpp>
#include <openvino/pass/pattern/op/wrap_type.hpp>
#include <transformations/utils/gen_pattern.hpp>
Expand All @@ -20,7 +21,12 @@
#include "itt.hpp"
#include "openvino/opsets/opset1.hpp"
#include "ov_ops/type_relaxed.hpp"
#include "transformations/common_optimizations/simplify_shape_of_sub_graph.hpp"
#include "transformations/cpu_opset/common/op/sdpa.hpp"
#include "transformations/cpu_opset/x64/pass/sdpa_fuse_transpose_reshape.hpp"
#include "transformations/defs.hpp"
#include "transformations/op_conversions/convert_broadcast3.hpp"
#include "transformations/transpose_sinking/ts_shape_of.hpp"
using namespace ov::gen_pattern;

namespace ov {
Expand Down Expand Up @@ -56,8 +62,9 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
std::shared_ptr<Node> reshape_k, reshape_v, unsqueeze_k, unsqueeze_v;
std::shared_ptr<Node> computed_bcst_k, computed_bcst_v, multiply_k, multiply_v;
std::shared_ptr<Node> mq_reshape_k, mq_reshape_v;
std::shared_ptr<Node> computed_bcst3_k, computed_bcst3_v;
auto multi_query_bcst = [](const std::shared_ptr<Node>& kv) {
auto reshape_kv = wrap_type<opset6::Reshape>({kv, any_input()});
auto reshape_kv = makePattern<opset6::Reshape>({kv, any_input()});
auto unsqueeze_kv = makePattern<opset1::Unsqueeze>({kv, any_input()});

auto check_one = [](Output<Node> output) -> bool {
Expand All @@ -73,13 +80,17 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
makePattern<opset1::Broadcast>({wrap_type<opset1::Constant>(check_one), any_input(), any_input()},
{{"mode", "numpy"}});

auto multiply_kv = wrap_type<opset6::Multiply>({reshape_kv | unsqueeze_kv, constant_bcst | computed_bcst});
auto result = wrap_type<opset6::Reshape>({multiply_kv, any_input()});
return std::make_tuple(result, reshape_kv, unsqueeze_kv, computed_bcst, multiply_kv);
auto multiply_kv = makePattern<opset6::Multiply>({reshape_kv | unsqueeze_kv, constant_bcst | computed_bcst});
auto computed_bcst3 = makePattern<opset3::Broadcast>({unsqueeze_kv, any_input()}, {{"mode", "bidirectional"}});

auto result = makePattern<opset6::Reshape>({multiply_kv | computed_bcst3, any_input()});
return std::make_tuple(result, reshape_kv, unsqueeze_kv, computed_bcst, multiply_kv, computed_bcst3);
};

std::tie(mq_reshape_k, reshape_k, unsqueeze_k, computed_bcst_k, multiply_k) = multi_query_bcst(concat_k);
std::tie(mq_reshape_v, reshape_v, unsqueeze_v, computed_bcst_v, multiply_v) = multi_query_bcst(concat_v);
std::tie(mq_reshape_k, reshape_k, unsqueeze_k, computed_bcst_k, multiply_k, computed_bcst3_k) =
multi_query_bcst(concat_k);
std::tie(mq_reshape_v, reshape_v, unsqueeze_v, computed_bcst_v, multiply_v, computed_bcst3_v) =
multi_query_bcst(concat_v);
auto present_k = concat_k | mq_reshape_k;
auto present_v = concat_v | mq_reshape_v;

Expand Down Expand Up @@ -178,15 +189,19 @@ StatefulSDPAFusion::StatefulSDPAFusion() {

opset6::Assign *assign_k_node = nullptr, *assign_v_node = nullptr;
opset1::Convert *assign_cvt_k_node = nullptr, *assign_cvt_v_node = nullptr;
if (!find_assign(concat_k_node, assign_k_node, assign_cvt_k_node))
if (!find_assign(concat_k_node, assign_k_node, assign_cvt_k_node)) {
return false;
if (past_k_node->get_variable_id() != assign_k_node->get_variable_id())
}
if (past_k_node->get_variable_id() != assign_k_node->get_variable_id()) {
return false;
}

if (!find_assign(concat_v_node, assign_v_node, assign_cvt_v_node))
if (!find_assign(concat_v_node, assign_v_node, assign_cvt_v_node)) {
return false;
if (past_v_node->get_variable_id() != assign_v_node->get_variable_id())
}
if (past_v_node->get_variable_id() != assign_v_node->get_variable_id()) {
return false;
}

auto is_optional_one_child = [&pattern_map](const std::vector<std::shared_ptr<Node>>& nodes) {
for (auto&& node : nodes) {
Expand All @@ -212,7 +227,9 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
computed_bcst_v,
multiply_v,
mq_reshape_k,
mq_reshape_v})) {
mq_reshape_v,
computed_bcst3_k,
computed_bcst3_v})) {
return false;
}

Expand Down Expand Up @@ -284,5 +301,18 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
this->register_matcher(m, callback);
}

bool SDPASubgraphFusion::run_on_model(const std::shared_ptr<ov::Model>& f) {
RUN_ON_FUNCTION_SCOPE(SDPASubgraphFusion);
ov::pass::Manager manager("SDPASubgraphFusion");

CPU_REGISTER_PASS_COMMON(manager, ov::pass::SimplifyGatherShapeOf);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::transpose_sinking::TSShapeOfForward);
CPU_REGISTER_PASS_COMMON(manager, StatefulSDPAFusion);
CPU_REGISTER_PASS_X64(manager, SDPAFuseTransposeReshape);

manager.run_passes(f);
return false;
}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,12 @@ class StatefulSDPAFusion : public ov::pass::MatcherPass {
StatefulSDPAFusion();
};

class SDPASubgraphFusion : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("SDPASubgraphFusion", "0");

bool run_on_model(const std::shared_ptr<ov::Model>& f) override;
};

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
ov::pass::KeepConstAndDecompression);

CPU_REGISTER_PASS_COMMON(manager, ov::pass::AUGRUCellFusion);
CPU_REGISTER_PASS_COMMON(manager, SDPASubgraphFusion);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::CommonOptimizations);
CPU_REGISTER_PASS_X64(manager, ov::pass::KeepConstsPrecision, decompression_precisions, false, true);
CPU_SET_CALLBACK_X64(
Expand Down Expand Up @@ -653,16 +654,6 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
CPU_SET_CALLBACK_COMMON(manager, nmsCallback, ov::pass::ConvertNMS9ToNMSIEInternal);
CPU_SET_CALLBACK_COMMON(manager, nmsCallback, ov::pass::ConvertMulticlassNmsToMulticlassNmsIE);
CPU_SET_CALLBACK_COMMON(manager, nmsCallback, ov::pass::ConvertMatrixNmsToMatrixNmsIE);
CPU_SET_CALLBACK_COMMON(
manager,
[this](const_node_ptr& node) -> bool {
std::string errorMsg;
// Current SDPA impl is optimized only for LLM models, so we decompose it for others to avoid perf
// regression. Matching the pattern is a little complicated, so we just check if there is any state nodes.
return node::ScaledDotProductAttention::isSupportedOperation(node, errorMsg) &&
model->get_variables().size() > 0;
},
ov::pass::ScaledDotProductAttentionDecomposition);

// List of enabled/disabled transformations

Expand Down Expand Up @@ -944,9 +935,6 @@ void Transformations::PostLpt() {
}
#endif // OPENVINO_ARCH_X86_64

CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::transpose_sinking::TSShapeOfForward);
CPU_REGISTER_PASS_COMMON(postLPTPassManager, StatefulSDPAFusion);
CPU_REGISTER_PASS_X64(postLPTPassManager, ov::intel_cpu::SDPAFuseTransposeReshape);
CPU_REGISTER_PASS_X64(postLPTPassManager, ov::pass::RMSFusion, false);
CPU_REGISTER_PASS_X64(postLPTPassManager, ov::intel_cpu::DecomposeRMSNorm);
CPU_SET_CALLBACK_X64(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ class ConcatMultiQuerySDPTest : public testing::WithParamInterface<ConcatMultiQu
auto unsqueezeK = std::make_shared<ov::op::v0::Unsqueeze>(concatK, unsquezeAxis);
auto unsqueezeV = std::make_shared<ov::op::v0::Unsqueeze>(concatV, unsquezeAxis);

auto targetShape = ov::op::v0::Constant::create(qkvType, {1, 1, 1, 4, 1}, {1});
auto broadcastK = std::make_shared<ov::op::v1::Multiply>(unsqueezeK, targetShape);
auto broadcastV = std::make_shared<ov::op::v1::Multiply>(unsqueezeV, targetShape);
auto targetShape = ov::op::v0::Constant::create(element::i32, {5}, {1, 1, 1, 4, 1});
auto broadcastK = std::make_shared<ov::op::v3::Broadcast>(unsqueezeK, targetShape, op::BroadcastType::BIDIRECTIONAL);
auto broadcastV = std::make_shared<ov::op::v3::Broadcast>(unsqueezeV, targetShape, op::BroadcastType::BIDIRECTIONAL);

auto target4D = ov::op::v0::Constant::create(ov::element::i32, {4}, {0, 0, 8, 64});

Expand Down
Loading
Loading