diff --git a/src/common/snippets/include/snippets/pass/split_dimension_m.hpp b/src/common/snippets/include/snippets/pass/split_dimension_m.hpp index e7a4adaefca0b2..380db25beb381b 100644 --- a/src/common/snippets/include/snippets/pass/split_dimension_m.hpp +++ b/src/common/snippets/include/snippets/pass/split_dimension_m.hpp @@ -70,15 +70,15 @@ class SplitDimensionM: public CommonOptimizations::SubgraphPass { /** * @brief Contains splitM approaches allowing to get the batch ideally divisible by optimal_parallelism_work_amount */ - static std::pair compute_ideal_cases_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount); + static std::pair split_ideally(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount); /** - * @brief Aggressively splits m_dim to minimize kernel_m in order to reduce waiting time for idle threads at the last parallel loop iteration. + * @brief Splits m_dim to minimize kernel_m in order to reduce waiting time for idle threads at the last parallel loop iteration. */ - static std::pair compute_aggressive_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount); + static std::pair split_minimize_kernel_wa(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount); /** - * @brief Conservatively splits m_dim to get the batch in (optimal_parallelism_work_amount, 2 * optimal_parallelism_work_amount) interval + * @brief Splits m_dim to get the batch in (optimal_parallelism_work_amount, 2 * optimal_parallelism_work_amount) interval */ - static std::pair compute_conservative_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount); + static std::pair split_conservatively_increase_parallel_wa(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount); void reshape_subgraph(const std::shared_ptr& subgraph, const ov::Shape& shape, size_t batch_m_dim, size_t new_m_dim); diff --git a/src/common/snippets/src/pass/split_dimension_m.cpp b/src/common/snippets/src/pass/split_dimension_m.cpp index 8d2b9e9778165d..5bead40c0dcb54 100644 --- a/src/common/snippets/src/pass/split_dimension_m.cpp +++ b/src/common/snippets/src/pass/split_dimension_m.cpp @@ -31,7 +31,7 @@ bool SplitDimensionM::is_supported_matmul(const std::shared_ptr& return matmul && !matmul->get_transpose_a() && !matmul->is_dynamic(); } -std::pair SplitDimensionM::compute_ideal_cases_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) { +std::pair SplitDimensionM::split_ideally(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) { // Ideal case #1: M can be split on the parts one of which complements the batch dimension to the optimal parallel work amount // In this case, each thread will execute the Snippets kernel once const size_t lower_bound = optimal_parallelism_work_amount / batch_dim; @@ -50,7 +50,7 @@ std::pair SplitDimensionM::compute_ideal_cases_heuristic(size_t return std::make_pair(1, m_dim); } -std::pair SplitDimensionM::compute_conservative_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) { +std::pair SplitDimensionM::split_conservatively_increase_parallel_wa(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) { std::pair splited = { 1, m_dim }; const size_t upper_bound = utils::div_up(2 * optimal_parallelism_work_amount, batch_dim); for (size_t divisor_0 = upper_bound - 1; divisor_0 > 1; divisor_0--) { @@ -61,7 +61,7 @@ std::pair SplitDimensionM::compute_conservative_heuristic(size_t return splited; } -std::pair SplitDimensionM::compute_aggressive_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) { +std::pair SplitDimensionM::split_minimize_kernel_wa(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) { constexpr size_t min_kernel_m = 32; std::pair best_result = {1, m_dim}; for (size_t divisor = 2; divisor < std::sqrt(m_dim); ++divisor) { @@ -145,19 +145,26 @@ bool SplitDimensionM::split(const ov::Shape& shape, size_t optimal_parallelism_w if (batch_dim % optimal_parallelism_work_amount == 0) return false; - std::tie(batch_m_dim, new_m_dim) = compute_ideal_cases_heuristic(batch_dim, m_dim, optimal_parallelism_work_amount); - if (batch_m_dim != 1) + auto split_is_done = [&batch_m_dim]() { + return batch_m_dim != 1; + }; + + std::tie(batch_m_dim, new_m_dim) = split_ideally(batch_dim, m_dim, optimal_parallelism_work_amount); + if (split_is_done()) return true; // If M dim is big enough, aggressive heuristic is used for kernel_m minimization. // For smaller M dim, conservative heuristic is used to preserve old behavour. const bool big_m_dim = m_dim >= 4000; if (big_m_dim) { - std::tie(batch_m_dim, new_m_dim) = compute_aggressive_heuristic(batch_dim, m_dim, optimal_parallelism_work_amount); - } else if (batch_dim < optimal_parallelism_work_amount) { - std::tie(batch_m_dim, new_m_dim) = compute_conservative_heuristic(batch_dim, m_dim, optimal_parallelism_work_amount); + std::tie(batch_m_dim, new_m_dim) = split_minimize_kernel_wa(batch_dim, m_dim, optimal_parallelism_work_amount); + if (split_is_done()) + return true; + } + if (batch_dim < optimal_parallelism_work_amount) { + std::tie(batch_m_dim, new_m_dim) = split_conservatively_increase_parallel_wa(batch_dim, m_dim, optimal_parallelism_work_amount); } - return batch_m_dim != 1; + return split_is_done(); } void SplitDimensionM::reshape_subgraph(const std::shared_ptr& subgraph, const ov::Shape& shape, size_t batch_m_dim, size_t new_m_dim) {