Skip to content

Commit

Permalink
[Snippets] Disabled non-inplace ops tokenization on 2nd inputs of Mat…
Browse files Browse the repository at this point in the history
…Muls in MHA, supported transposed_b=true

[Snippets] Disable second input of MM toknization

[Snippets][Tests] Updated tokenization unit tests

[Snippets][CPU][Tests] Updated cpu func tests

[Snippets] Updated the explanation in the tokenization pass

[Snippets][CPU] Removed  check from tokenization callback

[Snippets] Updated comment in pass of tokenization
  • Loading branch information
a-sidorova committed Jan 8, 2025
1 parent 4bfdc7f commit d85fafd
Show file tree
Hide file tree
Showing 11 changed files with 131 additions and 168 deletions.
47 changes: 6 additions & 41 deletions src/common/snippets/src/pass/mha_tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,45 +344,6 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken
* Transpose3
*/

// First input branch of MatMul0 should be executed before second input branch of MatMul0,
// so firstly we insert Transpose1 on the beginning of ordered_ops and then Transpose0
// Note: If MatMul0 has transposed_b, we should tokenize only scalars ops from 1st branch
// to move extracted Transpose from MatMul input to body Parameter
auto parent = matmul0->get_input_node_shared_ptr(1);
// We can support several ops between MatMul0 with transposed_b and Transpose1 with 0213 order (or without this Transpose1)
// only if these ops have scalar shapes on other inputs.
// There is transformation ExplicitTransposeMatMulInputs that set supported order and transposed_b(false).
// We can allow to call this pass only if ops have scalar shapes to avoid shape mismatching
const auto is_transposed_b_0 = matmul0->get_transpose_b();
bool has_matmul0_has_ops_on_input = false;
while (is_supported_intermediate_op(parent)) {
// All supported ops have only one output port
if (parent->get_output_target_inputs(0).size() != 1)
break;

// Only if MatMul0 has transposed_b, we have to tokenize scalar ops
// to move explicit Transpose from MatMul0 input_1 to Parameter of Subgraph body
if (is_transposed_b_0 && !ov::snippets::pass::ExplicitTransposeMatMulInputs::are_weights_scalar(parent)) {
break;
}

// To avoid unsupported number of non-scalar Constants in the future after FakeQuantize decomposition (plugin specific limitation)
// we should calculate potential number of non-scalar Constants for FakeQuantize that will be moved up from body.
if (const auto fq_node = ov::as_type_ptr<ov::op::v0::FakeQuantize>(parent)) {
hidden_virtual_ports_count += ov::snippets::utils::get_non_scalar_constant_count_for_fq(fq_node);
}

potential_body_params_count += get_potential_body_params(parent);
ordered_ops.insert(ordered_ops.begin(), parent);
// [107731] To go always through 0-th port - is it safe?
parent = parent->get_input_node_shared_ptr(0);
has_matmul0_has_ops_on_input = true;
}
// If there are ops on second input of MatMul0 and only one unique Buffer between MatMuls - there must be one more unique Buffer
if (has_matmul0_has_ops_on_input && uniqie_buffer_reg_group_count < 2) {
uniqie_buffer_reg_group_count++;
}

auto tokenize_transpose = [&](const std::shared_ptr<ov::opset1::Transpose>& transpose,
bool is_input_transposed, std::vector<int32_t> order,
const ov::NodeVector::const_iterator& pos) {
Expand All @@ -404,11 +365,15 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken
}
};

const auto transpose1 = ov::as_type_ptr<ov::opset1::Transpose>(parent);
// [160177]: Due to performance problems, if operations on 2nd input of MatMuls should be explicitly executed
// (in other words, if the Buffer should be inserted between Brgemm and this op sequence),
// we don't tokenize such operations into Subgraph. The details are described in the ticket 160177.
// Please, return the tokenization of these ops when parallel loops are implemented.
const auto transpose0 = ov::as_type_ptr<ov::opset1::Transpose>(matmul0->get_input_node_shared_ptr(0));
const auto transpose1 = ov::as_type_ptr<ov::opset1::Transpose>(matmul0->get_input_node_shared_ptr(1));
const auto transpose2 = ov::as_type_ptr<ov::opset1::Transpose>(matmul1->get_input_node_shared_ptr(1));
tokenize_transpose(transpose1, is_transposed_b_0, get_decomposed_transpose_order(pattern_rank), ordered_ops.begin());
tokenize_transpose(transpose0, matmul0->get_transpose_a(), get_fusion_transpose_order(pattern_rank), ordered_ops.begin());
tokenize_transpose(transpose1, matmul0->get_transpose_b(), get_fusion_transpose_order(pattern_rank), ordered_ops.begin());
tokenize_transpose(transpose2, matmul1->get_transpose_b(), get_fusion_transpose_order(pattern_rank), ordered_ops.end());
ordered_ops.push_back(matmul1);

