Skip to content

Commit

Permalink
add split_loop test with block on dimension 1 and 0
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Oct 29, 2024
1 parent 03c2a91 commit b43dfe9
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 67 deletions.
7 changes: 5 additions & 2 deletions src/common/snippets/src/lowered/pass/brgemm_blocking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,11 @@ bool BrgemmBlockingBase::mark_blocking_loops(snippets::lowered::LinearIR& linear

const auto& loop_manager = linear_ir.get_loop_manager();
if (!ov::snippets::utils::is_full_dim_value(k_block)) {
const std::vector<LoopPort> entries{LoopPort(brgemm_expr->get_input_port(0), true, 0),
LoopPort(brgemm_expr->get_input_port(1), true, 1)};
// const std::vector<LoopPort> entries{LoopPort(brgemm_expr->get_input_port(0), true, 0),
// LoopPort(brgemm_expr->get_input_port(1), true, 1)};
// const std::vector<LoopPort> exits{LoopPort(brgemm_expr->get_output_port(0), false)};
const std::vector<LoopPort> entries{LoopPort(brgemm_expr->get_input_port(0), true),
LoopPort(brgemm_expr->get_input_port(1), true)};
const std::vector<LoopPort> exits{LoopPort(brgemm_expr->get_output_port(0), false)};
mark_k_blocking(loop_manager, brgemm_it, std::next(brgemm_it), entries, exits, k_block);
}
Expand Down
134 changes: 69 additions & 65 deletions src/common/snippets/src/lowered/pass/split_loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ bool SplitLoops::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin,
// in case if loops are not split but FuseLoops is registered in pass manager after SplitLoops
if (loop_was_split)
FuseLoops().run(linear_ir, begin, end);
//
// fuse loop cover case of port update of this split
// for (auto expr_it = begin; expr_it != end; expr_it++) {
// const auto expr = *expr_it;
// const auto& node = expr->get_node();
Expand All @@ -165,9 +165,6 @@ bool SplitLoops::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin,
// std::cout << "loop_dim:" << loop_info->get_dim_idx() << std::endl;
// std::cout << "wa:" << loop_info->get_work_amount() << std::endl;
// std::cout << "inc:" << loop_info->get_increment() << std::endl;
// if (loop_info->get_dim_idx() == 18446744073709551615) {
// loop_info->set_dim_idx(0);
// }
// auto in_ports = loop_info->get_input_ports();
// for (auto in_port : in_ports) {
// auto in_port_parent = in_port.expr_port->get_connected_ports().begin()->get_expr();
Expand All @@ -176,8 +173,13 @@ bool SplitLoops::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin,

// auto out_ports = loop_info->get_output_ports();
// for (auto out_port : out_ports) {
// auto out_port_child = out_port.expr_port->get_connected_ports().begin()->get_expr();
// std::cout << "out_port_child to:" << out_port_child->get_node()->get_friendly_name() << std::endl;
// auto consumers = out_port.expr_port->get_connected_ports();
// std::cout << "out_port_child is one of:" << std::endl;
// for (auto consumer : consumers) {
// auto out_port_child = consumer.get_expr();
// std::cout << "out_port_child to:" << out_port_child->get_node()->get_friendly_name() << std::endl;
// }
// std::cout << "out_port_child is one of end."<< std::endl;
// }
// }
// }
Expand All @@ -199,18 +201,18 @@ size_t SplitLoops::split(LinearIR& linear_ir, size_t loop_to_split_id, size_t ou
auto outer_loop_bounds_second = loop_bounds.second;
auto in_ports = inner_loop_info->get_input_ports();
auto out_ports = inner_loop_info->get_output_ports();

// extend outer block loop of inner most dimension to child that have no inner most loop.
// for example, Hsum should perform on every on M_blk*K_blk*M_in_block loops.
// for example, Hmax should perform on every on M_blk*K_blk*M_in_block loops.
if (inner_loop_info->get_dim_idx() == 0) {
// extend child
for (auto expr_it = outer_loop_bounds_first; expr_it != outer_loop_bounds_second; expr_it++) {
auto expr = *expr_it;
std::cout << "expr:" << expr->get_node()->get_friendly_name() << std::endl;
const auto& expr_loop = expr->get_loop_ids();
const auto& expr_loop = expr->get_loop_ids(); // expr_loop is [2,3,1] child is [2,3]
if (expr->get_output_count() != 1)
break;
auto expr_out_port = expr->get_output_port(0);
bool check_next = true;
bool extend_global = false;
while (check_next) {
if (expr->get_output_count() != 1)
break;
Expand All @@ -233,7 +235,7 @@ size_t SplitLoops::split(LinearIR& linear_ir, size_t loop_to_split_id, size_t ou
if (child_expr_loop.size() < 1)
continue;
const auto& child_last_loop_dim = loop_manager->get_loop_info<UnifiedLoopInfo>(child_expr_loop.back())->get_dim_idx();
if ((expr_loop.size() - child_expr_loop.size() != 1) && child_last_loop_dim == 0) {
if ((expr_loop.size() - child_expr_loop.size() != 1) || child_last_loop_dim == 0) {
continue;
}
// check expr and child have common outer loop id
Expand Down Expand Up @@ -270,81 +272,83 @@ size_t SplitLoops::split(LinearIR& linear_ir, size_t loop_to_split_id, size_t ou
linear_ir.move(child_expr_it, outer_loop_bounds_second); // move new expr in this loop before loop_bounds_end
}
std::cout << "extend child:" << child_expr->get_node()->get_friendly_name() << std::endl;
auto expr_out_port = expr->get_output_port(0);
inner_loop_info->replace_with_new_ports(expr_out_port, child_expr->get_output_ports());

expr = child_expr;
extend = true;
extend_global = true;
break; // if one child is extend, stop other consumers. continue this branch
}
if (!extend) {
check_next = false; // no child extend, stop extend deeper
}
}
if (extend_global) {
// if expr port is loop port, replace to new child output port.
auto loop_port = std::find_if(out_ports.begin(), out_ports.end(), [&](LoopPort loop_ports) {
return *loop_ports.expr_port.get() == expr_out_port;
});
if (loop_port != out_ports.end()) {
inner_loop_info->replace_with_new_ports(expr_out_port, {(*std::prev(outer_loop_bounds_second))->get_output_port(0)});
}
}
}
// extend parent
std::cout << "start extend parent" << std::endl;
for (auto expr_it = outer_loop_bounds_first; expr_it != outer_loop_bounds_second; expr_it++) {
auto expr = *expr_it;
std::cout << "expr_for_parent:" << expr->get_node()->get_friendly_name() << std::endl;
auto in_num = expr->get_input_count();
if (in_num > 2 || in_num < 1)
continue;
auto expr_in_port = expr->get_input_port(in_num-1);
bool check_next = true;
const auto& expr_loop = expr->get_loop_ids();
bool extend = false;
while (check_next) {
// if (expr->get_input_count() != 1)
// break;
in_num = expr->get_input_count();
if (in_num > 2 || in_num < 1)
break;
auto parent_expr = expr->get_input_port_connector(in_num-1)->get_source().get_expr();
std::cout << "parent_expr:" << parent_expr->get_node()->get_friendly_name() << std::endl;
if (parent_expr->get_output_count() != 1)
break;
bool is_inside = false;
for (auto expr_check = outer_loop_bounds_first; expr_check != outer_loop_bounds_second; expr_check++) {
if (*expr_check == parent_expr) {
is_inside = true;
break;
auto in_num = expr->get_input_count();
bool extend = false;
for (size_t i = 0; i < in_num; i++) {
auto parent_expr = expr->get_input_port_connector(i)->get_source().get_expr();
bool is_inside = false;
for (auto expr_check = outer_loop_bounds_first; expr_check != outer_loop_bounds_second; expr_check++) {
if (*expr_check == parent_expr) {
is_inside = true;
break;
}
}
}
std::cout << "is_inside:" << is_inside << std::endl;
if (is_inside)
break;
const auto& parent_expr_loop = parent_expr->get_loop_ids();
// have common outer loop id, child w/o inner most loop
for (auto i = 0; i < std::min(expr_loop.size(), parent_expr_loop.size()); i++) {
if (expr_loop[i] != parent_expr_loop[i]) {
break;
std::cout << "is_inside:" << is_inside << std::endl;
if (is_inside)
continue;

// child_last_loop_dim is not 0(inner most dimension)
const auto& parent_expr_loop = parent_expr->get_loop_ids();
if (parent_expr_loop.size() < 1)
continue;
const auto& parent_last_loop_dim = loop_manager->get_loop_info<UnifiedLoopInfo>(parent_expr_loop.back())->get_dim_idx();
if ((expr_loop.size() - parent_expr_loop.size() != 1) || parent_last_loop_dim == 0) {
continue;
}
}
const auto& last_loop_dim = loop_manager->get_loop_info<UnifiedLoopInfo>(parent_expr_loop.back())->get_dim_idx();
if ((expr_loop.size() - parent_expr_loop.size() == 1) && last_loop_dim != 0) {
outer_loop_bounds_first = linear_ir.find(parent_expr);

bool common_outer_loop = true;
// have common outer loop id, child w/o inner most loop
for (auto i = 0; i < std::min(expr_loop.size(), parent_expr_loop.size()); i++) {
if (expr_loop[i] != parent_expr_loop[i]) {
common_outer_loop = false;
break;
}
}
if (!common_outer_loop) {
continue;
}

// not break data dependency

// can extend
auto parent_expr_it = linear_ir.find(parent_expr);
if (parent_expr_it != std::prev(outer_loop_bounds_first)) {
linear_ir.move(parent_expr_it, outer_loop_bounds_first); // move new expr in this loop before outer_loop_bounds_first
}
outer_loop_bounds_first--;
std::cout << "extend child:" << parent_expr->get_node()->get_friendly_name() << std::endl;

// update port. If extend, in_port must be a loop port as parent expr is not in this loop.
auto in_port = expr->get_input_port(i);
inner_loop_info->replace_with_new_ports(in_port, parent_expr->get_input_ports());

expr = parent_expr;
extend = true;
std::cout << "extend........" << std::endl;
} else {
break;
break; // one parent extended, check this branch from parent.
}
expr = parent_expr;
}
if (extend) {
// if expr port is loop port, replace to new child output port.
auto loop_port = std::find_if(in_ports.begin(), in_ports.end(), [&](LoopPort loop_ports) {
return *loop_ports.expr_port.get() == expr_in_port;
});
if (loop_port != out_ports.end()) {
// inner_loop_info->replace_with_new_ports(expr_in_port, {(*outer_loop_bounds_first)->get_output_port(0)});
inner_loop_info->replace_with_new_ports(expr_in_port, {});
if (!extend) {
check_next = false; // no parent extend, stop extend upper
}
}
}
Expand Down
Loading

0 comments on commit b43dfe9

Please sign in to comment.