diff --git a/src/common/snippets/include/snippets/op/brgemm.hpp b/src/common/snippets/include/snippets/op/brgemm.hpp index 6be07fda9f33e6..03ff70c88667f1 100644 --- a/src/common/snippets/include/snippets/op/brgemm.hpp +++ b/src/common/snippets/include/snippets/op/brgemm.hpp @@ -40,6 +40,7 @@ class Brgemm : virtual public modifier::MemoryAccess, public ov::op::Op { std::vector layout_a = {}, std::vector layout_b = {}, std::vector layout_c = {}, float beta = 0.f); Brgemm() = default; + Brgemm(bool c_pre_scale); size_t get_offset_a() const { return get_input_offset(0); } size_t get_offset_b() const { return get_input_offset(1); } diff --git a/src/common/snippets/src/lowered/pass/brgemm_blocking.cpp b/src/common/snippets/src/lowered/pass/brgemm_blocking.cpp index 097b1156803ca3..c361243cdae6ec 100644 --- a/src/common/snippets/src/lowered/pass/brgemm_blocking.cpp +++ b/src/common/snippets/src/lowered/pass/brgemm_blocking.cpp @@ -146,24 +146,38 @@ bool BrgemmBlockingBase::mark_blocking_loops(snippets::lowered::LinearIR& linear brgemm_expr->get_output_port_descriptor(0)->set_subtensor(ov::snippets::VectorDims{m_block, n_block}); const auto& loop_manager = linear_ir.get_loop_manager(); + const auto& brgemm_node = std::dynamic_pointer_cast((*brgemm_it)->get_node()); if (!ov::snippets::utils::is_full_dim_value(k_block)) { // const std::vector entries{LoopPort(brgemm_expr->get_input_port(0), true, 0), // LoopPort(brgemm_expr->get_input_port(1), true, 1)}; // const std::vector exits{LoopPort(brgemm_expr->get_output_port(0), false)}; - const std::vector entries{LoopPort(brgemm_expr->get_input_port(0), true), - LoopPort(brgemm_expr->get_input_port(1), true)}; - const std::vector exits{LoopPort(brgemm_expr->get_output_port(0), false)}; + std::vector entries{LoopPort(brgemm_expr->get_input_port(0), true, 0), + LoopPort(brgemm_expr->get_input_port(1), true, 0)}; + // if (brgemm_node->with_c_pre_ops) { + // entries.push_back(LoopPort(brgemm_expr->get_input_port(2), false, 0)); + // } + const std::vector exits{LoopPort(brgemm_expr->get_output_port(0), false, 0)}; mark_k_blocking(loop_manager, brgemm_it, std::next(brgemm_it), entries, exits, k_block); } if (!ov::snippets::utils::is_full_dim_value(n_block)) { - const std::vector entries{LoopPort(brgemm_expr->get_input_port(0), false), + std::vector entries{LoopPort(brgemm_expr->get_input_port(0), false), LoopPort(brgemm_expr->get_input_port(1), true)}; + // if (brgemm_node->with_c_pre_ops) { + // entries.push_back(LoopPort(brgemm_expr->get_input_port(2), false)); + // } const std::vector exits{LoopPort(brgemm_expr->get_output_port(0), true)}; mark_n_blocking(loop_manager, brgemm_it, std::next(brgemm_it), entries, exits, n_block); } if (!ov::snippets::utils::is_full_dim_value(m_block)) { - const std::vector entries{LoopPort(brgemm_expr->get_input_port(0), true), - LoopPort(brgemm_expr->get_input_port(1), false)}; + std::vector entries{LoopPort(brgemm_expr->get_input_port(0), true), + LoopPort(brgemm_expr->get_input_port(1), false)}; + // if (brgemm_node->with_c_pre_ops) { + // std::cout << "brgemm_node->with_c_pre_ops.................." << std::endl; + // entries.push_back(LoopPort(brgemm_expr->get_input_port(2), true)); + // } + std::cout << "brgemm_node:" << brgemm_node->get_friendly_name() << std::endl; + std::cout << "brgemm_node->with_c_pre_ops not.................." << std::endl; + std::cout << "brgemm_node input:" << brgemm_node->get_input_count() << std::endl; const std::vector exits{LoopPort(brgemm_expr->get_output_port(0), true)}; mark_m_blocking(loop_manager, brgemm_it, std::next(brgemm_it), entries, exits, m_block); } diff --git a/src/common/snippets/src/lowered/pass/split_loops.cpp b/src/common/snippets/src/lowered/pass/split_loops.cpp index 47b3fd9e3e59ef..edf508b894bc24 100644 --- a/src/common/snippets/src/lowered/pass/split_loops.cpp +++ b/src/common/snippets/src/lowered/pass/split_loops.cpp @@ -188,6 +188,39 @@ bool SplitLoops::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, return loop_was_split; } +std::vector get_loop_input_exprs(const std::vector& loop_in_ports) { + std::vector parent_exprs; + std::unordered_set seen_exprs; + for (size_t port_num = 0; port_num < loop_in_ports.size(); ++port_num) { + const auto& parent_expr = loop_in_ports[port_num].expr_port->get_connected_ports().begin()->get_expr(); + if (seen_exprs.count(parent_expr) == 0) { + parent_exprs.push_back(parent_expr); + seen_exprs.insert(parent_expr); + } + } + return parent_exprs; +} + +// std::vector is_parent_extend_applicable(const std::vector& loop_in_ports) { +// parent_expr +// } + +std::vector get_loop_output_exprs(const std::vector& loop_out_ports) { + std::vector child_exprs; + std::unordered_set seen_exprs; + for (size_t port_num = 0; port_num < loop_out_ports.size(); ++port_num) { + const auto& consumers = loop_out_ports[port_num].expr_port->get_connected_ports(); + for (const auto consumer : consumers) { + auto child_expr = consumer.get_expr(); + if (seen_exprs.count(child_expr) == 0) { + child_exprs.push_back(child_expr); + seen_exprs.insert(child_expr); + } + } + } + return child_exprs; +} + size_t SplitLoops::split(LinearIR& linear_ir, size_t loop_to_split_id, size_t outer_increment, size_t loop_position) { const auto& loop_manager = linear_ir.get_loop_manager(); @@ -202,164 +235,172 @@ size_t SplitLoops::split(LinearIR& linear_ir, size_t loop_to_split_id, size_t ou auto in_ports = inner_loop_info->get_input_ports(); auto out_ports = inner_loop_info->get_output_ports(); - // extend outer block loop of inner most dimension to child that have no inner most loop. - // for example, Hmax should perform on every on M_blk*K_blk*M_in_block loops. - if (inner_loop_info->get_dim_idx() == 0) { - // extend child - for (auto expr_it = outer_loop_bounds_first; expr_it != outer_loop_bounds_second; expr_it++) { - auto expr = *expr_it; - std::cout << "expr:" << expr->get_node()->get_friendly_name() << std::endl; - const auto& expr_loop = expr->get_loop_ids(); // expr_loop is [2,3,1] child is [2,3] - if (expr->get_output_count() != 1) - break; - bool check_next = true; - while (check_next) { - if (expr->get_output_count() != 1) - break; - const auto& consumers = expr->get_output_port_connector(0)->get_consumers(); - bool extend = false; - for (const auto& consumer : consumers) { - const auto& child_expr = consumer.get_expr(); - // check if child is already in outer block loop - bool is_inside = false; - for (auto expr_check = outer_loop_bounds_first; expr_check != outer_loop_bounds_second; expr_check++) { - if (*expr_check == child_expr) { - is_inside = true; - break; - } - } - if (is_inside) - continue; - // child_last_loop_dim is not 0(inner most dimension) - const auto& child_expr_loop = child_expr->get_loop_ids(); - if (child_expr_loop.size() < 1) - continue; - const auto& child_last_loop_dim = loop_manager->get_loop_info(child_expr_loop.back())->get_dim_idx(); - if ((expr_loop.size() - child_expr_loop.size() != 1) || child_last_loop_dim == 0) { - continue; - } - // check expr and child have common outer loop id - bool have_common_outer_loop = true; - for (auto i = 0; i < std::min(expr_loop.size(), child_expr_loop.size()); i++) { - if (expr_loop[i] != child_expr_loop[i]) { - have_common_outer_loop = false; - break; - } - } - if (!have_common_outer_loop) - continue; - - // all inputs of child not break data dependency - bool data_conflict = false; - size_t in_num = child_expr->get_input_count(); - for (size_t i = 0; i < in_num; i++) { - auto parent = child_expr->get_input_port_connector(i)->get_source().get_expr(); - if (parent == expr) - continue; - if (parent->get_exec_num() > (*outer_loop_bounds_second)->get_exec_num()) { - data_conflict = true; - break; - } - } - if (data_conflict) - continue; + // save loop ids before mark + const auto expr_loop = (*outer_loop_bounds_first)->get_loop_ids(); + const auto outer_loop_id = loop_manager->mark_loop(outer_loop_bounds_first, outer_loop_bounds_second, inner_loop_info->get_work_amount(), + outer_increment, inner_loop_info->get_dim_idx(), + in_ports, out_ports, + false, true, loop_position); + const auto& outer_loop_info = loop_manager->get_loop_info(outer_loop_id); + // extend based on loop ports + if (outer_loop_info->get_dim_idx() == 0) { + // extend parent + bool continue_extend_parent = true; + while (continue_extend_parent) { + const auto& input_ports = outer_loop_info->get_input_ports(); + const auto& potential_extended_exprs = get_loop_input_exprs(input_ports); + bool expr_extend = false; + for (const auto& parent_expr : potential_extended_exprs) { + // child_last_loop_dim is not 0(inner most dimension) + std::cout << "parent_expr:" << parent_expr->get_node()->get_friendly_name() << std::endl; + const auto& parent_expr_loop = parent_expr->get_loop_ids(); + // for (auto loop : parent_expr_loop) { + // std::cout << "parent_expr_loop:" << loop << std::endl; + // } + // for (auto loop : expr_loop) { + // std::cout << "expr_loop:" << loop << std::endl; + // } + if (parent_expr_loop.size() < 1) + continue; + const auto& parent_last_loop_dim = loop_manager->get_loop_info(parent_expr_loop.back())->get_dim_idx(); + // max[0,1] h_max[0]. after outer dim split: max[2,3,1] h_max[2,3]. after inner dim split: max[2,4,3,5] h_max[2,4,3] + if ((expr_loop.size() - parent_expr_loop.size() != 1) || parent_last_loop_dim == 0) { + continue; + } - // can extend - auto child_expr_it = linear_ir.find(child_expr); - if (child_expr_it == outer_loop_bounds_second) { - outer_loop_bounds_second++; // child is outer_loop_bounds_second, just update loop_bound_end to next - } else { - linear_ir.move(child_expr_it, outer_loop_bounds_second); // move new expr in this loop before loop_bounds_end + bool common_outer_loop = true; + // have common outer loop id, child w/o inner most loop. max[2,3,1] h_max[2,3]. + for (auto i = 0; i < std::min(expr_loop.size(), parent_expr_loop.size()); i++) { + if (expr_loop[i] != parent_expr_loop[i]) { + common_outer_loop = false; + break; } - std::cout << "extend child:" << child_expr->get_node()->get_friendly_name() << std::endl; - auto expr_out_port = expr->get_output_port(0); - inner_loop_info->replace_with_new_ports(expr_out_port, child_expr->get_output_ports()); - - expr = child_expr; - extend = true; - break; // if one child is extend, stop other consumers. continue this branch } - if (!extend) { - check_next = false; // no child extend, stop extend deeper + if (!common_outer_loop) { + continue; } - } - } - // extend parent - std::cout << "start extend parent" << std::endl; - for (auto expr_it = outer_loop_bounds_first; expr_it != outer_loop_bounds_second; expr_it++) { - auto expr = *expr_it; - std::cout << "expr_for_parent:" << expr->get_node()->get_friendly_name() << std::endl; - bool check_next = true; - const auto& expr_loop = expr->get_loop_ids(); - bool extend = false; - while (check_next) { - auto in_num = expr->get_input_count(); - bool extend = false; - for (size_t i = 0; i < in_num; i++) { - auto parent_expr = expr->get_input_port_connector(i)->get_source().get_expr(); - bool is_inside = false; - for (auto expr_check = outer_loop_bounds_first; expr_check != outer_loop_bounds_second; expr_check++) { - if (*expr_check == parent_expr) { - is_inside = true; - break; - } - } - std::cout << "is_inside:" << is_inside << std::endl; - if (is_inside) - continue; - - // child_last_loop_dim is not 0(inner most dimension) - const auto& parent_expr_loop = parent_expr->get_loop_ids(); - if (parent_expr_loop.size() < 1) - continue; - const auto& parent_last_loop_dim = loop_manager->get_loop_info(parent_expr_loop.back())->get_dim_idx(); - if ((expr_loop.size() - parent_expr_loop.size() != 1) || parent_last_loop_dim == 0) { - continue; - } + // std::cout << "parent_expr common_outer_loop:" << common_outer_loop << std::endl; + + // todo: check not break data dependency - bool common_outer_loop = true; - // have common outer loop id, child w/o inner most loop - for (auto i = 0; i < std::min(expr_loop.size(), parent_expr_loop.size()); i++) { - if (expr_loop[i] != parent_expr_loop[i]) { - common_outer_loop = false; - break; + // can extend + auto parent_expr_it = linear_ir.find(parent_expr); + if (parent_expr_it != std::prev(outer_loop_bounds_first)) { + linear_ir.move(parent_expr_it, outer_loop_bounds_first); // move new expr in this loop before outer_loop_bounds_first + } + outer_loop_bounds_first--; + auto parent_loops = parent_expr_loop; // 16,17 + // std::cout << "loop_position:" << loop_position << std::endl; // 17 + auto insert_pos = std::find_if(parent_loops.begin(), parent_loops.end(), [&](size_t loop){ + return loop == loop_position; + }); + // if (insert_pos != parent_loops.end()) { + // std::cout << "insert_pos:" << (*insert_pos) << std::endl; // 17 + // } + parent_loops.insert(insert_pos, outer_loop_id); // insert 18 before 17 + parent_expr->set_loop_ids(parent_loops); + // for (auto loop : parent_loops) { + // std::cout << "parent_loops:" << loop << std::endl; + // } + // std::cout << "extend parent:" << parent_expr->get_node()->get_friendly_name() << std::endl; + + // update ports + // if parent out port is loop input port, delete the loop input port. add parent input port as new loop input port + const auto& new_in_ports = parent_expr->get_input_ports(); + bool inserted = false; + for (const auto& out_port : parent_expr->get_output_ports()) { + for (const auto& in_port_connect : out_port.get_connected_ports()) { + if (outer_loop_info->is_loop_port(in_port_connect)) { + outer_loop_info->replace_with_new_ports(in_port_connect, (inserted ? std::vector{} : new_in_ports)); + inserted = true; } } - if (!common_outer_loop) { - continue; - } + } + expr_extend = true; + break; + } + // no more extended expr in this loop after go through all potential_extended_exprs, done for this loop. + if (!expr_extend) + continue_extend_parent = false; + } - // not break data dependency + // extend child + bool continue_extend_child = true; + while (continue_extend_child) { + const auto& output_ports = outer_loop_info->get_output_ports(); + const auto& potential_extended_exprs = get_loop_output_exprs(output_ports); + bool expr_extend = false; + for (const auto& child_expr : potential_extended_exprs) { + auto child_expr_loop = child_expr->get_loop_ids(); + if (child_expr_loop.size() < 1) + continue; + const auto& child_last_loop_dim = loop_manager->get_loop_info(child_expr_loop.back())->get_dim_idx(); + // max[0,1] h_max[0]. after outer dim split: max[2,3,1] h_max[2,3]. after inner dim split: max[2,4,3,5] h_max[2,4,3] + if ((expr_loop.size() - child_expr_loop.size() != 1) || child_last_loop_dim == 0) { + continue; + } - // can extend - auto parent_expr_it = linear_ir.find(parent_expr); - if (parent_expr_it != std::prev(outer_loop_bounds_first)) { - linear_ir.move(parent_expr_it, outer_loop_bounds_first); // move new expr in this loop before outer_loop_bounds_first + bool common_outer_loop = true; + // have common outer loop id, child w/o inner most loop. max[2,3,1] h_max[2,3]. + for (auto i = 0; i < std::min(expr_loop.size(), child_expr_loop.size()); i++) { + if (expr_loop[i] != child_expr_loop[i]) { + common_outer_loop = false; + break; } - outer_loop_bounds_first--; - std::cout << "extend child:" << parent_expr->get_node()->get_friendly_name() << std::endl; + } + if (!common_outer_loop) { + continue; + } - // update port. If extend, in_port must be a loop port as parent expr is not in this loop. - auto in_port = expr->get_input_port(i); - inner_loop_info->replace_with_new_ports(in_port, parent_expr->get_input_ports()); + // all inputs of child not break data dependency + bool data_conflict = false; + size_t in_num = child_expr->get_input_count(); + for (size_t i = 0; i < in_num; i++) { + auto parent = child_expr->get_input_port_connector(i)->get_source().get_expr(); + if (parent->get_exec_num() > (*outer_loop_bounds_second)->get_exec_num()) { + data_conflict = true; + break; + } + } + if (data_conflict) + continue; - expr = parent_expr; - extend = true; - break; // one parent extended, check this branch from parent. + // can extend + auto child_expr_it = linear_ir.find(child_expr); + if (child_expr_it != outer_loop_bounds_second) { + linear_ir.move(child_expr_it, outer_loop_bounds_second); // move new expr in this loop before outer_loop_bounds_first + } else { + outer_loop_bounds_second++; } - if (!extend) { - check_next = false; // no parent extend, stop extend upper + auto insert_pos = std::find_if(child_expr_loop.begin(), child_expr_loop.end(), [&](size_t loop){ + return loop == loop_position; + }); + // if (insert_pos != parent_loops.end()) { + // std::cout << "insert_pos:" << (*insert_pos) << std::endl; // 17 + // } + child_expr_loop.insert(insert_pos, outer_loop_id); // insert 18 before 17 + child_expr->set_loop_ids(child_expr_loop); + + // update ports. no update, can not extend chain + const auto& new_out_ports = child_expr->get_output_ports(); + bool inserted = false; + const auto child_in_ports = child_expr->get_input_ports(); + for (const auto port : child_in_ports) { + const auto& out_port_connect = port.get_connected_ports().begin(); + if (outer_loop_info->is_loop_port(*out_port_connect)) { + outer_loop_info->replace_with_new_ports(*out_port_connect, (inserted ? std::vector{} : new_out_ports)); + inserted = true; + } } + + expr_extend = true; + break; } + if (!expr_extend) + continue_extend_child = false; } } - const auto outer_loop_id = loop_manager->mark_loop(outer_loop_bounds_first, outer_loop_bounds_second, inner_loop_info->get_work_amount(), - outer_increment, inner_loop_info->get_dim_idx(), - in_ports, out_ports, - false, true, loop_position); - const auto& outer_loop_info = loop_manager->get_loop_info(outer_loop_id); - const auto& inner_splitted_loop_info = std::make_shared(inner_loop_info->get_increment(), inner_loop_info->get_input_ports(), inner_loop_info->get_output_ports(), inner_loop_info->get_input_port_descs(), diff --git a/src/common/snippets/src/op/brgemm.cpp b/src/common/snippets/src/op/brgemm.cpp index 716f969df48d64..23f45981b86cf1 100644 --- a/src/common/snippets/src/op/brgemm.cpp +++ b/src/common/snippets/src/op/brgemm.cpp @@ -71,6 +71,10 @@ Brgemm::Brgemm(const Output& A, const Output& B, const Output& custom_constructor_validate_and_infer_types(std::move(layout_a), std::move(layout_b), std::move(layout_c)); } +Brgemm::Brgemm(bool c_pre_scale) { + with_c_pre_ops = c_pre_scale; +} + void Brgemm::custom_constructor_validate_and_infer_types(std::vector layout_a, std::vector layout_b, std::vector layout_c) { INTERNAL_OP_SCOPE(BrgemmCPU_constructor_validate_and_infer_types); diff --git a/src/common/snippets/tests/src/lowered/pass/split_loop.cpp b/src/common/snippets/tests/src/lowered/pass/split_loop.cpp index bdd4e62c0345d6..42af3035759ab5 100644 --- a/src/common/snippets/tests/src/lowered/pass/split_loop.cpp +++ b/src/common/snippets/tests/src/lowered/pass/split_loop.cpp @@ -17,37 +17,32 @@ using namespace ov::snippets::lowered::pass; class SplitLoopTest : public LoweredPassTestsF { public: - SplitLoopTest() : LoweredPassTestsF() { - comparator.enable(LIRComparator::LIRCmpValues::LOOP_INDICES); - comparator.enable(LIRComparator::LIRCmpValues::PORT_DESCRIPTORS); - comparator.enable(LIRComparator::LIRCmpValues::PORT_CONNECTORS); - comparator.enable(LIRComparator::LIRCmpValues::LOOP_MANAGER); - } - void SetUp() override { pipeline.register_pass(); } }; -TEST_F(SplitLoopTest, SplitLoopTestTwoDepthBlocks) { +TEST_F(SplitLoopTest, SplitLoopHmaxInInnerDimensionBlockLoopTest) { size_t vector_size = 16; const auto input_precision = ov::element::f32; const ov::Shape input_shape1{512, 64}; const ov::Shape input_shape2{64, 1024}; const ov::Shape input_shape3{1024, 16}; - /* + /* Brgemm1 and brgemm2 have two block loops. + * HorizonMax, Fill1 and VectorBuffer should be included into block loop of inner most dimension. + * * Param1 Param2 * \ / * Brgemm1 VectorBuffer * | | | - * | Fill Fill + * | Fill0 Fill1 * | \ / * | Maximum * | | * | HorizonMax - * | | + * \ / * Substract Param3 - * | / + * \ / * Brgemm2 * | * Result @@ -67,36 +62,40 @@ TEST_F(SplitLoopTest, SplitLoopTestTwoDepthBlocks) { auto brgemm2 = linear_ir->push_node(sub.second, param3.second); const auto result = linear_ir->push_node(brgemm2.second); const auto& loop_manager = linear_ir->get_loop_manager(); - // two loops for brgemm1 - loop_manager->mark_loop(brgemm1.first, vector_buffer.first, 512, 32, 1, - std::vector{LoopPort((*brgemm1.first)->get_input_port(0), true, 1), - LoopPort((*brgemm1.first)->get_input_port(1), false, 1)}, - std::vector{LoopPort((*brgemm1.first)->get_output_port(0), true, 1)}); + // two loops(N and M) for brgemm1. mark inner first as mark new loop inserted as outer loop as default. loop_manager->mark_loop(brgemm1.first, vector_buffer.first, 1024, 64, 0, std::vector{LoopPort((*brgemm1.first)->get_input_port(0), false, 0), LoopPort((*brgemm1.first)->get_input_port(1), true, 0)}, std::vector{LoopPort((*brgemm1.first)->get_output_port(0), true, 0)}); + loop_manager->mark_loop(brgemm1.first, vector_buffer.first, 512, 32, 1, + std::vector{LoopPort((*brgemm1.first)->get_input_port(0), true, 1), + LoopPort((*brgemm1.first)->get_input_port(1), false, 1)}, + std::vector{LoopPort((*brgemm1.first)->get_output_port(0), true, 1)}); // loops on column loop_manager->mark_loop(fill.first, h_max.first, 1024, vector_size, 0, - std::vector{LoopPort((*fill.first)->get_input_port(0), true, 0)}, + std::vector{LoopPort((*fill.first)->get_input_port(0), true, 0), + // skip (*max.first)->get_input_port(0) ? It is vector_buffer, not memory. + // mark reduce_max first and set port, then decompose reduce_max, vector_fill port should not added. + LoopPort((*max.first)->get_input_port(0), false, 0)}, std::vector{LoopPort((*max.first)->get_output_port(0), true, 0)}); loop_manager->mark_loop(sub.first, brgemm2.first, 1024, vector_size, 0, - std::vector{LoopPort((*sub.first)->get_input_port(0), true, 0)}, + std::vector{LoopPort((*sub.first)->get_input_port(0), true, 0), + LoopPort((*sub.first)->get_input_port(1), false, 0)}, std::vector{LoopPort((*sub.first)->get_output_port(0), true, 0)}); // loop on row loop_manager->mark_loop(vector_buffer.first, brgemm2.first, 512, 1, 1, std::vector{LoopPort((*fill.first)->get_input_port(0), true, 1), LoopPort((*sub.first)->get_input_port(0), true, 1)}, std::vector{LoopPort((*sub.first)->get_output_port(0), true, 1)}); - // two loops for brgemm2 - loop_manager->mark_loop(brgemm2.first, result.first, 512, 32, 1, - std::vector{LoopPort((*brgemm2.first)->get_input_port(0), true, 1), - LoopPort((*brgemm2.first)->get_input_port(1), false, 1)}, - std::vector{LoopPort((*brgemm2.first)->get_output_port(0), true, 1)}); + // two loops(K and M) for brgemm2 loop_manager->mark_loop(brgemm2.first, result.first, 1024, 64, 0, std::vector{LoopPort((*brgemm2.first)->get_input_port(0), true, 0), LoopPort((*brgemm2.first)->get_input_port(1), true, 0)}, std::vector{LoopPort((*brgemm2.first)->get_output_port(0), false, 0)}); + loop_manager->mark_loop(brgemm2.first, result.first, 512, 32, 1, + std::vector{LoopPort((*brgemm2.first)->get_input_port(0), true, 1), + LoopPort((*brgemm2.first)->get_input_port(1), false, 1)}, + std::vector{LoopPort((*brgemm2.first)->get_output_port(0), true, 1)}); } { auto param1 = linear_ir_ref->push_node(input_precision, input_shape1); @@ -117,7 +116,8 @@ TEST_F(SplitLoopTest, SplitLoopTestTwoDepthBlocks) { // block inner loops for dimension 0 loop_manager->mark_loop(fill.first, h_max.first, 64, vector_size, 0, std::vector{LoopPort((*fill.first)->get_input_port(0), true, 0), - LoopPort((*max.first)->get_input_port(0), false, 0)}, // Max(initial_fill, fill) + // skip (*max.first)->get_input_port(0) ? It is vector_buffer, not memory. + LoopPort((*max.first)->get_input_port(0), false, 0)}, std::vector{LoopPort((*max.first)->get_output_port(0), true, 0)}); loop_manager->mark_loop(sub.first, brgemm2.first, 64, vector_size, 0, std::vector{LoopPort((*sub.first)->get_input_port(0), true, 0), @@ -129,19 +129,590 @@ TEST_F(SplitLoopTest, SplitLoopTestTwoDepthBlocks) { LoopPort((*sub.first)->get_input_port(0), true, 1)}, std::vector{LoopPort((*sub.first)->get_output_port(0), true, 1)}); // two block loops. All exprs between two brgemm including h_max should be in both block loops. + loop_manager->mark_loop(brgemm1.first, result.first, 1024, 64, 0, + std::vector{LoopPort((*brgemm1.first)->get_input_port(0), false, 0), + LoopPort((*brgemm1.first)->get_input_port(1), true, 0), + LoopPort((*brgemm2.first)->get_input_port(1), true, 0)}, + std::vector{LoopPort((*brgemm2.first)->get_output_port(0), false, 0)}); loop_manager->mark_loop(brgemm1.first, result.first, 512, 32, 1, std::vector{LoopPort((*brgemm1.first)->get_input_port(0), true, 1), LoopPort((*brgemm1.first)->get_input_port(1), false, 1), LoopPort((*brgemm2.first)->get_input_port(1), false, 1)}, std::vector{LoopPort((*brgemm2.first)->get_output_port(0), true, 1)}); + } +} + +// extend child of Hsum, and move up as exprs with same loop id should be in a line together to execute. +TEST_F(SplitLoopTest, SplitLoopExtendHsumChildChain) { + size_t vector_size = 16; + const auto input_precision = ov::element::f32; + const ov::Shape input_shape1{512, 64}; + const ov::Shape input_shape2{64, 1024}; + const ov::Shape input_shape3{1024, 16}; + const ov::Shape buf_shape{512, 1}; + /* Brgemm1 and brgemm2 have two block loops. + * HorizonMax and Multiply extended, Multiply is moved up. Divide not moved as data dependency. + * + * Param1 Param2 + * \ / + * Brgemm1 + * | | + * | Relu + * | | + * | HorizonMax Buffer + * \ | \ / + * Substract Multiply + * | | | + * | HorizonSum | + * | \ / + * | Divide param3 + * \ | / + * Brgemm2 + * | + * Result + */ + { + auto param1 = linear_ir->push_node(input_precision, input_shape1); + auto param2 = linear_ir->push_node(input_precision, input_shape2); + auto param3 = linear_ir->push_node(input_precision, input_shape3); + auto buffer = linear_ir->push_node(buf_shape, input_precision); + auto brgemm1 = linear_ir->push_node(param1.second, param2.second); + auto relu = linear_ir->push_node(brgemm1.second); + auto h_max = linear_ir->push_node(relu.second); + auto sub = linear_ir->push_node(brgemm1.second, h_max.second); + auto h_sum = linear_ir->push_node(sub.second); + auto mul = linear_ir->push_node(buffer.second, h_max.second); + auto div = linear_ir->push_node(mul.second, h_sum.second); + auto brgemm2 = linear_ir->push_node(sub.second, param3.second, div.second); + const auto result = linear_ir->push_node(brgemm2.second); + const auto& loop_manager = linear_ir->get_loop_manager(); + // two loops for brgemm1 + loop_manager->mark_loop(brgemm1.first, relu.first, 1024, 64, 0, + std::vector{LoopPort((*brgemm1.first)->get_input_port(0), false, 0), + LoopPort((*brgemm1.first)->get_input_port(1), true, 0)}, + // after split, result in reused buffer, should not increment + std::vector{LoopPort((*brgemm1.first)->get_output_port(0), true, 0)}); + loop_manager->mark_loop(brgemm1.first, relu.first, 512, 32, 1, + std::vector{LoopPort((*brgemm1.first)->get_input_port(0), true, 1), + LoopPort((*brgemm1.first)->get_input_port(1), false, 1)}, + std::vector{LoopPort((*brgemm1.first)->get_output_port(0), true, 1)}); + // loops on column + loop_manager->mark_loop(relu.first, h_max.first, 1024, vector_size, 0, + std::vector{LoopPort((*relu.first)->get_input_port(0), true, 0)}, + std::vector{LoopPort((*relu.first)->get_output_port(0), true, 0)}); + loop_manager->mark_loop(sub.first, h_sum.first, 1024, vector_size, 0, + std::vector{LoopPort((*sub.first)->get_input_port(0), true, 0)}, + std::vector{LoopPort((*sub.first)->get_output_port(0), true, 0)}); + // loop on row + loop_manager->mark_loop(relu.first, brgemm2.first, 512, 1, 1, + std::vector{LoopPort((*relu.first)->get_input_port(0), true, 1), + LoopPort((*sub.first)->get_input_port(0), true, 1), + LoopPort((*mul.first)->get_input_port(0), true, 1)}, + std::vector{LoopPort((*sub.first)->get_output_port(0), true, 1), + LoopPort((*div.first)->get_output_port(0), true, 1)}); + // two loops for brgemm2 + loop_manager->mark_loop(brgemm2.first, result.first, 1024, 64, 0, + std::vector{LoopPort((*brgemm2.first)->get_input_port(0), true, 0), + LoopPort((*brgemm2.first)->get_input_port(1), true, 0), + LoopPort((*brgemm2.first)->get_input_port(2), false, 0)}, + std::vector{LoopPort((*brgemm2.first)->get_output_port(0), false, 0)}); + loop_manager->mark_loop(brgemm2.first, result.first, 512, 32, 1, + std::vector{LoopPort((*brgemm2.first)->get_input_port(0), true, 1), + LoopPort((*brgemm2.first)->get_input_port(1), false, 1), + LoopPort((*brgemm2.first)->get_input_port(2), true, 1)}, + std::vector{LoopPort((*brgemm2.first)->get_output_port(0), true, 1)}); + } + { + auto param1 = linear_ir_ref->push_node(input_precision, input_shape1); + auto param2 = linear_ir_ref->push_node(input_precision, input_shape2); + auto param3 = linear_ir_ref->push_node(input_precision, input_shape3); + auto buffer = linear_ir_ref->push_node(buf_shape, input_precision); + auto brgemm1 = linear_ir_ref->push_node(param1.second, param2.second); + auto relu = linear_ir_ref->push_node(brgemm1.second); + auto h_max = linear_ir_ref->push_node(relu.second); + auto mul = linear_ir_ref->push_node(h_max.second, buffer.second); + auto sub = linear_ir_ref->push_node(brgemm1.second, h_max.second); + auto h_sum = linear_ir_ref->push_node(sub.second); + auto div = linear_ir_ref->push_node(mul.second, h_sum.second); + auto brgemm2 = linear_ir_ref->push_node(sub.second, param3.second, div.second); + const auto result = linear_ir_ref->push_node(brgemm2.second); + const auto& loop_manager = linear_ir_ref->get_loop_manager(); + + // block inner loops for dimension 0 + loop_manager->mark_loop(relu.first, h_max.first, 64, vector_size, 0, + std::vector{LoopPort((*relu.first)->get_input_port(0), true, 0)}, + std::vector{LoopPort((*relu.first)->get_output_port(0), true, 0)}); + loop_manager->mark_loop(sub.first, h_sum.first, 64, vector_size, 0, + std::vector{LoopPort((*sub.first)->get_input_port(0), true, 0), + LoopPort((*sub.first)->get_input_port(1), false, 0)}, + std::vector{LoopPort((*sub.first)->get_output_port(0), true, 0)}); + + // block inner loop for dimension 1. + loop_manager->mark_loop(relu.first, brgemm2.first, 32, 1, 1, + std::vector{LoopPort((*relu.first)->get_input_port(0), true, 1), + LoopPort((*sub.first)->get_input_port(0), true, 1), + LoopPort((*mul.first)->get_input_port(0), true, 1)}, + std::vector{LoopPort((*sub.first)->get_output_port(0), true, 1), + LoopPort((*div.first)->get_output_port(0), true, 1)}); + // two block loops. All exprs between two brgemm including h_max should be in both block loops. loop_manager->mark_loop(brgemm1.first, result.first, 1024, 64, 0, std::vector{LoopPort((*brgemm1.first)->get_input_port(0), false, 0), LoopPort((*brgemm1.first)->get_input_port(1), true, 0), + LoopPort((*mul.first)->get_input_port(0), false, 0), LoopPort((*brgemm2.first)->get_input_port(1), true, 0)}, std::vector{LoopPort((*brgemm2.first)->get_output_port(0), false, 0)}); + loop_manager->mark_loop(brgemm1.first, result.first, 512, 32, 1, + std::vector{LoopPort((*brgemm1.first)->get_input_port(0), true, 1), + LoopPort((*brgemm1.first)->get_input_port(1), false, 1), + LoopPort((*mul.first)->get_input_port(0), true, 1), + LoopPort((*brgemm2.first)->get_input_port(1), false, 1)}, + std::vector{LoopPort((*brgemm2.first)->get_output_port(0), true, 1)}); } } +// full falsh attention case +TEST_F(SplitLoopTest, SplitLoopFlashAttentionTest) { + size_t vector_size = 16; + const auto input_precision = ov::element::f32; + const ov::Shape input_shape1{512, 64}; + const ov::Shape input_shape2{64, 1024}; + const ov::Shape input_shape3{1024, 16}; + const ov::Shape buf_shape{512, 1}; + { + auto param1 = linear_ir->push_node(input_precision, input_shape1); + auto param2 = linear_ir->push_node(input_precision, input_shape2); + auto param3 = linear_ir->push_node(input_precision, input_shape3); + auto buffer_max = linear_ir->push_node(buf_shape, input_precision); + auto buffer_sum = linear_ir->push_node(buf_shape, input_precision); + auto brgemm1 = linear_ir->push_node(param1.second, param2.second); + // softmax max + const auto vector_buffer_max = linear_ir->push_node(input_precision); + uint32_t fill_value_max = 0xff7fffff; + const auto initial_fill_max = linear_ir->push_node(vector_buffer_max.second, 0, fill_value_max); + const auto fill_max = linear_ir->push_node(brgemm1.second, vector_size, fill_value_max); + auto max = linear_ir->push_node(initial_fill_max.second, fill_max.second); + auto h_max = linear_ir->push_node(max.second); + // scale + auto sub_scale = linear_ir->push_node(buffer_max.second, h_max.second); + auto exp_scale = linear_ir->push_node(sub_scale.second); + + // softmax sum + const auto vector_buffer_sum = linear_ir->push_node(input_precision); + uint32_t fill_value_sum = 0x00000000; + const auto initial_fill_sum = linear_ir->push_node(vector_buffer_sum.second, 0, fill_value_sum); + + auto max_new = linear_ir->push_node(buffer_max.second, h_max.second); // max of old and new + auto sub_softmax = linear_ir->push_node(brgemm1.second, max_new.second); + auto exp = linear_ir->push_node(sub_softmax.second); + const auto fill_sum = linear_ir->push_node(exp.second, vector_size, fill_value_max); + auto add = linear_ir->push_node(initial_fill_sum.second, fill_sum.second); + auto h_sum = linear_ir->push_node(add.second); + + // softmax multiply + auto power_static = linear_ir->push_node(h_sum.second, -1); + auto mul_softmax = linear_ir->push_node(exp.second, power_static.second); + // scale + auto mul_scale = linear_ir->push_node(buffer_sum.second, exp_scale.second); + auto scale = linear_ir->push_node(mul_scale.second, h_sum.second); + + auto brgemm2 = linear_ir->push_node(mul_softmax.second, param3.second, scale.second); + const auto result = linear_ir->push_node(brgemm2.second); + const auto& loop_manager = linear_ir->get_loop_manager(); + // two loops for brgemm1. mark inner first as mark new loop inserted as outer loop as default. + size_t brgemm1_n = loop_manager->mark_loop(brgemm1.first, vector_buffer_max.first, 1024, 64, 0, + std::vector{LoopPort((*brgemm1.first)->get_input_port(0), false, 0), + LoopPort((*brgemm1.first)->get_input_port(1), true, 0)}, + std::vector{LoopPort((*brgemm1.first)->get_output_port(0), true, 0)}); + size_t brgemm1_m = loop_manager->mark_loop(brgemm1.first, vector_buffer_max.first, 512, 32, 1, + std::vector{LoopPort((*brgemm1.first)->get_input_port(0), true, 1), + LoopPort((*brgemm1.first)->get_input_port(1), false, 1)}, + std::vector{LoopPort((*brgemm1.first)->get_output_port(0), true, 1)}); + + // three loops on column-1024 [512,1024] + size_t column_loop1 = loop_manager->mark_loop(fill_max.first, h_max.first, 1024, vector_size, 0, + std::vector{LoopPort((*fill_max.first)->get_input_port(0), true, 0)}, + std::vector{LoopPort((*max.first)->get_output_port(0), true, 0)}); + size_t column_loop2 = loop_manager->mark_loop(max_new.first, h_sum.first, 1024, vector_size, 0, + std::vector{LoopPort((*max_new.first)->get_input_port(0), false, 0), + LoopPort((*max_new.first)->get_input_port(1), false, 0), + LoopPort((*sub_softmax.first)->get_input_port(0), true, 0)}, + std::vector{LoopPort((*exp.first)->get_output_port(0), true, 0), + LoopPort((*add.first)->get_output_port(0), true, 0)}); + size_t column_loop3 = loop_manager->mark_loop(power_static.first, mul_scale.first, 1024, vector_size, 0, + std::vector{LoopPort((*power_static.first)->get_input_port(0), false, 0), + LoopPort((*mul_softmax.first)->get_input_port(0), true, 0)}, + std::vector{LoopPort((*mul_softmax.first)->get_output_port(0), true, 0)}); + // one loop on row-512 [512,1024] + size_t row_loop = loop_manager->mark_loop(vector_buffer_max.first, brgemm2.first, 512, 1, 1, + std::vector{LoopPort((*fill_max.first)->get_input_port(0), true, 1), + LoopPort((*sub_softmax.first)->get_input_port(0), true, 1), + LoopPort((*sub_scale.first)->get_input_port(0), true, 1), + LoopPort((*max_new.first)->get_input_port(0), true, 1), + LoopPort((*mul_scale.first)->get_input_port(0), true, 1)}, + std::vector{LoopPort((*mul_softmax.first)->get_output_port(0), true, 1), + LoopPort((*scale.first)->get_output_port(0), true, 1)}); + // two loops for brgemm2 + size_t brgemm2_k = loop_manager->mark_loop(brgemm2.first, result.first, 1024, 64, 0, + std::vector{LoopPort((*brgemm2.first)->get_input_port(0), true, 0), + LoopPort((*brgemm2.first)->get_input_port(1), true, 0), + LoopPort((*brgemm2.first)->get_input_port(2), false, 0)}, + std::vector{LoopPort((*brgemm2.first)->get_output_port(0), false, 0)}); + size_t brgemm2_m = loop_manager->mark_loop(brgemm2.first, result.first, 512, 32, 1, + std::vector{LoopPort((*brgemm2.first)->get_input_port(0), true, 1), + LoopPort((*brgemm2.first)->get_input_port(1), false, 1), + LoopPort((*brgemm2.first)->get_input_port(2), true, 1)}, + std::vector{LoopPort((*brgemm2.first)->get_output_port(0), true, 1)}); + std::cout << "brgemm1_m:" << brgemm1_m << std::endl; + std::cout << "brgemm1_n:" << brgemm1_n << std::endl; + std::cout << "column_loop1:" << column_loop1 << std::endl; + std::cout << "column_loop2:" << column_loop2 << std::endl; + std::cout << "column_loop3:" << column_loop3 << std::endl; + std::cout << "row_loop:" << row_loop << std::endl; + std::cout << "brgemm2_m:" << brgemm2_m << std::endl; + std::cout << "brgemm2_k:" << brgemm2_k << std::endl; + } + { + auto param1 = linear_ir_ref->push_node(input_precision, input_shape1); + auto param2 = linear_ir_ref->push_node(input_precision, input_shape2); + auto param3 = linear_ir_ref->push_node(input_precision, input_shape3); + auto buffer_max = linear_ir_ref->push_node(buf_shape, input_precision); + auto buffer_sum = linear_ir_ref->push_node(buf_shape, input_precision); + auto brgemm1 = linear_ir_ref->push_node(param1.second, param2.second); + // softmax max + const auto vector_buffer_max = linear_ir_ref->push_node(input_precision); + uint32_t fill_value_max = 0xff7fffff; + const auto initial_fill_max = linear_ir_ref->push_node(vector_buffer_max.second, 0, fill_value_max); + const auto fill_max = linear_ir_ref->push_node(brgemm1.second, vector_size, fill_value_max); + auto max = linear_ir_ref->push_node(initial_fill_max.second, fill_max.second); + auto h_max = linear_ir_ref->push_node(max.second); + // scale + auto sub_scale = linear_ir_ref->push_node(buffer_max.second, h_max.second); + auto exp_scale = linear_ir_ref->push_node(sub_scale.second); + auto mul_scale = linear_ir_ref->push_node(buffer_sum.second, exp_scale.second); // moved up to here + + // softmax sum + const auto vector_buffer_sum = linear_ir_ref->push_node(input_precision); + uint32_t fill_value_sum = 0x00000000; + const auto initial_fill_sum = linear_ir_ref->push_node(vector_buffer_sum.second, 0, fill_value_sum); + + auto max_new = linear_ir_ref->push_node(buffer_max.second, h_max.second); // max of old and new + auto sub_softmax = linear_ir_ref->push_node(brgemm1.second, max_new.second); + auto exp = linear_ir_ref->push_node(sub_softmax.second); + const auto fill_sum = linear_ir_ref->push_node(exp.second, vector_size, fill_value_max); + auto add = linear_ir_ref->push_node(initial_fill_sum.second, fill_sum.second); + auto h_sum = linear_ir_ref->push_node(add.second); + // scale, moved up here + auto scale = linear_ir_ref->push_node(mul_scale.second, h_sum.second); + + // softmax multiply + auto power_static = linear_ir_ref->push_node(h_sum.second, -1); + auto mul_softmax = linear_ir_ref->push_node(exp.second, power_static.second); + + auto brgemm2 = linear_ir_ref->push_node(mul_softmax.second, param3.second, scale.second); + const auto result = linear_ir_ref->push_node(brgemm2.second); + const auto& loop_manager = linear_ir_ref->get_loop_manager(); + + // three block inner loops for dimension 0(0 means inner most dimension) + loop_manager->mark_loop(fill_max.first, h_max.first, 64, vector_size, 0, + std::vector{LoopPort((*fill_max.first)->get_input_port(0), true, 0)}, + std::vector{LoopPort((*max.first)->get_output_port(0), true, 0)}); + loop_manager->mark_loop(max_new.first, h_sum.first, 64, vector_size, 0, + std::vector{LoopPort((*max_new.first)->get_input_port(0), false, 0), + LoopPort((*max_new.first)->get_input_port(1), false, 0), + LoopPort((*sub_softmax.first)->get_input_port(0), true, 0)}, + std::vector{LoopPort((*exp.first)->get_output_port(0), true, 0), + // result of add should not store, just Hsum on vec reg. + LoopPort((*add.first)->get_output_port(0), true, 0)}); + loop_manager->mark_loop(power_static.first, brgemm2.first, 64, vector_size, 0, + std::vector{LoopPort((*power_static.first)->get_input_port(0), false, 0), + LoopPort((*mul_softmax.first)->get_input_port(0), true, 0)}, + std::vector{LoopPort((*mul_softmax.first)->get_output_port(0), true, 0)}); + + // block inner loop for dimension 1. + // there will be buffers inserted as brgemm1 result and brgemm2 input0. + // below ports will be connected with the block buffers and load/store inserted based on buffer shape. + loop_manager->mark_loop(vector_buffer_max.first, brgemm2.first, 32, 1, 1, + std::vector{LoopPort((*fill_max.first)->get_input_port(0), true, 1), + LoopPort((*sub_softmax.first)->get_input_port(0), true, 1), + // three buffers inc 1 on each row loop + LoopPort((*sub_scale.first)->get_input_port(0), true, 1), + LoopPort((*max_new.first)->get_input_port(0), true, 1), + LoopPort((*mul_scale.first)->get_input_port(0), true, 1)}, + // inc 64. store to buffer, inc based on buffer shape. + std::vector{LoopPort((*mul_softmax.first)->get_output_port(0), true, 1), + // inc 1. one row get one scale. scale buffer(32*1) should be inserted. + LoopPort((*scale.first)->get_output_port(0), true, 1)}); + + // two block loops. All exprs between two brgemm are in both block loops. + size_t block_nk = loop_manager->mark_loop(brgemm1.first, result.first, 1024, 64, 0, + std::vector{LoopPort((*brgemm1.first)->get_input_port(0), false, 0), + LoopPort((*brgemm1.first)->get_input_port(1), true, 0), // K matrix inc 64(*1) on 1024 dimension + // three buffers ports are not increased on N/K block loop. + LoopPort((*max_new.first)->get_input_port(0), false, 0), + LoopPort((*sub_scale.first)->get_input_port(0), false, 0), + LoopPort((*mul_scale.first)->get_input_port(0), false, 0), + LoopPort((*brgemm2.first)->get_input_port(1), true, 0)}, // V matrix inc 64(*16) on 1024 dimension + std::vector{LoopPort((*brgemm2.first)->get_output_port(0), false, 0)}); + size_t block_m = loop_manager->mark_loop(brgemm1.first, result.first, 512, 32, 1, + std::vector{LoopPort((*brgemm1.first)->get_input_port(0), true, 1), + LoopPort((*brgemm1.first)->get_input_port(1), false, 1), // K matrix not inc + // three buffers ports are increased 32 on M block loop + LoopPort((*max_new.first)->get_input_port(0), true, 1), + LoopPort((*sub_scale.first)->get_input_port(0), true, 1), + LoopPort((*mul_scale.first)->get_input_port(0), true, 1), + LoopPort((*brgemm2.first)->get_input_port(1), false, 1)}, // V matrix not inc + std::vector{LoopPort((*brgemm2.first)->get_output_port(0), true, 1)}); + } + /* + // after split_loop, not apply inside fuse_loop + { + auto param1 = linear_ir_ref->push_node(input_precision, input_shape1); + auto param2 = linear_ir_ref->push_node(input_precision, input_shape2); + auto param3 = linear_ir_ref->push_node(input_precision, input_shape3); + auto buffer_max = linear_ir_ref->push_node(buf_shape, input_precision); + auto buffer_sum = linear_ir_ref->push_node(buf_shape, input_precision); + auto brgemm1 = linear_ir_ref->push_node(param1.second, param2.second); + // softmax max + const auto vector_buffer_max = linear_ir_ref->push_node(input_precision); + uint32_t fill_value_max = 0xff7fffff; + const auto initial_fill_max = linear_ir_ref->push_node(vector_buffer_max.second, 0, fill_value_max); + const auto fill_max = linear_ir_ref->push_node(brgemm1.second, vector_size, fill_value_max); + auto max = linear_ir_ref->push_node(initial_fill_max.second, fill_max.second); + auto h_max = linear_ir_ref->push_node(max.second); + // scale + auto sub_scale = linear_ir_ref->push_node(buffer_max.second, h_max.second); + auto exp_scale = linear_ir_ref->push_node(sub_scale.second); + auto mul_scale = linear_ir_ref->push_node(buffer_sum.second, exp_scale.second); // moved up to here + + // softmax sum + const auto vector_buffer_sum = linear_ir_ref->push_node(input_precision); + uint32_t fill_value_sum = 0x00000000; + const auto initial_fill_sum = linear_ir_ref->push_node(vector_buffer_sum.second, 0, fill_value_sum); + + auto max_new = linear_ir_ref->push_node(buffer_max.second, h_max.second); // max of old and new + auto sub_softmax = linear_ir_ref->push_node(brgemm1.second, max_new.second); + auto exp = linear_ir_ref->push_node(sub_softmax.second); + const auto fill_sum = linear_ir_ref->push_node(exp.second, vector_size, fill_value_max); + auto add = linear_ir_ref->push_node(initial_fill_sum.second, fill_sum.second); + auto h_sum = linear_ir_ref->push_node(add.second); + // scale, moved up here + auto scale = linear_ir_ref->push_node(mul_scale.second, h_sum.second); + + // softmax multiply + auto power_static = linear_ir_ref->push_node(h_sum.second, -1); + auto mul_softmax = linear_ir_ref->push_node(exp.second, power_static.second); + + auto brgemm2 = linear_ir_ref->push_node(mul_softmax.second, param3.second, scale.second); + const auto result = linear_ir_ref->push_node(brgemm2.second); + const auto& loop_manager = linear_ir_ref->get_loop_manager(); + + // two loops for brgemm1. mark inner first as mark new loop inserted as outer loop as default. + size_t brgemm1_n = loop_manager->mark_loop(brgemm1.first, vector_buffer_max.first, 1024, 64, 0, + std::vector{LoopPort((*brgemm1.first)->get_input_port(0), false, 0), + LoopPort((*brgemm1.first)->get_input_port(1), true, 0)}, + std::vector{LoopPort((*brgemm1.first)->get_output_port(0), true, 0)}); + size_t brgemm1_m = loop_manager->mark_loop(brgemm1.first, vector_buffer_max.first, 512, 32, 1, + std::vector{LoopPort((*brgemm1.first)->get_input_port(0), true, 1), + LoopPort((*brgemm1.first)->get_input_port(1), false, 1)}, + std::vector{LoopPort((*brgemm1.first)->get_output_port(0), true, 1)}); + + // three inner dimension block inside loops(wa:64, inc:16). + size_t column_loop1 = loop_manager->mark_loop(fill_max.first, h_max.first, 64, vector_size, 0, + std::vector{LoopPort((*fill_max.first)->get_input_port(0), true, 0)}, + std::vector{LoopPort((*max.first)->get_output_port(0), true, 0)}); + size_t column_loop2 = loop_manager->mark_loop(max_new.first, h_sum.first, 64, vector_size, 0, + std::vector{LoopPort((*max_new.first)->get_input_port(0), false, 0), + LoopPort((*max_new.first)->get_input_port(1), false, 0), + LoopPort((*sub_softmax.first)->get_input_port(0), true, 0)}, + std::vector{LoopPort((*exp.first)->get_output_port(0), true, 0), + LoopPort((*add.first)->get_output_port(0), true, 0)}); + size_t column_loop3 = loop_manager->mark_loop(power_static.first, mul_scale.first, 64, vector_size, 0, + std::vector{LoopPort((*power_static.first)->get_input_port(0), false, 0), + LoopPort((*mul_softmax.first)->get_input_port(0), true, 0)}, + std::vector{LoopPort((*mul_softmax.first)->get_output_port(0), true, 0)}); + + // one outer dimension block inside loop(wa:32, inc:1). + loop_manager->mark_loop(vector_buffer_max.first, brgemm2.first, 32, 1, 1, + std::vector{LoopPort((*fill_max.first)->get_input_port(0), true, 1), + LoopPort((*sub_softmax.first)->get_input_port(0), true, 1), + // three buffers inc 1(*64) on each row loop + LoopPort((*sub_scale.first)->get_input_port(0), true, 1), + LoopPort((*max_new.first)->get_input_port(0), true, 1), + LoopPort((*mul_scale.first)->get_input_port(0), true, 1)}, + // inc 1(*64). store to buffer, inc based on buffer shape. + std::vector{LoopPort((*mul_softmax.first)->get_output_port(0), true, 1), + // inc 1*(1). one row get one scale. scale buffer(32*1) should be inserted. + LoopPort((*scale.first)->get_output_port(0), true, 1)}); + + // three inner dimension block loops(wa:1024, inc:64). these three will be fused, and finally fused with N in brgemm1 and K in brgemm2 + // parent extend to vector_buffer_max. child extend to mul_scale(vector_buffer_sum-loop_end) + size_t column_loop1 = loop_manager->mark_loop(vector_buffer_max.first, vector_buffer_sum.first, 64, vector_size, 0, + std::vector{LoopPort((*fill_max.first)->get_input_port(0), true, 0), + LoopPort((*mul_scale.first)->get_input_port(0), false, 0)}, // add + std::vector{LoopPort((*h_max.first)->get_output_port(0), false, 0)}); // update + // parent extend to vector_buffer_sum. child extend to scale(loop_end is power_static). + size_t column_loop2 = loop_manager->mark_loop(vector_buffer_sum.first, power_static.first, 64, vector_size, 0, + std::vector{LoopPort((*max_new.first)->get_input_port(0), false, 0), + LoopPort((*max_new.first)->get_input_port(1), false, 0), + LoopPort((*sub_softmax.first)->get_input_port(0), true, 0), + LoopPort((*scale.first)->get_input_port(0), false, 0)}, // add + std::vector{LoopPort((*exp.first)->get_output_port(0), true, 0), + // update from (*add.first)->get_output_port(0) to h_sum + LoopPort((*h_sum.first)->get_output_port(0), false, 0), + LoopPort((*scale.first)->get_output_port(0), false, 0)}); // add + size_t column_loop3 = loop_manager->mark_loop(power_static.first, mul_scale.first, 64, vector_size, 0, + std::vector{LoopPort((*power_static.first)->get_input_port(0), false, 0), + LoopPort((*mul_softmax.first)->get_input_port(0), true, 0)}, + std::vector{LoopPort((*mul_softmax.first)->get_output_port(0), true, 0)}); + + // one outer dimension block loops(wa:512, inc:32). fused to M block in brgemm1 and brgemm2 + loop_manager->mark_loop(vector_buffer_max.first, brgemm2.first, 512, 32, 1, + std::vector{LoopPort((*fill_max.first)->get_input_port(0), true, 1), + LoopPort((*sub_softmax.first)->get_input_port(0), true, 1), + // three buffers inc 32(*1) + LoopPort((*sub_scale.first)->get_input_port(0), true, 1), + LoopPort((*max_new.first)->get_input_port(0), true, 1), + LoopPort((*mul_scale.first)->get_input_port(0), true, 1)}, + // 32(*64) buffer inserted and reused + std::vector{LoopPort((*mul_softmax.first)->get_output_port(0), false, 1), + // 32(*1) buffer inserted and reused + LoopPort((*scale.first)->get_output_port(0), false, 1)}); + + // two loops for brgemm2 + size_t brgemm2_k = loop_manager->mark_loop(brgemm2.first, result.first, 1024, 64, 0, + std::vector{LoopPort((*brgemm2.first)->get_input_port(0), true, 0), + LoopPort((*brgemm2.first)->get_input_port(1), true, 0), + LoopPort((*brgemm2.first)->get_input_port(2), false, 0)}, + std::vector{LoopPort((*brgemm2.first)->get_output_port(0), false, 0)}); + size_t brgemm2_m = loop_manager->mark_loop(brgemm2.first, result.first, 512, 32, 1, + std::vector{LoopPort((*brgemm2.first)->get_input_port(0), true, 1), + LoopPort((*brgemm2.first)->get_input_port(1), false, 1), + LoopPort((*brgemm2.first)->get_input_port(2), true, 1)}, + std::vector{LoopPort((*brgemm2.first)->get_output_port(0), true, 1)}); + } + /**/ +} + +// TEST_F(SplitLoopTest, SplitLoopTestReduceOp) { +// size_t vector_size = 16; +// const auto input_precision = ov::element::f32; +// const ov::Shape input_shape1{512, 64}; +// const ov::Shape input_shape2{64, 1024}; +// const ov::Shape input_shape3{1024, 16}; +// const ov::Shape buf_shape{512, 1}; +// /* +// * Param1 Param2 +// * \ / +// * Brgemm1 +// * | | +// * | ReduceMax Buffer +// * | \ / | +// * | Maximum | +// * | | | | +// * Substract | | +// * | \ / +// * | Add +// * \ / +// * Multiply Param3 +// * \ / +// * Brgemm2 +// * | +// * Result +// */ +// // Maximum have one loop, add moved after Max. expr in same loop together. +// { +// auto param1 = linear_ir->push_node(input_precision, input_shape1); +// auto param2 = linear_ir->push_node(input_precision, input_shape2); +// auto param3 = linear_ir->push_node(input_precision, input_shape3); +// auto buffer = linear_ir->push_node(buf_shape, input_precision); +// auto brgemm1 = linear_ir->push_node(param1.second, param2.second); +// auto reduce_max = linear_ir->push_node(brgemm1.second, 1); +// auto max = linear_ir->push_node(reduce_max.second, buffer.second); +// auto sub = linear_ir->push_node(brgemm1.second, max.second); +// auto add = linear_ir->push_node(max.second, buffer.second); +// auto mul = linear_ir->push_node(sub.second, add.second); +// auto brgemm2 = linear_ir->push_node(mul.second, param3.second); +// const auto result = linear_ir->push_node(brgemm2.second); +// const auto& loop_manager = linear_ir->get_loop_manager(); +// // two loops for brgemm1 +// loop_manager->mark_loop(brgemm1.first, reduce_max.first, 512, 32, 1, +// std::vector{LoopPort((*brgemm1.first)->get_input_port(0), true, 1), +// LoopPort((*brgemm1.first)->get_input_port(1), false, 1)}, +// std::vector{LoopPort((*brgemm1.first)->get_output_port(0), true, 1)}); +// loop_manager->mark_loop(brgemm1.first, reduce_max.first, 1024, 64, 0, +// std::vector{LoopPort((*brgemm1.first)->get_input_port(0), false, 0), +// LoopPort((*brgemm1.first)->get_input_port(1), true, 0)}, +// std::vector{LoopPort((*brgemm1.first)->get_output_port(0), true, 0)}); +// // loops on column +// loop_manager->mark_loop(reduce_max.first, max.first, 1024, vector_size, 0, +// std::vector{LoopPort((*reduce_max.first)->get_input_port(0), true, 0)}, +// std::vector{LoopPort((*reduce_max.first)->get_output_port(0), false, 0)}); +// loop_manager->mark_loop(sub.first, brgemm2.first, 1024, vector_size, 0, +// std::vector{LoopPort((*sub.first)->get_input_port(0), true, 0), +// LoopPort((*sub.first)->get_input_port(1), false, 0), +// LoopPort((*mul.first)->get_input_port(1), false, 0)}, +// std::vector{LoopPort((*mul.first)->get_output_port(0), true, 0)}); +// // loop on row +// loop_manager->mark_loop(reduce_max.first, brgemm2.first, 512, 1, 1, +// std::vector{LoopPort((*reduce_max.first)->get_input_port(0), true, 1), +// LoopPort((*max.first)->get_input_port(1), true, 1), +// LoopPort((*sub.first)->get_input_port(0), true, 1), +// LoopPort((*add.first)->get_input_port(1), true, 1)}, +// std::vector{LoopPort((*mul.first)->get_output_port(0), true, 1)}); +// // two loops for brgemm2 +// loop_manager->mark_loop(brgemm2.first, result.first, 512, 32, 1, +// std::vector{LoopPort((*brgemm2.first)->get_input_port(0), true, 1), +// LoopPort((*brgemm2.first)->get_input_port(1), false, 1)}, +// std::vector{LoopPort((*brgemm2.first)->get_output_port(0), true, 1)}); +// loop_manager->mark_loop(brgemm2.first, result.first, 1024, 64, 0, +// std::vector{LoopPort((*brgemm2.first)->get_input_port(0), true, 0), +// LoopPort((*brgemm2.first)->get_input_port(1), true, 0)}, +// std::vector{LoopPort((*brgemm2.first)->get_output_port(0), false, 0)}); +// } +// { +// auto param1 = linear_ir_ref->push_node(input_precision, input_shape1); +// auto param2 = linear_ir_ref->push_node(input_precision, input_shape2); +// auto param3 = linear_ir_ref->push_node(input_precision, input_shape3); +// auto brgemm1 = linear_ir_ref->push_node(param1.second, param2.second); +// const auto vector_buffer = linear_ir_ref->push_node(input_precision); +// uint32_t fill_value = 0xff7fffff; +// const auto initial_fill = linear_ir_ref->push_node(vector_buffer.second, 0, fill_value); +// const auto fill = linear_ir_ref->push_node(brgemm1.second, vector_size, fill_value); +// auto max = linear_ir_ref->push_node(initial_fill.second, fill.second); +// auto h_max = linear_ir_ref->push_node(max.second); +// auto sub = linear_ir_ref->push_node(brgemm1.second, h_max.second); +// auto brgemm2 = linear_ir_ref->push_node(sub.second, param3.second); +// const auto result = linear_ir_ref->push_node(brgemm2.second); +// const auto& loop_manager = linear_ir_ref->get_loop_manager(); + +// // block inner loops for dimension 0 +// loop_manager->mark_loop(fill.first, h_max.first, 64, vector_size, 0, +// std::vector{LoopPort((*fill.first)->get_input_port(0), true, 0), +// LoopPort((*max.first)->get_input_port(0), false, 0)}, // Max(initial_fill, fill) +// std::vector{LoopPort((*max.first)->get_output_port(0), true, 0)}); +// loop_manager->mark_loop(sub.first, brgemm2.first, 64, vector_size, 0, +// std::vector{LoopPort((*sub.first)->get_input_port(0), true, 0), +// LoopPort((*sub.first)->get_input_port(1), false, 0)}, // sub:brgemm1-Hmax +// std::vector{LoopPort((*sub.first)->get_output_port(0), true, 0)}); +// // block inner loop for dimension 1. +// loop_manager->mark_loop(vector_buffer.first, brgemm2.first, 32, 1, 1, +// std::vector{LoopPort((*fill.first)->get_input_port(0), true, 1), +// LoopPort((*sub.first)->get_input_port(0), true, 1)}, +// std::vector{LoopPort((*sub.first)->get_output_port(0), true, 1)}); +// // two block loops. All exprs between two brgemm including h_max should be in both block loops. +// loop_manager->mark_loop(brgemm1.first, result.first, 512, 32, 1, +// std::vector{LoopPort((*brgemm1.first)->get_input_port(0), true, 1), +// LoopPort((*brgemm1.first)->get_input_port(1), false, 1), +// LoopPort((*brgemm2.first)->get_input_port(1), false, 1)}, +// std::vector{LoopPort((*brgemm2.first)->get_output_port(0), true, 1)}); +// loop_manager->mark_loop(brgemm1.first, result.first, 1024, 64, 0, +// std::vector{LoopPort((*brgemm1.first)->get_input_port(0), false, 0), +// LoopPort((*brgemm1.first)->get_input_port(1), true, 0), +// LoopPort((*brgemm2.first)->get_input_port(1), true, 0)}, +// std::vector{LoopPort((*brgemm2.first)->get_output_port(0), false, 0)}); +// } +// } + } // namespace snippets } // namespace test } // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp index 5b408bb0b611cc..e48054be88d629 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp @@ -32,7 +32,7 @@ BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, BRGEMM_TYPE t const size_t offset_a, const size_t offset_b, const size_t offset_c_scale, const size_t offset_c, std::vector layout_a, std::vector layout_b, std::vector layout_c_scale, std::vector layout_c) - : Brgemm(), m_type(type), has_c_pre_scale(true) { + : Brgemm(true), m_type(type), has_c_pre_scale(true) { // We call default ctor of Brgemm class to avoid incorrect shape infer in constructor_validate_and_type_infer() call std::cout << "BrgemmCPU1" << std::endl; set_arguments({A, B, C}); @@ -77,7 +77,7 @@ BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, BRGEMM_TYPE t const PortDescriptor& desc_c, const PortDescriptor& desc_c_scale, std::vector layout_a, std::vector layout_b, std::vector layout_c_scale, std::vector layout_c) - : Brgemm(), m_type(type), has_c_pre_scale(true) { + : Brgemm(true), m_type(type), has_c_pre_scale(true) { set_arguments({A, B, C}); set_output_size(1); m_input_ports = {{0, desc_a}, {1, desc_b}, {2, desc_c_scale}}; diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp index 51565537c43568..501d0732d88fbc 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp @@ -88,8 +88,11 @@ bool BrgemmCPUBlocking::mark_blocking_loops(LinearIR& linear_ir, const auto brgemm = ov::as_type_ptr(brgemm_expr->get_node()); const auto type = brgemm->get_type(); - if (stand_alone(type)) + if (stand_alone(type)) { + std::cout << "BrgemmCPUBlocking stand_alone..............." << std::endl; return ov::snippets::lowered::pass::BrgemmBlockingBase::mark_blocking_loops(linear_ir, brgemm_it, m_block, n_block, k_block); + } + // return ov::snippets::lowered::pass::BrgemmBlockingBase::mark_blocking_loops(linear_ir, brgemm_it, m_block, n_block, k_block); brgemm_expr->get_input_port_descriptor(0)->set_subtensor({m_block, k_block}); brgemm_expr->get_input_port_descriptor(1)->set_subtensor({k_block, n_block});