Skip to content

Commit

Permalink
[Snippets] SplitDimensionM: heuristic update
Browse files Browse the repository at this point in the history
Tests adjustment

Final changes
  • Loading branch information
v-Golubev authored and a-sidorova committed Jan 8, 2025
1 parent d85fafd commit 7d183e3
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 35 deletions.
2 changes: 1 addition & 1 deletion src/common/snippets/docs/mha_optimization_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ The supported by decomposition Transpose orders are defined by `TokenizeMHASnipp

[SplitDimensionM](../src/pass/split_dimension_m.cpp) splits M dimension of MHA in 2 parts (`batch_m` and `new_m`) by inserting Reshape on A input of the first Matmul and output of the second Matmul (the rest Subgraph's inputs are reshaped by Unsqueeze-like reshapes in order not to break subgraph semantic).
This optimization increases parallel work amount by `batch_m` times thus enabling a more efficient parallel execution in some cases.
The splitting is performed based on heuristic algorithm which can be found in `SplitDimensionM::get_splited_dimensions` method.
The splitting is performed based on heuristic algorithm which can be found in `SplitDimensionM::split` method.

Let's consider an example of the transformation:

Expand Down
15 changes: 14 additions & 1 deletion src/common/snippets/include/snippets/pass/split_dimension_m.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,24 @@ class SplitDimensionM: public CommonOptimizations::SubgraphPass {

private:
static std::shared_ptr<ov::op::v0::MatMul> get_matmul(const std::shared_ptr<op::Subgraph>& subgraph);
static std::pair<size_t, size_t> get_splited_dimensions(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount);
/**
* @brief Contains splitM approaches allowing to get the batch ideally divisible by optimal_parallelism_work_amount
*/
static std::pair<size_t, size_t> split_ideally(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount);
/**
* @brief Splits m_dim to minimize kernel_m in order to reduce waiting time for idle threads at the last parallel loop iteration.
*/
static std::pair<size_t, size_t> split_minimize_kernel_wa(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount);
/**
* @brief Splits m_dim to get the batch in (optimal_parallelism_work_amount, 2 * optimal_parallelism_work_amount) interval
*/
static std::pair<size_t, size_t> split_fallback_increase_parallel_wa(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount);

void reshape_subgraph(const std::shared_ptr<op::Subgraph>& subgraph, const ov::Shape& shape, size_t batch_m_dim, size_t new_m_dim);

size_t m_concurrency;

static const size_t min_kernel_m;
};
} // namespace pass
} // namespace snippets
Expand Down
88 changes: 58 additions & 30 deletions src/common/snippets/src/pass/split_dimension_m.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

#include "snippets/pass/split_dimension_m.hpp"

#include "snippets/utils/utils.hpp"
#include "snippets/itt.hpp"
#include "snippets/utils/utils.hpp"

namespace {
size_t get_dim_M(const ov::Shape& shape) {
Expand All @@ -26,50 +26,69 @@ bool is_prime_number(size_t value) {
namespace ov {
namespace snippets {
namespace pass {

const size_t SplitDimensionM::min_kernel_m = 32;

bool SplitDimensionM::is_supported_matmul(const std::shared_ptr<const ov::Node>& node) {
const auto matmul = ov::as_type_ptr<const ov::op::v0::MatMul>(node);
return matmul && !matmul->get_transpose_a() && !matmul->is_dynamic();
}

std::pair<size_t, size_t> SplitDimensionM::get_splited_dimensions(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) {
std::pair<size_t, size_t> splited = { 1, m_dim };

std::pair<size_t, size_t> SplitDimensionM::split_ideally(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) {
// Ideal case #1: M can be split on the parts one of which complements the batch dimension to the optimal parallel work amount
// In this case, each thread will execute the Snippets kernel once
const size_t lower_bound = optimal_parallelism_work_amount / batch_dim;
if (lower_bound * batch_dim == optimal_parallelism_work_amount && m_dim % lower_bound == 0) {
splited.first = lower_bound;
splited.second = m_dim / lower_bound;
OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!");
return splited;
}
if (lower_bound * batch_dim == optimal_parallelism_work_amount && m_dim % lower_bound == 0)
return std::make_pair(lower_bound, m_dim / lower_bound);

// Ideal case #2: M is divisible by optimal parallel work amount, and the new_m_dim is big enough
// In this case, each thread will execute the Snippets kernel 'batch_dim' times
if (m_dim % optimal_parallelism_work_amount == 0) {
const auto new_m_dim = m_dim / optimal_parallelism_work_amount;
const size_t min_kernel_m = 64;
if (new_m_dim >= min_kernel_m) {
splited.first = optimal_parallelism_work_amount;
splited.second = new_m_dim;
OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!");
return splited;
}
if (new_m_dim >= min_kernel_m)
return std::make_pair(optimal_parallelism_work_amount, new_m_dim);
}

return std::make_pair(1, m_dim);
}

std::pair<size_t, size_t> SplitDimensionM::split_fallback_increase_parallel_wa(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) {
std::pair<size_t, size_t> splited = { 1, m_dim };
const size_t upper_bound = utils::div_up(2 * optimal_parallelism_work_amount, batch_dim);
for (size_t divisor_0 = upper_bound - 1; divisor_0 > 1; divisor_0--) {
size_t divisor_1 = m_dim / divisor_0;
if (divisor_1 * divisor_0 == m_dim) {
splited.first = divisor_0;
splited.second = divisor_1;
break;
}
if (divisor_1 * divisor_0 == m_dim)
return divisor_0 * batch_dim >= optimal_parallelism_work_amount ? std::make_pair(divisor_0, divisor_1) : splited;
}
OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!");
return splited;
}

std::pair<size_t, size_t> SplitDimensionM::split_minimize_kernel_wa(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) {
// This heuristic minimizes 'm_kernel' (=> maximizes 'm_batch') with a limitation that 'm_kernel >= min_kernel_m'.
// In other words, it tries to find 'm_kernel' bigger than 'min_kernel_m' and at the same time as close as possible to this value.
std::pair<size_t, size_t> best_result = {1, m_dim};
for (size_t divisor = 2; divisor < std::sqrt(m_dim); ++divisor) {
if (m_dim % divisor != 0)
continue;
// If divisor is more than 'min_kernel_m', divisor becomes 'm_kernel',
// guaranteeing the most optimal implementation from 'm_kernel' minimization perspective.
if (divisor >= min_kernel_m)
return std::make_pair(m_dim / divisor, divisor);

// If divisor is less than 'min_kernel_m', divisor becomes m_batch.
// However, it is not guaranteed that the current 'm_kernel = m_dim / divisor' is minimized, as one of the next divisors can be more optimal.
// So in this case the best result is remembered
const size_t m_kernel = m_dim / divisor;
if (m_kernel >= min_kernel_m) {
best_result.first = divisor;
best_result.second = m_kernel;
}
}
if (best_result.first * batch_dim >= optimal_parallelism_work_amount)
return best_result;
return std::make_pair(1, m_dim);
}

bool SplitDimensionM::can_be_optimized(const std::shared_ptr<const ov::Node>& node, size_t concurrency) {
if (!is_supported_matmul(node))
return false;
Expand Down Expand Up @@ -131,16 +150,25 @@ bool SplitDimensionM::split(const ov::Shape& shape, size_t optimal_parallelism_w
if (is_prime_number(m_dim))
return false;

auto is_optimized = [&](size_t batch_dim) {
return batch_dim >= optimal_parallelism_work_amount;
};

// We skip optimization if the current batch is optimal for concurrency
if (is_optimized(batch_dim))
if (batch_dim % optimal_parallelism_work_amount == 0)
return false;

std::tie(batch_m_dim, new_m_dim) = get_splited_dimensions(batch_dim, m_dim, optimal_parallelism_work_amount);
return is_optimized(batch_dim * batch_m_dim);
auto split_is_done = [&batch_m_dim]() {
return batch_m_dim != 1;
};

std::tie(batch_m_dim, new_m_dim) = split_ideally(batch_dim, m_dim, optimal_parallelism_work_amount);
if (split_is_done())
return true;

std::tie(batch_m_dim, new_m_dim) = split_minimize_kernel_wa(batch_dim, m_dim, optimal_parallelism_work_amount);
if (split_is_done())
return true;
// If all the previous heuristics failed, fallback heuristic is used, which reflects the old splitting behavior
if (batch_dim < optimal_parallelism_work_amount)
std::tie(batch_m_dim, new_m_dim) = split_fallback_increase_parallel_wa(batch_dim, m_dim, optimal_parallelism_work_amount);
return split_is_done();
}

void SplitDimensionM::reshape_subgraph(const std::shared_ptr<op::Subgraph>& subgraph, const ov::Shape& shape, size_t batch_m_dim, size_t new_m_dim) {
Expand Down
6 changes: 3 additions & 3 deletions src/common/snippets/tests/src/pass/mha_tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM) {
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM_withMul) {
const auto& f = MHASplitMFunction(std::vector<PartialShape>{{128, 12, 64}, {128, 12, 64}, {12, 128, 128}, {128, 12, 64}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{2, 64, 12, 64}, {12, 1, 64, 128}, {12, 2, 64, 128}, {1, 128, 12, 64}, {128, 12, 64}},
std::vector<Shape>{{4, 32, 12, 64}, {12, 1, 64, 128}, {12, 4, 32, 128}, {1, 128, 12, 64}, {128, 12, 64}},
true);
model = f.getOriginal();
model_ref = f.getReference();
Expand All @@ -182,7 +182,7 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM_withMul) {
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM) {
const auto& f = MHASplitMFunction(std::vector<PartialShape>{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{1, 6, 64, 16, 64}, {1, 16, 1, 64, 384}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}},
std::vector<Shape>{{1, 12, 32, 16, 64}, {1, 16, 1, 64, 384}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}},
false);
model = f.getOriginal();
model_ref = f.getReference();
Expand All @@ -193,7 +193,7 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM) {
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM_withMul) {
const auto& f = MHASplitMFunction(std::vector<PartialShape>{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{1, 6, 64, 16, 64}, {1, 16, 1, 64, 384}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}},
std::vector<Shape>{{1, 12, 32, 16, 64}, {1, 16, 1, 64, 384}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}},
true);
model = f.getOriginal();
model_ref = f.getReference();
Expand Down
5 changes: 5 additions & 0 deletions src/common/snippets/tests/src/utils/split_dim_m.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ const std::vector<SplitDimensionMParams> split_dimension_cases = {
{InputData{25, 50, 40}, ReferenceData{true, 2, 25}},
{InputData{5, 16384, 40}, ReferenceData{true, 8, 2048}},
{InputData{5, 16384, 32}, ReferenceData{true, 32, 512}},
{InputData{48, 4097, 32}, ReferenceData{true, 17, 241}},
{InputData{48, 6600, 32}, ReferenceData{true, 200, 33}},
{InputData{12, 128, 16}, ReferenceData{true, 4, 32}},
{InputData{16, 384, 60}, ReferenceData{true, 12, 32}},
{InputData{16, 384, 24}, ReferenceData{true, 12, 32}},
};

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_SplitDimensionM,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,7 @@ void Transformations::MainSnippets(void) {
return false;
const auto parallel_work_amount =
std::accumulate(shape.rbegin() + 2, shape.rend(), ov::Dimension(1), std::multiplies<ov::Dimension>());
// Ticket 160154: enable tokenization for MHA with insufficient parallel work amount
const auto is_unsupported_parallel_work_amount =
static_cast<size_t>(parallel_work_amount.get_length()) < tokenization_config.get_concurrency() &&
!ov::snippets::pass::SplitDimensionM::can_be_optimized(n, tokenization_config.get_concurrency());
Expand Down

0 comments on commit 7d183e3

Please sign in to comment.