Skip to content

Commit

Permalink
extend inner most dimension outer block loop
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Oct 24, 2024
1 parent ac360a8 commit 03c2a91
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 104 deletions.
56 changes: 28 additions & 28 deletions src/common/snippets/src/lowered/pass/fuse_loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ bool FuseLoops::fuse_lower_into_current(LinearIR& linear_ir, const LoopManagerPt
const auto& loop_current = loop_manager->get_loop_info<UnifiedLoopInfo>(current_loop_id);
const auto& loop_target = loop_manager->get_loop_info<UnifiedLoopInfo>(target_loop_id);
if (!can_be_fused(loop_current, loop_target)) {
std::cout << "fuse_lower_into_current can_be_fused is NOT" << std::endl;
// std::cout << "fuse_lower_into_current can_be_fused is NOT" << std::endl;
return false;
}
// return false;
Expand All @@ -169,17 +169,17 @@ bool FuseLoops::fuse_lower_into_current(LinearIR& linear_ir, const LoopManagerPt
const auto& parent_expr = parent_expr_output.get_expr();
if (ov::is_type<ov::op::v0::Parameter>(parent_expr->get_node()) || parent_expr == current_output_port->get_expr())
continue;
std::cout << "parent_expr:" << parent_expr->get_node()->get_friendly_name() << std::endl;
for (size_t i = 0; i < parent_expr->get_loop_ids().size(); i++) {
std::cout << "parent_expr->get_loop_ids:" << parent_expr->get_loop_ids()[i] << std::endl;
}
std::cout << "current_loop_id:" << current_loop_id << std::endl;
// std::cout << "parent_expr:" << parent_expr->get_node()->get_friendly_name() << std::endl;
// for (size_t i = 0; i < parent_expr->get_loop_ids().size(); i++) {
// std::cout << "parent_expr->get_loop_ids:" << parent_expr->get_loop_ids()[i] << std::endl;
// }
// std::cout << "current_loop_id:" << current_loop_id << std::endl;
auto a = is_loop_id_found(parent_expr->get_loop_ids(), current_loop_id);
auto b = parent_expr->get_exec_num();
auto c = (*current_loop_begin_pos)->get_exec_num();
std::cout << "is_loop_id_found:" << a << std::endl;
std::cout << "parent_expr->get_exec_num():" << b << std::endl;
std::cout << "(*current_loop_begin_pos)->get_exec_num():" << c << std::endl;
// std::cout << "is_loop_id_found:" << a << std::endl;
// std::cout << "parent_expr->get_exec_num():" << b << std::endl;
// std::cout << "(*current_loop_begin_pos)->get_exec_num():" << c << std::endl;
is_fusion_allowed = is_loop_id_found(parent_expr->get_loop_ids(), current_loop_id) || // The parent expr is from the same current Loop
parent_expr->get_exec_num() < (*current_loop_begin_pos)->get_exec_num(); // The parent is before current Loop
// is_fusion_allowed = is_loop_id_found(parent_expr->get_loop_ids(), current_loop_id) || // The parent expr is from the same current Loop
Expand Down Expand Up @@ -227,9 +227,9 @@ bool FuseLoops::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, l
const auto current_loop_depth = current_expr_loops.size();
for (size_t i = 0; i < current_loop_depth; ++i) {
const auto current_loop_id = current_expr_loops[i];
if (node_name == "PowerStatic_4096") {
std::cout << "current_loop_id:" << current_loop_id << std::endl;
}
// if (node_name == "PowerStatic_4096") {
// std::cout << "current_loop_id:" << current_loop_id << std::endl;
// }
// If the current Loop ID is in prev fused Loops, it means that on previous step all possible fusions are completed
if (prev_fused_loops.count(current_loop_id) != 0)
continue;
Expand Down Expand Up @@ -286,11 +286,11 @@ bool FuseLoops::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, l
// "Loops cannot have parents of input ports with the same identifier (", upper_loop_id, ")");
if (fuse_upper_into_current(linear_ir, loop_manager, input_port.expr_port, current_loop_id, upper_loop_id,
current_loop_begin_pos, current_loop_end_pos)) {
std::cout << "fuse_upper_into_current ok" << std::endl;
std::cout << "current_loop_id:" << current_loop_id << std::endl;
std::cout << "upper_loop_id:" << upper_loop_id << std::endl;
std::cout << "node:" << node->get_friendly_name() << std::endl;
std::cout << "parent:" << parent->get_friendly_name() << std::endl;
// std::cout << "fuse_upper_into_current ok" << std::endl;
// std::cout << "current_loop_id:" << current_loop_id << std::endl;
// std::cout << "upper_loop_id:" << upper_loop_id << std::endl;
// std::cout << "node:" << node->get_friendly_name() << std::endl;
// std::cout << "parent:" << parent->get_friendly_name() << std::endl;
was_fusion_up = true;
prev_fused_loops.insert(current_loop_id);
current_loop_info = loop_manager->get_loop_info(current_loop_id);
Expand All @@ -304,9 +304,9 @@ bool FuseLoops::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, l
// Loop_0 (Current) Loop_0 + Loop_1 => new `Loop_0`
// | => |
// Loop_1 (Lower) |
if (node_name == "PowerStatic_4096") {
std::cout << "start fusion_down" << std::endl;
}
// if (node_name == "PowerStatic_4096") {
// std::cout << "start fusion_down" << std::endl;
// }
bool was_fusion_down = false;
const auto& output_ports = current_loop_info->get_output_ports();
for (size_t out_port = 0; !was_fusion_down && out_port < output_ports.size(); ++out_port) {
Expand All @@ -315,9 +315,9 @@ bool FuseLoops::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, l
for (const auto& consumer_expr_input : consumer_exprs_inputs) {
auto consumer_expr = consumer_expr_input.get_expr();
const auto consumer = consumer_expr->get_node();
if (node_name == "PowerStatic_4096") {
std::cout << "PowerStatic_4096 child:" << consumer->get_friendly_name() << std::endl;
}
// if (node_name == "PowerStatic_4096") {
// std::cout << "PowerStatic_4096 child:" << consumer->get_friendly_name() << std::endl;
// }
// if parent of consumer is brgemm, not fuse
// if (ov::is_type<snippets::op::Brgemm>(consumer)) {
// break;
Expand Down Expand Up @@ -361,11 +361,11 @@ bool FuseLoops::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, l

if (fuse_lower_into_current(linear_ir, loop_manager, output_port.expr_port, current_loop_id, lower_loop_id,
current_loop_begin_pos, current_loop_end_pos)) {
std::cout << "fuse_lower_into_current ok" << std::endl;
std::cout << "current_loop_id:" << current_loop_id << std::endl;
std::cout << "lower_loop_id:" << lower_loop_id << std::endl;
std::cout << "node:" << node->get_friendly_name() << std::endl;
std::cout << "consumer:" << consumer->get_friendly_name() << std::endl;
// std::cout << "fuse_lower_into_current ok" << std::endl;
// std::cout << "current_loop_id:" << current_loop_id << std::endl;
// std::cout << "lower_loop_id:" << lower_loop_id << std::endl;
// std::cout << "node:" << node->get_friendly_name() << std::endl;
// std::cout << "consumer:" << consumer->get_friendly_name() << std::endl;
was_fusion_down = true;
prev_fused_loops.insert(current_loop_id);
current_loop_info = loop_manager->get_loop_info(current_loop_id);
Expand Down
Loading

0 comments on commit 03c2a91

Please sign in to comment.