diff --git a/src/common/snippets/src/lowered/pass/brgemm_blocking.cpp b/src/common/snippets/src/lowered/pass/brgemm_blocking.cpp index e0a4ba285de9cd..097b1156803ca3 100644 --- a/src/common/snippets/src/lowered/pass/brgemm_blocking.cpp +++ b/src/common/snippets/src/lowered/pass/brgemm_blocking.cpp @@ -147,8 +147,11 @@ bool BrgemmBlockingBase::mark_blocking_loops(snippets::lowered::LinearIR& linear const auto& loop_manager = linear_ir.get_loop_manager(); if (!ov::snippets::utils::is_full_dim_value(k_block)) { - const std::vector entries{LoopPort(brgemm_expr->get_input_port(0), true, 0), - LoopPort(brgemm_expr->get_input_port(1), true, 1)}; + // const std::vector entries{LoopPort(brgemm_expr->get_input_port(0), true, 0), + // LoopPort(brgemm_expr->get_input_port(1), true, 1)}; + // const std::vector exits{LoopPort(brgemm_expr->get_output_port(0), false)}; + const std::vector entries{LoopPort(brgemm_expr->get_input_port(0), true), + LoopPort(brgemm_expr->get_input_port(1), true)}; const std::vector exits{LoopPort(brgemm_expr->get_output_port(0), false)}; mark_k_blocking(loop_manager, brgemm_it, std::next(brgemm_it), entries, exits, k_block); } diff --git a/src/common/snippets/src/lowered/pass/split_loops.cpp b/src/common/snippets/src/lowered/pass/split_loops.cpp index 4295327225d713..47b3fd9e3e59ef 100644 --- a/src/common/snippets/src/lowered/pass/split_loops.cpp +++ b/src/common/snippets/src/lowered/pass/split_loops.cpp @@ -153,7 +153,7 @@ bool SplitLoops::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, // in case if loops are not split but FuseLoops is registered in pass manager after SplitLoops if (loop_was_split) FuseLoops().run(linear_ir, begin, end); - // + // fuse loop cover case of port update of this split // for (auto expr_it = begin; expr_it != end; expr_it++) { // const auto expr = *expr_it; // const auto& node = expr->get_node(); @@ -165,9 +165,6 @@ bool SplitLoops::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, // std::cout << "loop_dim:" << loop_info->get_dim_idx() << std::endl; // std::cout << "wa:" << loop_info->get_work_amount() << std::endl; // std::cout << "inc:" << loop_info->get_increment() << std::endl; - // if (loop_info->get_dim_idx() == 18446744073709551615) { - // loop_info->set_dim_idx(0); - // } // auto in_ports = loop_info->get_input_ports(); // for (auto in_port : in_ports) { // auto in_port_parent = in_port.expr_port->get_connected_ports().begin()->get_expr(); @@ -176,8 +173,13 @@ bool SplitLoops::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, // auto out_ports = loop_info->get_output_ports(); // for (auto out_port : out_ports) { - // auto out_port_child = out_port.expr_port->get_connected_ports().begin()->get_expr(); - // std::cout << "out_port_child to:" << out_port_child->get_node()->get_friendly_name() << std::endl; + // auto consumers = out_port.expr_port->get_connected_ports(); + // std::cout << "out_port_child is one of:" << std::endl; + // for (auto consumer : consumers) { + // auto out_port_child = consumer.get_expr(); + // std::cout << "out_port_child to:" << out_port_child->get_node()->get_friendly_name() << std::endl; + // } + // std::cout << "out_port_child is one of end."<< std::endl; // } // } // } @@ -199,18 +201,18 @@ size_t SplitLoops::split(LinearIR& linear_ir, size_t loop_to_split_id, size_t ou auto outer_loop_bounds_second = loop_bounds.second; auto in_ports = inner_loop_info->get_input_ports(); auto out_ports = inner_loop_info->get_output_ports(); + // extend outer block loop of inner most dimension to child that have no inner most loop. - // for example, Hsum should perform on every on M_blk*K_blk*M_in_block loops. + // for example, Hmax should perform on every on M_blk*K_blk*M_in_block loops. if (inner_loop_info->get_dim_idx() == 0) { + // extend child for (auto expr_it = outer_loop_bounds_first; expr_it != outer_loop_bounds_second; expr_it++) { auto expr = *expr_it; std::cout << "expr:" << expr->get_node()->get_friendly_name() << std::endl; - const auto& expr_loop = expr->get_loop_ids(); + const auto& expr_loop = expr->get_loop_ids(); // expr_loop is [2,3,1] child is [2,3] if (expr->get_output_count() != 1) break; - auto expr_out_port = expr->get_output_port(0); bool check_next = true; - bool extend_global = false; while (check_next) { if (expr->get_output_count() != 1) break; @@ -233,7 +235,7 @@ size_t SplitLoops::split(LinearIR& linear_ir, size_t loop_to_split_id, size_t ou if (child_expr_loop.size() < 1) continue; const auto& child_last_loop_dim = loop_manager->get_loop_info(child_expr_loop.back())->get_dim_idx(); - if ((expr_loop.size() - child_expr_loop.size() != 1) && child_last_loop_dim == 0) { + if ((expr_loop.size() - child_expr_loop.size() != 1) || child_last_loop_dim == 0) { continue; } // check expr and child have common outer loop id @@ -270,81 +272,83 @@ size_t SplitLoops::split(LinearIR& linear_ir, size_t loop_to_split_id, size_t ou linear_ir.move(child_expr_it, outer_loop_bounds_second); // move new expr in this loop before loop_bounds_end } std::cout << "extend child:" << child_expr->get_node()->get_friendly_name() << std::endl; + auto expr_out_port = expr->get_output_port(0); + inner_loop_info->replace_with_new_ports(expr_out_port, child_expr->get_output_ports()); + expr = child_expr; extend = true; - extend_global = true; break; // if one child is extend, stop other consumers. continue this branch } if (!extend) { check_next = false; // no child extend, stop extend deeper } } - if (extend_global) { - // if expr port is loop port, replace to new child output port. - auto loop_port = std::find_if(out_ports.begin(), out_ports.end(), [&](LoopPort loop_ports) { - return *loop_ports.expr_port.get() == expr_out_port; - }); - if (loop_port != out_ports.end()) { - inner_loop_info->replace_with_new_ports(expr_out_port, {(*std::prev(outer_loop_bounds_second))->get_output_port(0)}); - } - } } // extend parent + std::cout << "start extend parent" << std::endl; for (auto expr_it = outer_loop_bounds_first; expr_it != outer_loop_bounds_second; expr_it++) { auto expr = *expr_it; std::cout << "expr_for_parent:" << expr->get_node()->get_friendly_name() << std::endl; - auto in_num = expr->get_input_count(); - if (in_num > 2 || in_num < 1) - continue; - auto expr_in_port = expr->get_input_port(in_num-1); bool check_next = true; const auto& expr_loop = expr->get_loop_ids(); bool extend = false; while (check_next) { - // if (expr->get_input_count() != 1) - // break; - in_num = expr->get_input_count(); - if (in_num > 2 || in_num < 1) - break; - auto parent_expr = expr->get_input_port_connector(in_num-1)->get_source().get_expr(); - std::cout << "parent_expr:" << parent_expr->get_node()->get_friendly_name() << std::endl; - if (parent_expr->get_output_count() != 1) - break; - bool is_inside = false; - for (auto expr_check = outer_loop_bounds_first; expr_check != outer_loop_bounds_second; expr_check++) { - if (*expr_check == parent_expr) { - is_inside = true; - break; + auto in_num = expr->get_input_count(); + bool extend = false; + for (size_t i = 0; i < in_num; i++) { + auto parent_expr = expr->get_input_port_connector(i)->get_source().get_expr(); + bool is_inside = false; + for (auto expr_check = outer_loop_bounds_first; expr_check != outer_loop_bounds_second; expr_check++) { + if (*expr_check == parent_expr) { + is_inside = true; + break; + } } - } - std::cout << "is_inside:" << is_inside << std::endl; - if (is_inside) - break; - const auto& parent_expr_loop = parent_expr->get_loop_ids(); - // have common outer loop id, child w/o inner most loop - for (auto i = 0; i < std::min(expr_loop.size(), parent_expr_loop.size()); i++) { - if (expr_loop[i] != parent_expr_loop[i]) { - break; + std::cout << "is_inside:" << is_inside << std::endl; + if (is_inside) + continue; + + // child_last_loop_dim is not 0(inner most dimension) + const auto& parent_expr_loop = parent_expr->get_loop_ids(); + if (parent_expr_loop.size() < 1) + continue; + const auto& parent_last_loop_dim = loop_manager->get_loop_info(parent_expr_loop.back())->get_dim_idx(); + if ((expr_loop.size() - parent_expr_loop.size() != 1) || parent_last_loop_dim == 0) { + continue; } - } - const auto& last_loop_dim = loop_manager->get_loop_info(parent_expr_loop.back())->get_dim_idx(); - if ((expr_loop.size() - parent_expr_loop.size() == 1) && last_loop_dim != 0) { - outer_loop_bounds_first = linear_ir.find(parent_expr); + + bool common_outer_loop = true; + // have common outer loop id, child w/o inner most loop + for (auto i = 0; i < std::min(expr_loop.size(), parent_expr_loop.size()); i++) { + if (expr_loop[i] != parent_expr_loop[i]) { + common_outer_loop = false; + break; + } + } + if (!common_outer_loop) { + continue; + } + + // not break data dependency + + // can extend + auto parent_expr_it = linear_ir.find(parent_expr); + if (parent_expr_it != std::prev(outer_loop_bounds_first)) { + linear_ir.move(parent_expr_it, outer_loop_bounds_first); // move new expr in this loop before outer_loop_bounds_first + } + outer_loop_bounds_first--; + std::cout << "extend child:" << parent_expr->get_node()->get_friendly_name() << std::endl; + + // update port. If extend, in_port must be a loop port as parent expr is not in this loop. + auto in_port = expr->get_input_port(i); + inner_loop_info->replace_with_new_ports(in_port, parent_expr->get_input_ports()); + + expr = parent_expr; extend = true; - std::cout << "extend........" << std::endl; - } else { - break; + break; // one parent extended, check this branch from parent. } - expr = parent_expr; - } - if (extend) { - // if expr port is loop port, replace to new child output port. - auto loop_port = std::find_if(in_ports.begin(), in_ports.end(), [&](LoopPort loop_ports) { - return *loop_ports.expr_port.get() == expr_in_port; - }); - if (loop_port != out_ports.end()) { - // inner_loop_info->replace_with_new_ports(expr_in_port, {(*outer_loop_bounds_first)->get_output_port(0)}); - inner_loop_info->replace_with_new_ports(expr_in_port, {}); + if (!extend) { + check_next = false; // no parent extend, stop extend upper } } } diff --git a/src/common/snippets/tests/src/lowered/pass/split_loop.cpp b/src/common/snippets/tests/src/lowered/pass/split_loop.cpp new file mode 100644 index 00000000000000..bdd4e62c0345d6 --- /dev/null +++ b/src/common/snippets/tests/src/lowered/pass/split_loop.cpp @@ -0,0 +1,147 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "lir_test_utils.hpp" + +#include "openvino/opsets/opset10.hpp" +#include "snippets/lowered/pass/split_loops.hpp" +#include "snippets/snippets_isa.hpp" + +namespace ov { +namespace test { +namespace snippets { + +using namespace ov::snippets::lowered; +using namespace ov::snippets::lowered::pass; + +class SplitLoopTest : public LoweredPassTestsF { +public: + SplitLoopTest() : LoweredPassTestsF() { + comparator.enable(LIRComparator::LIRCmpValues::LOOP_INDICES); + comparator.enable(LIRComparator::LIRCmpValues::PORT_DESCRIPTORS); + comparator.enable(LIRComparator::LIRCmpValues::PORT_CONNECTORS); + comparator.enable(LIRComparator::LIRCmpValues::LOOP_MANAGER); + } + + void SetUp() override { + pipeline.register_pass(); + } +}; + +TEST_F(SplitLoopTest, SplitLoopTestTwoDepthBlocks) { + size_t vector_size = 16; + const auto input_precision = ov::element::f32; + const ov::Shape input_shape1{512, 64}; + const ov::Shape input_shape2{64, 1024}; + const ov::Shape input_shape3{1024, 16}; + /* + * Param1 Param2 + * \ / + * Brgemm1 VectorBuffer + * | | | + * | Fill Fill + * | \ / + * | Maximum + * | | + * | HorizonMax + * | | + * Substract Param3 + * | / + * Brgemm2 + * | + * Result + */ + { + auto param1 = linear_ir->push_node(input_precision, input_shape1); + auto param2 = linear_ir->push_node(input_precision, input_shape2); + auto param3 = linear_ir->push_node(input_precision, input_shape3); + auto brgemm1 = linear_ir->push_node(param1.second, param2.second); + const auto vector_buffer = linear_ir->push_node(input_precision); + uint32_t fill_value = 0xff7fffff; + const auto initial_fill = linear_ir->push_node(vector_buffer.second, 0, fill_value); + const auto fill = linear_ir->push_node(brgemm1.second, vector_size, fill_value); + auto max = linear_ir->push_node(initial_fill.second, fill.second); + auto h_max = linear_ir->push_node(max.second); + auto sub = linear_ir->push_node(brgemm1.second, h_max.second); + auto brgemm2 = linear_ir->push_node(sub.second, param3.second); + const auto result = linear_ir->push_node(brgemm2.second); + const auto& loop_manager = linear_ir->get_loop_manager(); + // two loops for brgemm1 + loop_manager->mark_loop(brgemm1.first, vector_buffer.first, 512, 32, 1, + std::vector{LoopPort((*brgemm1.first)->get_input_port(0), true, 1), + LoopPort((*brgemm1.first)->get_input_port(1), false, 1)}, + std::vector{LoopPort((*brgemm1.first)->get_output_port(0), true, 1)}); + loop_manager->mark_loop(brgemm1.first, vector_buffer.first, 1024, 64, 0, + std::vector{LoopPort((*brgemm1.first)->get_input_port(0), false, 0), + LoopPort((*brgemm1.first)->get_input_port(1), true, 0)}, + std::vector{LoopPort((*brgemm1.first)->get_output_port(0), true, 0)}); + // loops on column + loop_manager->mark_loop(fill.first, h_max.first, 1024, vector_size, 0, + std::vector{LoopPort((*fill.first)->get_input_port(0), true, 0)}, + std::vector{LoopPort((*max.first)->get_output_port(0), true, 0)}); + loop_manager->mark_loop(sub.first, brgemm2.first, 1024, vector_size, 0, + std::vector{LoopPort((*sub.first)->get_input_port(0), true, 0)}, + std::vector{LoopPort((*sub.first)->get_output_port(0), true, 0)}); + // loop on row + loop_manager->mark_loop(vector_buffer.first, brgemm2.first, 512, 1, 1, + std::vector{LoopPort((*fill.first)->get_input_port(0), true, 1), + LoopPort((*sub.first)->get_input_port(0), true, 1)}, + std::vector{LoopPort((*sub.first)->get_output_port(0), true, 1)}); + // two loops for brgemm2 + loop_manager->mark_loop(brgemm2.first, result.first, 512, 32, 1, + std::vector{LoopPort((*brgemm2.first)->get_input_port(0), true, 1), + LoopPort((*brgemm2.first)->get_input_port(1), false, 1)}, + std::vector{LoopPort((*brgemm2.first)->get_output_port(0), true, 1)}); + loop_manager->mark_loop(brgemm2.first, result.first, 1024, 64, 0, + std::vector{LoopPort((*brgemm2.first)->get_input_port(0), true, 0), + LoopPort((*brgemm2.first)->get_input_port(1), true, 0)}, + std::vector{LoopPort((*brgemm2.first)->get_output_port(0), false, 0)}); + } + { + auto param1 = linear_ir_ref->push_node(input_precision, input_shape1); + auto param2 = linear_ir_ref->push_node(input_precision, input_shape2); + auto param3 = linear_ir_ref->push_node(input_precision, input_shape3); + auto brgemm1 = linear_ir_ref->push_node(param1.second, param2.second); + const auto vector_buffer = linear_ir_ref->push_node(input_precision); + uint32_t fill_value = 0xff7fffff; + const auto initial_fill = linear_ir_ref->push_node(vector_buffer.second, 0, fill_value); + const auto fill = linear_ir_ref->push_node(brgemm1.second, vector_size, fill_value); + auto max = linear_ir_ref->push_node(initial_fill.second, fill.second); + auto h_max = linear_ir_ref->push_node(max.second); + auto sub = linear_ir_ref->push_node(brgemm1.second, h_max.second); + auto brgemm2 = linear_ir_ref->push_node(sub.second, param3.second); + const auto result = linear_ir_ref->push_node(brgemm2.second); + const auto& loop_manager = linear_ir_ref->get_loop_manager(); + + // block inner loops for dimension 0 + loop_manager->mark_loop(fill.first, h_max.first, 64, vector_size, 0, + std::vector{LoopPort((*fill.first)->get_input_port(0), true, 0), + LoopPort((*max.first)->get_input_port(0), false, 0)}, // Max(initial_fill, fill) + std::vector{LoopPort((*max.first)->get_output_port(0), true, 0)}); + loop_manager->mark_loop(sub.first, brgemm2.first, 64, vector_size, 0, + std::vector{LoopPort((*sub.first)->get_input_port(0), true, 0), + LoopPort((*sub.first)->get_input_port(1), false, 0)}, // sub:brgemm1-Hmax + std::vector{LoopPort((*sub.first)->get_output_port(0), true, 0)}); + // block inner loop for dimension 1. + loop_manager->mark_loop(vector_buffer.first, brgemm2.first, 32, 1, 1, + std::vector{LoopPort((*fill.first)->get_input_port(0), true, 1), + LoopPort((*sub.first)->get_input_port(0), true, 1)}, + std::vector{LoopPort((*sub.first)->get_output_port(0), true, 1)}); + // two block loops. All exprs between two brgemm including h_max should be in both block loops. + loop_manager->mark_loop(brgemm1.first, result.first, 512, 32, 1, + std::vector{LoopPort((*brgemm1.first)->get_input_port(0), true, 1), + LoopPort((*brgemm1.first)->get_input_port(1), false, 1), + LoopPort((*brgemm2.first)->get_input_port(1), false, 1)}, + std::vector{LoopPort((*brgemm2.first)->get_output_port(0), true, 1)}); + loop_manager->mark_loop(brgemm1.first, result.first, 1024, 64, 0, + std::vector{LoopPort((*brgemm1.first)->get_input_port(0), false, 0), + LoopPort((*brgemm1.first)->get_input_port(1), true, 0), + LoopPort((*brgemm2.first)->get_input_port(1), true, 0)}, + std::vector{LoopPort((*brgemm2.first)->get_output_port(0), false, 0)}); + } +} + +} // namespace snippets +} // namespace test +} // namespace ov \ No newline at end of file