From 7d183e3cdbd8c50eeb4ca3abb072e9343714ba70 Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Mon, 23 Dec 2024 11:17:23 +0100 Subject: [PATCH] [Snippets] SplitDimensionM: heuristic update Tests adjustment Final changes --- .../snippets/docs/mha_optimization_guide.md | 2 +- .../snippets/pass/split_dimension_m.hpp | 15 +++- .../snippets/src/pass/split_dimension_m.cpp | 88 ++++++++++++------- .../tests/src/pass/mha_tokenization.cpp | 6 +- .../snippets/tests/src/utils/split_dim_m.cpp | 5 ++ .../transformation_pipeline.cpp | 1 + 6 files changed, 82 insertions(+), 35 deletions(-) diff --git a/src/common/snippets/docs/mha_optimization_guide.md b/src/common/snippets/docs/mha_optimization_guide.md index 28245017833a4a..1ea3a4c24c3524 100644 --- a/src/common/snippets/docs/mha_optimization_guide.md +++ b/src/common/snippets/docs/mha_optimization_guide.md @@ -65,7 +65,7 @@ The supported by decomposition Transpose orders are defined by `TokenizeMHASnipp [SplitDimensionM](../src/pass/split_dimension_m.cpp) splits M dimension of MHA in 2 parts (`batch_m` and `new_m`) by inserting Reshape on A input of the first Matmul and output of the second Matmul (the rest Subgraph's inputs are reshaped by Unsqueeze-like reshapes in order not to break subgraph semantic). This optimization increases parallel work amount by `batch_m` times thus enabling a more efficient parallel execution in some cases. -The splitting is performed based on heuristic algorithm which can be found in `SplitDimensionM::get_splited_dimensions` method. +The splitting is performed based on heuristic algorithm which can be found in `SplitDimensionM::split` method. Let's consider an example of the transformation: diff --git a/src/common/snippets/include/snippets/pass/split_dimension_m.hpp b/src/common/snippets/include/snippets/pass/split_dimension_m.hpp index e9a9a46d3847ff..b93f09bf62803e 100644 --- a/src/common/snippets/include/snippets/pass/split_dimension_m.hpp +++ b/src/common/snippets/include/snippets/pass/split_dimension_m.hpp @@ -67,11 +67,24 @@ class SplitDimensionM: public CommonOptimizations::SubgraphPass { private: static std::shared_ptr get_matmul(const std::shared_ptr& subgraph); - static std::pair get_splited_dimensions(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount); + /** + * @brief Contains splitM approaches allowing to get the batch ideally divisible by optimal_parallelism_work_amount + */ + static std::pair split_ideally(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount); + /** + * @brief Splits m_dim to minimize kernel_m in order to reduce waiting time for idle threads at the last parallel loop iteration. + */ + static std::pair split_minimize_kernel_wa(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount); + /** + * @brief Splits m_dim to get the batch in (optimal_parallelism_work_amount, 2 * optimal_parallelism_work_amount) interval + */ + static std::pair split_fallback_increase_parallel_wa(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount); void reshape_subgraph(const std::shared_ptr& subgraph, const ov::Shape& shape, size_t batch_m_dim, size_t new_m_dim); size_t m_concurrency; + + static const size_t min_kernel_m; }; } // namespace pass } // namespace snippets diff --git a/src/common/snippets/src/pass/split_dimension_m.cpp b/src/common/snippets/src/pass/split_dimension_m.cpp index ae95a371483163..b6b8cdd70f0bc8 100644 --- a/src/common/snippets/src/pass/split_dimension_m.cpp +++ b/src/common/snippets/src/pass/split_dimension_m.cpp @@ -4,8 +4,8 @@ #include "snippets/pass/split_dimension_m.hpp" -#include "snippets/utils/utils.hpp" #include "snippets/itt.hpp" +#include "snippets/utils/utils.hpp" namespace { size_t get_dim_M(const ov::Shape& shape) { @@ -26,50 +26,69 @@ bool is_prime_number(size_t value) { namespace ov { namespace snippets { namespace pass { + +const size_t SplitDimensionM::min_kernel_m = 32; + bool SplitDimensionM::is_supported_matmul(const std::shared_ptr& node) { const auto matmul = ov::as_type_ptr(node); return matmul && !matmul->get_transpose_a() && !matmul->is_dynamic(); } -std::pair SplitDimensionM::get_splited_dimensions(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) { - std::pair splited = { 1, m_dim }; - +std::pair SplitDimensionM::split_ideally(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) { // Ideal case #1: M can be split on the parts one of which complements the batch dimension to the optimal parallel work amount // In this case, each thread will execute the Snippets kernel once const size_t lower_bound = optimal_parallelism_work_amount / batch_dim; - if (lower_bound * batch_dim == optimal_parallelism_work_amount && m_dim % lower_bound == 0) { - splited.first = lower_bound; - splited.second = m_dim / lower_bound; - OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!"); - return splited; - } + if (lower_bound * batch_dim == optimal_parallelism_work_amount && m_dim % lower_bound == 0) + return std::make_pair(lower_bound, m_dim / lower_bound); // Ideal case #2: M is divisible by optimal parallel work amount, and the new_m_dim is big enough // In this case, each thread will execute the Snippets kernel 'batch_dim' times if (m_dim % optimal_parallelism_work_amount == 0) { const auto new_m_dim = m_dim / optimal_parallelism_work_amount; - const size_t min_kernel_m = 64; - if (new_m_dim >= min_kernel_m) { - splited.first = optimal_parallelism_work_amount; - splited.second = new_m_dim; - OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!"); - return splited; - } + if (new_m_dim >= min_kernel_m) + return std::make_pair(optimal_parallelism_work_amount, new_m_dim); } + return std::make_pair(1, m_dim); +} + +std::pair SplitDimensionM::split_fallback_increase_parallel_wa(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) { + std::pair splited = { 1, m_dim }; const size_t upper_bound = utils::div_up(2 * optimal_parallelism_work_amount, batch_dim); for (size_t divisor_0 = upper_bound - 1; divisor_0 > 1; divisor_0--) { size_t divisor_1 = m_dim / divisor_0; - if (divisor_1 * divisor_0 == m_dim) { - splited.first = divisor_0; - splited.second = divisor_1; - break; - } + if (divisor_1 * divisor_0 == m_dim) + return divisor_0 * batch_dim >= optimal_parallelism_work_amount ? std::make_pair(divisor_0, divisor_1) : splited; } - OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!"); return splited; } +std::pair SplitDimensionM::split_minimize_kernel_wa(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) { + // This heuristic minimizes 'm_kernel' (=> maximizes 'm_batch') with a limitation that 'm_kernel >= min_kernel_m'. + // In other words, it tries to find 'm_kernel' bigger than 'min_kernel_m' and at the same time as close as possible to this value. + std::pair best_result = {1, m_dim}; + for (size_t divisor = 2; divisor < std::sqrt(m_dim); ++divisor) { + if (m_dim % divisor != 0) + continue; + // If divisor is more than 'min_kernel_m', divisor becomes 'm_kernel', + // guaranteeing the most optimal implementation from 'm_kernel' minimization perspective. + if (divisor >= min_kernel_m) + return std::make_pair(m_dim / divisor, divisor); + + // If divisor is less than 'min_kernel_m', divisor becomes m_batch. + // However, it is not guaranteed that the current 'm_kernel = m_dim / divisor' is minimized, as one of the next divisors can be more optimal. + // So in this case the best result is remembered + const size_t m_kernel = m_dim / divisor; + if (m_kernel >= min_kernel_m) { + best_result.first = divisor; + best_result.second = m_kernel; + } + } + if (best_result.first * batch_dim >= optimal_parallelism_work_amount) + return best_result; + return std::make_pair(1, m_dim); +} + bool SplitDimensionM::can_be_optimized(const std::shared_ptr& node, size_t concurrency) { if (!is_supported_matmul(node)) return false; @@ -131,16 +150,25 @@ bool SplitDimensionM::split(const ov::Shape& shape, size_t optimal_parallelism_w if (is_prime_number(m_dim)) return false; - auto is_optimized = [&](size_t batch_dim) { - return batch_dim >= optimal_parallelism_work_amount; - }; - // We skip optimization if the current batch is optimal for concurrency - if (is_optimized(batch_dim)) + if (batch_dim % optimal_parallelism_work_amount == 0) return false; - std::tie(batch_m_dim, new_m_dim) = get_splited_dimensions(batch_dim, m_dim, optimal_parallelism_work_amount); - return is_optimized(batch_dim * batch_m_dim); + auto split_is_done = [&batch_m_dim]() { + return batch_m_dim != 1; + }; + + std::tie(batch_m_dim, new_m_dim) = split_ideally(batch_dim, m_dim, optimal_parallelism_work_amount); + if (split_is_done()) + return true; + + std::tie(batch_m_dim, new_m_dim) = split_minimize_kernel_wa(batch_dim, m_dim, optimal_parallelism_work_amount); + if (split_is_done()) + return true; + // If all the previous heuristics failed, fallback heuristic is used, which reflects the old splitting behavior + if (batch_dim < optimal_parallelism_work_amount) + std::tie(batch_m_dim, new_m_dim) = split_fallback_increase_parallel_wa(batch_dim, m_dim, optimal_parallelism_work_amount); + return split_is_done(); } void SplitDimensionM::reshape_subgraph(const std::shared_ptr& subgraph, const ov::Shape& shape, size_t batch_m_dim, size_t new_m_dim) { diff --git a/src/common/snippets/tests/src/pass/mha_tokenization.cpp b/src/common/snippets/tests/src/pass/mha_tokenization.cpp index dfd269bba49597..d725c36e5c35a5 100644 --- a/src/common/snippets/tests/src/pass/mha_tokenization.cpp +++ b/src/common/snippets/tests/src/pass/mha_tokenization.cpp @@ -171,7 +171,7 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM) { TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM_withMul) { const auto& f = MHASplitMFunction(std::vector{{128, 12, 64}, {128, 12, 64}, {12, 128, 128}, {128, 12, 64}}, std::vector({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}), - std::vector{{2, 64, 12, 64}, {12, 1, 64, 128}, {12, 2, 64, 128}, {1, 128, 12, 64}, {128, 12, 64}}, + std::vector{{4, 32, 12, 64}, {12, 1, 64, 128}, {12, 4, 32, 128}, {1, 128, 12, 64}, {128, 12, 64}}, true); model = f.getOriginal(); model_ref = f.getReference(); @@ -182,7 +182,7 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM_withMul) { TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM) { const auto& f = MHASplitMFunction(std::vector{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}}, std::vector({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}), - std::vector{{1, 6, 64, 16, 64}, {1, 16, 1, 64, 384}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}}, + std::vector{{1, 12, 32, 16, 64}, {1, 16, 1, 64, 384}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}}, false); model = f.getOriginal(); model_ref = f.getReference(); @@ -193,7 +193,7 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM) { TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM_withMul) { const auto& f = MHASplitMFunction(std::vector{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}}, std::vector({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}), - std::vector{{1, 6, 64, 16, 64}, {1, 16, 1, 64, 384}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}}, + std::vector{{1, 12, 32, 16, 64}, {1, 16, 1, 64, 384}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}}, true); model = f.getOriginal(); model_ref = f.getReference(); diff --git a/src/common/snippets/tests/src/utils/split_dim_m.cpp b/src/common/snippets/tests/src/utils/split_dim_m.cpp index 9e801fceae02e9..df7c277d775cb4 100644 --- a/src/common/snippets/tests/src/utils/split_dim_m.cpp +++ b/src/common/snippets/tests/src/utils/split_dim_m.cpp @@ -59,6 +59,11 @@ const std::vector split_dimension_cases = { {InputData{25, 50, 40}, ReferenceData{true, 2, 25}}, {InputData{5, 16384, 40}, ReferenceData{true, 8, 2048}}, {InputData{5, 16384, 32}, ReferenceData{true, 32, 512}}, + {InputData{48, 4097, 32}, ReferenceData{true, 17, 241}}, + {InputData{48, 6600, 32}, ReferenceData{true, 200, 33}}, + {InputData{12, 128, 16}, ReferenceData{true, 4, 32}}, + {InputData{16, 384, 60}, ReferenceData{true, 12, 32}}, + {InputData{16, 384, 24}, ReferenceData{true, 12, 32}}, }; INSTANTIATE_TEST_SUITE_P(smoke_Snippets_SplitDimensionM, diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 0acce355e8262f..7b787f2afd0296 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -1110,6 +1110,7 @@ void Transformations::MainSnippets(void) { return false; const auto parallel_work_amount = std::accumulate(shape.rbegin() + 2, shape.rend(), ov::Dimension(1), std::multiplies()); + // Ticket 160154: enable tokenization for MHA with insufficient parallel work amount const auto is_unsupported_parallel_work_amount = static_cast(parallel_work_amount.get_length()) < tokenization_config.get_concurrency() && !ov::snippets::pass::SplitDimensionM::can_be_optimized(n, tokenization_config.get_concurrency());