diff --git a/src/common/snippets/docs/mha_optimization_guide.md b/src/common/snippets/docs/mha_optimization_guide.md index 28245017833a4a..1ea3a4c24c3524 100644 --- a/src/common/snippets/docs/mha_optimization_guide.md +++ b/src/common/snippets/docs/mha_optimization_guide.md @@ -65,7 +65,7 @@ The supported by decomposition Transpose orders are defined by `TokenizeMHASnipp [SplitDimensionM](../src/pass/split_dimension_m.cpp) splits M dimension of MHA in 2 parts (`batch_m` and `new_m`) by inserting Reshape on A input of the first Matmul and output of the second Matmul (the rest Subgraph's inputs are reshaped by Unsqueeze-like reshapes in order not to break subgraph semantic). This optimization increases parallel work amount by `batch_m` times thus enabling a more efficient parallel execution in some cases. -The splitting is performed based on heuristic algorithm which can be found in `SplitDimensionM::get_splited_dimensions` method. +The splitting is performed based on heuristic algorithm which can be found in `SplitDimensionM::split` method. Let's consider an example of the transformation: diff --git a/src/common/snippets/include/snippets/pass/split_dimension_m.hpp b/src/common/snippets/include/snippets/pass/split_dimension_m.hpp index e9a9a46d3847ff..e7a4adaefca0b2 100644 --- a/src/common/snippets/include/snippets/pass/split_dimension_m.hpp +++ b/src/common/snippets/include/snippets/pass/split_dimension_m.hpp @@ -67,7 +67,18 @@ class SplitDimensionM: public CommonOptimizations::SubgraphPass { private: static std::shared_ptr get_matmul(const std::shared_ptr& subgraph); - static std::pair get_splited_dimensions(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount); + /** + * @brief Contains splitM approaches allowing to get the batch ideally divisible by optimal_parallelism_work_amount + */ + static std::pair compute_ideal_cases_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount); + /** + * @brief Aggressively splits m_dim to minimize kernel_m in order to reduce waiting time for idle threads at the last parallel loop iteration. + */ + static std::pair compute_aggressive_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount); + /** + * @brief Conservatively splits m_dim to get the batch in (optimal_parallelism_work_amount, 2 * optimal_parallelism_work_amount) interval + */ + static std::pair compute_conservative_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount); void reshape_subgraph(const std::shared_ptr& subgraph, const ov::Shape& shape, size_t batch_m_dim, size_t new_m_dim); diff --git a/src/common/snippets/src/pass/split_dimension_m.cpp b/src/common/snippets/src/pass/split_dimension_m.cpp index ae95a371483163..8d2b9e9778165d 100644 --- a/src/common/snippets/src/pass/split_dimension_m.cpp +++ b/src/common/snippets/src/pass/split_dimension_m.cpp @@ -4,8 +4,8 @@ #include "snippets/pass/split_dimension_m.hpp" -#include "snippets/utils/utils.hpp" #include "snippets/itt.hpp" +#include "snippets/utils/utils.hpp" namespace { size_t get_dim_M(const ov::Shape& shape) { @@ -31,45 +31,55 @@ bool SplitDimensionM::is_supported_matmul(const std::shared_ptr& return matmul && !matmul->get_transpose_a() && !matmul->is_dynamic(); } -std::pair SplitDimensionM::get_splited_dimensions(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) { - std::pair splited = { 1, m_dim }; - +std::pair SplitDimensionM::compute_ideal_cases_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) { // Ideal case #1: M can be split on the parts one of which complements the batch dimension to the optimal parallel work amount // In this case, each thread will execute the Snippets kernel once const size_t lower_bound = optimal_parallelism_work_amount / batch_dim; - if (lower_bound * batch_dim == optimal_parallelism_work_amount && m_dim % lower_bound == 0) { - splited.first = lower_bound; - splited.second = m_dim / lower_bound; - OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!"); - return splited; - } + if (lower_bound * batch_dim == optimal_parallelism_work_amount && m_dim % lower_bound == 0) + return std::make_pair(lower_bound, m_dim / lower_bound); // Ideal case #2: M is divisible by optimal parallel work amount, and the new_m_dim is big enough // In this case, each thread will execute the Snippets kernel 'batch_dim' times if (m_dim % optimal_parallelism_work_amount == 0) { const auto new_m_dim = m_dim / optimal_parallelism_work_amount; const size_t min_kernel_m = 64; - if (new_m_dim >= min_kernel_m) { - splited.first = optimal_parallelism_work_amount; - splited.second = new_m_dim; - OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!"); - return splited; - } + if (new_m_dim >= min_kernel_m) + return std::make_pair(optimal_parallelism_work_amount, new_m_dim); } + return std::make_pair(1, m_dim); +} + +std::pair SplitDimensionM::compute_conservative_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) { + std::pair splited = { 1, m_dim }; const size_t upper_bound = utils::div_up(2 * optimal_parallelism_work_amount, batch_dim); for (size_t divisor_0 = upper_bound - 1; divisor_0 > 1; divisor_0--) { size_t divisor_1 = m_dim / divisor_0; - if (divisor_1 * divisor_0 == m_dim) { - splited.first = divisor_0; - splited.second = divisor_1; - break; - } + if (divisor_1 * divisor_0 == m_dim) + return divisor_0 * batch_dim >= optimal_parallelism_work_amount ? std::make_pair(divisor_0, divisor_1) : splited; } - OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!"); return splited; } +std::pair SplitDimensionM::compute_aggressive_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) { + constexpr size_t min_kernel_m = 32; + std::pair best_result = {1, m_dim}; + for (size_t divisor = 2; divisor < std::sqrt(m_dim); ++divisor) { + if (m_dim % divisor != 0) + continue; + if (divisor >= min_kernel_m) + return std::make_pair(m_dim / divisor, divisor); + const size_t m_kernel = m_dim / divisor; + if (m_kernel >= min_kernel_m) { + best_result.first = divisor; + best_result.second = m_kernel; + } + } + if (best_result.first * batch_dim >= optimal_parallelism_work_amount) + return best_result; + return std::make_pair(1, m_dim); +} + bool SplitDimensionM::can_be_optimized(const std::shared_ptr& node, size_t concurrency) { if (!is_supported_matmul(node)) return false; @@ -131,16 +141,23 @@ bool SplitDimensionM::split(const ov::Shape& shape, size_t optimal_parallelism_w if (is_prime_number(m_dim)) return false; - auto is_optimized = [&](size_t batch_dim) { - return batch_dim >= optimal_parallelism_work_amount; - }; - // We skip optimization if the current batch is optimal for concurrency - if (is_optimized(batch_dim)) + if (batch_dim % optimal_parallelism_work_amount == 0) return false; - std::tie(batch_m_dim, new_m_dim) = get_splited_dimensions(batch_dim, m_dim, optimal_parallelism_work_amount); - return is_optimized(batch_dim * batch_m_dim); + std::tie(batch_m_dim, new_m_dim) = compute_ideal_cases_heuristic(batch_dim, m_dim, optimal_parallelism_work_amount); + if (batch_m_dim != 1) + return true; + + // If M dim is big enough, aggressive heuristic is used for kernel_m minimization. + // For smaller M dim, conservative heuristic is used to preserve old behavour. + const bool big_m_dim = m_dim >= 4000; + if (big_m_dim) { + std::tie(batch_m_dim, new_m_dim) = compute_aggressive_heuristic(batch_dim, m_dim, optimal_parallelism_work_amount); + } else if (batch_dim < optimal_parallelism_work_amount) { + std::tie(batch_m_dim, new_m_dim) = compute_conservative_heuristic(batch_dim, m_dim, optimal_parallelism_work_amount); + } + return batch_m_dim != 1; } void SplitDimensionM::reshape_subgraph(const std::shared_ptr& subgraph, const ov::Shape& shape, size_t batch_m_dim, size_t new_m_dim) { diff --git a/src/common/snippets/tests/src/utils/split_dim_m.cpp b/src/common/snippets/tests/src/utils/split_dim_m.cpp index 9e801fceae02e9..a310008b2e2e23 100644 --- a/src/common/snippets/tests/src/utils/split_dim_m.cpp +++ b/src/common/snippets/tests/src/utils/split_dim_m.cpp @@ -59,6 +59,8 @@ const std::vector split_dimension_cases = { {InputData{25, 50, 40}, ReferenceData{true, 2, 25}}, {InputData{5, 16384, 40}, ReferenceData{true, 8, 2048}}, {InputData{5, 16384, 32}, ReferenceData{true, 32, 512}}, + {InputData{48, 4097, 32}, ReferenceData{true, 17, 241}}, + {InputData{48, 6600, 32}, ReferenceData{true, 200, 33}}, }; INSTANTIATE_TEST_SUITE_P(smoke_Snippets_SplitDimensionM,