Expand Down
8 changes: 4 additions & 4 deletions src/common/snippets/tests/src/pass/mha_tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Dynamic_Transpose_fusion) {
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM) {
const auto& f = MHASplitMFunction(std::vector<PartialShape>{{128, 12, 64}, {128, 12, 64}, {12, 128, 128}, {128, 12, 64}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{2, 64, 12, 64}, {128, 12, 1, 64}, {12, 2, 64, 128}, {1, 128, 12, 64}, {128, 12, 64}},
std::vector<Shape>{{2, 64, 12, 64}, {12, 1, 64, 128}, {12, 2, 64, 128}, {1, 128, 12, 64}, {128, 12, 64}},
false);
model = f.getOriginal();
model_ref = f.getReference();
Expand All @@ -171,7 +171,7 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM) {
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM_withMul) {
const auto& f = MHASplitMFunction(std::vector<PartialShape>{{128, 12, 64}, {128, 12, 64}, {12, 128, 128}, {128, 12, 64}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{2, 64, 12, 64}, {128, 12, 1, 64}, {12, 2, 64, 128}, {1, 128, 12, 64}, {128, 12, 64}},
std::vector<Shape>{{2, 64, 12, 64}, {12, 1, 64, 128}, {12, 2, 64, 128}, {1, 128, 12, 64}, {128, 12, 64}},
true);
model = f.getOriginal();
model_ref = f.getReference();
Expand All @@ -182,7 +182,7 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM_withMul) {
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM) {
const auto& f = MHASplitMFunction(std::vector<PartialShape>{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{1, 6, 64, 16, 64}, {1, 384, 16, 1, 64}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}},
std::vector<Shape>{{1, 6, 64, 16, 64}, {1, 16, 1, 64, 384}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}},
false);
model = f.getOriginal();
model_ref = f.getReference();
Expand All @@ -193,7 +193,7 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM) {
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM_withMul) {
const auto& f = MHASplitMFunction(std::vector<PartialShape>{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{1, 6, 64, 16, 64}, {1, 384, 16, 1, 64}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}},
std::vector<Shape>{{1, 6, 64, 16, 64}, {1, 16, 1, 64, 384}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}},
true);
model = f.getOriginal();
model_ref = f.getReference();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1081,10 +1081,7 @@ void Transformations::MainSnippets(void) {
// Only FP32 dynamic MHA is supported
if (matmul->is_dynamic())
return false;
// [114487] brgemm kernel in oneDNN requires brgemm_copy_b kernel if MatMul node has transposed_b=True
// The current solution with ExtractExplicitMatMulTranspose pass is slower for non-f32 cases than using of
// brgemm_copy_b kernel
if (matmul->get_transpose_a() || matmul->get_transpose_b())
if (matmul->get_transpose_a())
return false;
// [150842] The execution of Brgemm INT8/BF16/FP16 on AMX platforms depends on the value of "K % VNNIFactor".
// For more details, please teake a look at the ticket 150842
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_MHA,
ElementType::f32}),
::testing::ValuesIn(matMulIn0Precisions),
::testing::ValuesIn(patternTypes),
::testing::Values(ExpectedNodes{{"Subgraph", 1}}),
::testing::Values(ExpectedNodes{{"Subgraph", 2}}), // MHA + Decomposed Transpose on input
::testing::Values(ov::test::utils::DEVICE_CPU)),
MHATest::getTestCaseName);

Expand All @@ -309,7 +309,7 @@ INSTANTIATE_TEST_SUITE_P(
std::vector<ElementType>{ElementType::bf16, ElementType::bf16, ElementType::bf16, ElementType::bf16}),
::testing::ValuesIn(matMulIn0Precisions),
::testing::ValuesIn(patternTypes),
::testing::Values(ExpectedNodes{{"Subgraph", 1},
::testing::Values(ExpectedNodes{{"Subgraph", 2}, // MHA + Decomposed Transpose on input
{"Transpose", 1}}), // Plugin disables tokenization of Transpose on output
::testing::Values(ov::test::utils::DEVICE_CPU)),
MHATest::getTestCaseName);
Expand All @@ -323,7 +323,7 @@ INSTANTIATE_TEST_SUITE_P(
std::vector<ElementType>{ElementType::f16, ElementType::f16, ElementType::f16, ElementType::f16}),
::testing::ValuesIn(matMulIn0Precisions),
::testing::ValuesIn(patternTypes),
::testing::Values(ExpectedNodes{{"Subgraph", 1},
::testing::Values(ExpectedNodes{{"Subgraph", 2}, // MHA + Decomposed Transpose on input
{"Transpose", 1}}), // Plugin disables tokenization of Transpose on output
::testing::Values(ov::test::utils::DEVICE_CPU)),
MHATest::getTestCaseName);
Expand Down Expand Up @@ -694,7 +694,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_MHAQuant_Pattern0,
::testing::Values(0),
::testing::Values(ExpectedNodes{
{"Subgraph", 5}, // FQs on inputs x 3 + MHA + Deq Mul
{"Transpose", 1}}), // Transpose between MHA and Deq Mul
{"Transpose", 2}}), // Decomposed Transpose on input + Transpose between MHA and Deq Mul
::testing::Values(ov::test::utils::DEVICE_CPU)),
MHAQuantTest::getTestCaseName);

Expand All @@ -706,7 +706,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_MHAQuant_Pattern1,
::testing::Values(1),
::testing::Values(ExpectedNodes{
{"Subgraph", 4}, // FQ on input x 2 + MHA + Deq Mul
{"Transpose", 1}}), // Transpose between MHA and Deq Mul
{"Transpose", 2}}), // Decomposed Transpose on input + Transpose between MHA and Deq Mul
::testing::Values(ov::test::utils::DEVICE_CPU)),
MHAQuantTest::getTestCaseName);

Expand All @@ -717,7 +717,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_MHAQuant_Pattern2,
::testing::ValuesIn(matMulIn0PrecisionsQuant),
::testing::Values(2),
::testing::Values(ExpectedNodes{{"Subgraph", 3}, // FQ on inputs x 2 + MHA
{"Transpose", 0}}), // Transpose is fused
{"Transpose", 1}}), // Decomposed Transpose on input
::testing::Values(ov::test::utils::DEVICE_CPU)),
MHAQuantTest::getTestCaseName);

Expand Down
Loading

0 comments on commit d85fafd

Please sign in to comment.