Skip to content

Commit

Permalink
innermost loop split extend. update brgemm port
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Nov 4, 2024
1 parent b43dfe9 commit 6116fbe
Show file tree
Hide file tree
Showing 7 changed files with 808 additions and 174 deletions.
1 change: 1 addition & 0 deletions src/common/snippets/include/snippets/op/brgemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class Brgemm : virtual public modifier::MemoryAccess, public ov::op::Op {
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {},
float beta = 0.f);
Brgemm() = default;
Brgemm(bool c_pre_scale);

size_t get_offset_a() const { return get_input_offset(0); }
size_t get_offset_b() const { return get_input_offset(1); }
Expand Down
26 changes: 20 additions & 6 deletions src/common/snippets/src/lowered/pass/brgemm_blocking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,24 +146,38 @@ bool BrgemmBlockingBase::mark_blocking_loops(snippets::lowered::LinearIR& linear
brgemm_expr->get_output_port_descriptor(0)->set_subtensor(ov::snippets::VectorDims{m_block, n_block});

const auto& loop_manager = linear_ir.get_loop_manager();
const auto& brgemm_node = std::dynamic_pointer_cast<op::Brgemm>((*brgemm_it)->get_node());
if (!ov::snippets::utils::is_full_dim_value(k_block)) {
// const std::vector<LoopPort> entries{LoopPort(brgemm_expr->get_input_port(0), true, 0),
// LoopPort(brgemm_expr->get_input_port(1), true, 1)};
// const std::vector<LoopPort> exits{LoopPort(brgemm_expr->get_output_port(0), false)};
const std::vector<LoopPort> entries{LoopPort(brgemm_expr->get_input_port(0), true),
LoopPort(brgemm_expr->get_input_port(1), true)};
const std::vector<LoopPort> exits{LoopPort(brgemm_expr->get_output_port(0), false)};
std::vector<LoopPort> entries{LoopPort(brgemm_expr->get_input_port(0), true, 0),
LoopPort(brgemm_expr->get_input_port(1), true, 0)};
// if (brgemm_node->with_c_pre_ops) {
// entries.push_back(LoopPort(brgemm_expr->get_input_port(2), false, 0));
// }
const std::vector<LoopPort> exits{LoopPort(brgemm_expr->get_output_port(0), false, 0)};
mark_k_blocking(loop_manager, brgemm_it, std::next(brgemm_it), entries, exits, k_block);
}
if (!ov::snippets::utils::is_full_dim_value(n_block)) {
const std::vector<LoopPort> entries{LoopPort(brgemm_expr->get_input_port(0), false),
std::vector<LoopPort> entries{LoopPort(brgemm_expr->get_input_port(0), false),
LoopPort(brgemm_expr->get_input_port(1), true)};
// if (brgemm_node->with_c_pre_ops) {
// entries.push_back(LoopPort(brgemm_expr->get_input_port(2), false));
// }
const std::vector<LoopPort> exits{LoopPort(brgemm_expr->get_output_port(0), true)};
mark_n_blocking(loop_manager, brgemm_it, std::next(brgemm_it), entries, exits, n_block);
}
if (!ov::snippets::utils::is_full_dim_value(m_block)) {
const std::vector<LoopPort> entries{LoopPort(brgemm_expr->get_input_port(0), true),
LoopPort(brgemm_expr->get_input_port(1), false)};
std::vector<LoopPort> entries{LoopPort(brgemm_expr->get_input_port(0), true),
LoopPort(brgemm_expr->get_input_port(1), false)};
// if (brgemm_node->with_c_pre_ops) {
// std::cout << "brgemm_node->with_c_pre_ops.................." << std::endl;
// entries.push_back(LoopPort(brgemm_expr->get_input_port(2), true));
// }
std::cout << "brgemm_node:" << brgemm_node->get_friendly_name() << std::endl;
std::cout << "brgemm_node->with_c_pre_ops not.................." << std::endl;
std::cout << "brgemm_node input:" << brgemm_node->get_input_count() << std::endl;
const std::vector<LoopPort> exits{LoopPort(brgemm_expr->get_output_port(0), true)};
mark_m_blocking(loop_manager, brgemm_it, std::next(brgemm_it), entries, exits, m_block);
}
Expand Down
Loading

0 comments on commit 6116fbe

Please sign in to comment.