Skip to content

Commit

Permalink
Alexandra comments apply 2
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Jun 6, 2024
1 parent 6edf613 commit 54fc2fb
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 41 deletions.
21 changes: 15 additions & 6 deletions src/common/snippets/include/snippets/lowered/loop_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,16 @@ class LoopInfo {
const char* get_type_name() const {
return get_type_info().name;
}

/**
* @brief Find LoopPort in input and output ports
* @param loop_port target port
* @return iterator of the corresponding collection
* @brief Return true if expression port is a loop port
* @param expr_port - expression port to check
*/
template<typename T>
std::vector<LoopPort>::iterator find_loop_port(const T& loop_port);
bool is_loop_port(const ExpressionPort& expr_port);
/**
* @brief Return loop port of an expression port
* @param expr_port - expression port.
*/
const LoopPort& get_loop_port(const ExpressionPort& expr_port);

protected:
/**
Expand All @@ -162,6 +164,13 @@ class LoopInfo {
* @return vector with new cloned loop ports
*/
static std::vector<LoopPort> clone_loop_ports(const ExpressionMap& expr_map, const std::vector<LoopPort>& loop_ports);
/**
* @brief Find LoopPort in input and output ports
* @param loop_port target port
* @return iterator of the corresponding collection
*/
template<typename T>
std::vector<LoopPort>::iterator find_loop_port(const T& loop_port);

size_t m_work_amount = 0;
size_t m_increment = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@ namespace pass {
class ExtractLoopInvariants : public RangedPass {
public:
OPENVINO_RTTI("ExtractLoopInvariants", "RangedPass")
ExtractLoopInvariants();
ExtractLoopInvariants() = default;
bool run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) override;

private:
bool is_extraction_applicable(const ExpressionPtr& expr, const UnifiedLoopInfoPtr& inner_loop_info);
void extract_expr(const ExpressionPtr& expr, LinearIR& linear_ir,
static std::set<ExpressionPtr> get_potential_extractable_exprs(const std::vector<LoopPort>& loop_in_ports);
static bool is_extraction_applicable(const ExpressionPtr& expr, const UnifiedLoopInfoPtr& inner_loop_info);
static void extract_expr(const ExpressionPtr& expr, LinearIR& linear_ir,
LinearIR::constExprIt& inner_loop_begin_pos, LinearIR::constExprIt& inner_loop_end_pos);
void update_loop_ports(const ExpressionPtr& expr, const LoopManagerPtr& loop_manager, size_t inner_loop_id,
static void update_loop_ports(const ExpressionPtr& expr, const LoopManagerPtr& loop_manager, size_t inner_loop_id,
LinearIR::constExprIt& inner_loop_begin_pos, LinearIR::constExprIt& inner_loop_end_pos);
bool extract_from_loop(const size_t& inner_loop_id, LinearIR& linear_ir);
static bool extract_from_loop(const size_t& inner_loop_id, LinearIR& linear_ir);
};

} // namespace pass
Expand Down
19 changes: 13 additions & 6 deletions src/common/snippets/src/lowered/loop_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ std::vector<LoopPort>::iterator LoopInfo::find_loop_port(const LoopPort& loop_po
auto& ports = loop_port.expr_port->get_type() == ExpressionPort::Input ? m_input_ports : m_output_ports;
const auto it = std::find_if(ports.begin(), ports.end(),
[&loop_port](const LoopPort& port) { return port == loop_port; });
OPENVINO_ASSERT(it != ports.end(), "Failed update_loop_port: existing loop port has not been found");
OPENVINO_ASSERT(it != ports.end(), "Failed find_loop_port: existing loop port has not been found");
return it;
}

Expand All @@ -101,6 +101,17 @@ std::vector<LoopPort>::iterator LoopInfo::find_loop_port(const ExpressionPort& e
return it;
}

bool LoopInfo::is_loop_port(const ExpressionPort& expr_port) {
const auto& loop_port_it = find_loop_port(expr_port);
const auto& ports = expr_port.get_type() == ExpressionPort::Input ? m_input_ports : m_output_ports;
return loop_port_it != ports.end();
}

const LoopPort& LoopInfo::get_loop_port(const ExpressionPort& expr_port) {
OPENVINO_ASSERT(is_loop_port(expr_port), "Failed get_loop_port: expr_port is not a loop port");
return *find_loop_port(expr_port);
}

void LoopInfo::replace_with_new_ports(const LoopPort& actual_port, const std::vector<LoopPort>& target_ports) {
auto& ports = actual_port.expr_port->get_type() == ExpressionPort::Input ? m_input_ports : m_output_ports;
auto port_it = find_loop_port(actual_port);
Expand Down Expand Up @@ -335,12 +346,8 @@ void UnifiedLoopInfo::add_loop_ports(const std::vector<ExpressionPort>& ports, b
for (size_t i = 0; i < ports.size(); i++) {
// if already in loop ports, skip
auto loop_port = find_loop_port(ports[i]);
if (loop_port != loop_ports.end()) {
if (loop_port->dim_idx != loop_dim_idx) {
loop_port->dim_idx = loop_dim_idx;
}
if (loop_port != loop_ports.end())
continue;
}
loop_ports.push_back(LoopPort(ports[i], true, loop_dim_idx));
loop_ports_desc.push_back(LoopPortDesc());
}
Expand Down
37 changes: 21 additions & 16 deletions src/common/snippets/src/lowered/pass/extract_loop_invariants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace pass {
namespace {
void remove_last_loop_id(const std::shared_ptr<Expression>& expr) {
auto loop_ids = expr->get_loop_ids();
OPENVINO_ASSERT(!loop_ids.empty(), "Expr loop_ids should not be empty when remove last loop id.");
loop_ids.pop_back();
expr->set_loop_ids(loop_ids);
}
Expand All @@ -36,31 +37,27 @@ size_t get_stride_after_move_outer(const LoopPort& loop_port) {
}
} // namespace

ExtractLoopInvariants::ExtractLoopInvariants() : RangedPass() {}

bool ExtractLoopInvariants::is_extraction_applicable(const ExpressionPtr& expr,
const UnifiedLoopInfoPtr& inner_loop_info) {
const auto& expr_input_ports = expr->get_input_ports();
const auto& input_port_size = expr_input_ports.size();
if (input_port_size == 0)
return false;

const auto& inner_loop_input_ports = inner_loop_info->get_input_ports();

for (size_t i = 0; i < input_port_size; ++i) { // iter expr ports
const auto& parent = expr->get_input_port_connector(i)->get_source().get_expr();
bool parent_scalar_with_single_consumer = ov::is_type<snippets::op::Scalar>(parent->get_node()) &&
parent->get_output_port_connector(0)->get_consumers().size() == 1;

const auto& loop_port = inner_loop_info->find_loop_port(expr_input_ports[i]);
const auto& is_loop_port = inner_loop_info->is_loop_port(expr_input_ports[i]);
// expr input port is not a loop input port, then should not extract
// expr with single scalar parent could be extracted as well, which is common.
if (loop_port == inner_loop_input_ports.end() && !parent_scalar_with_single_consumer) {
if (!is_loop_port && !parent_scalar_with_single_consumer) {
return false;
}
if (loop_port != inner_loop_input_ports.end()) {
if (is_loop_port) {
// after move to outside, stride in inner loop is not 1, then should not extract.
if (get_stride_after_move_outer(*loop_port) != 1) {
const auto& loop_port = inner_loop_info->get_loop_port(expr_input_ports[i]);
if (get_stride_after_move_outer(loop_port) != 1) {
return false;
}
}
Expand Down Expand Up @@ -103,34 +100,42 @@ void ExtractLoopInvariants::update_loop_ports(const ExpressionPtr& expr, const L
inner_loop_info->update_loop_ports(expr_input_ports, new_loop_input_ports, true);

// delete expr out ports from loop out ports directly if it's in loop output ports
const auto& loop_out_ports = inner_loop_info->get_output_ports();
std::vector<ExpressionPort> exp_out_ports;
for (size_t i = 0; i < expr->get_output_count(); ++i) {
const auto& out_port = expr->get_output_port(i);
if (inner_loop_info->find_loop_port(out_port) != loop_out_ports.end()) {
if (inner_loop_info->is_loop_port(out_port)) {
exp_out_ports.push_back(out_port);
}
}
if (!exp_out_ports.empty()) {
std::vector<ExpressionPort> new_ports;
inner_loop_info->update_loop_ports(exp_out_ports, new_ports, false);
}
// need sort after update loop ports. There are possibility that all exprs are moved to outer loop.
// TODO: 142990.
// Need sort after update loop ports. There are possibility that all exprs are moved to outer loop.
if (!inner_loop_info->get_input_ports().empty() && !inner_loop_info->get_output_ports().empty()) {
loop_manager->sort_loop_ports(inner_loop_begin_pos, inner_loop_end_pos, inner_loop_id);
}
}

std::set<ExpressionPtr> ExtractLoopInvariants::get_potential_extractable_exprs(const std::vector<LoopPort>& loop_in_ports) {
std::set<ExpressionPtr> expr_set;
for (size_t i = 0; i < loop_in_ports.size(); ++i) {
expr_set.insert(loop_in_ports[i].expr_port->get_expr());
}
return expr_set;
}

bool ExtractLoopInvariants::extract_from_loop(const size_t& inner_loop_id, LinearIR& linear_ir) {
const auto& loop_manager = linear_ir.get_loop_manager();
bool status = false;
bool extraction_completed = false; // true means all extractable exprs are extracted from this loop
const auto& inner_loop_info = loop_manager->get_loop_info<UnifiedLoopInfo>(inner_loop_id);
while (!extraction_completed) {
const auto& inner_loop_input_ports = inner_loop_info->get_input_ports();
const auto& potential_extractable_exprs = get_potential_extractable_exprs(inner_loop_input_ports);
bool has_been_extraction = false;
for (size_t i = 0; i < inner_loop_input_ports.size(); i++) { // iter loop ports
const auto& port_expr = inner_loop_input_ports[i].expr_port->get_expr();
for (const auto& port_expr : potential_extractable_exprs) {
if (is_extraction_applicable(port_expr, inner_loop_info)) {
status = true;
LinearIR::constExprIt inner_loop_begin_pos, inner_loop_end_pos;
Expand All @@ -146,11 +151,11 @@ bool ExtractLoopInvariants::extract_from_loop(const size_t& inner_loop_id, Linea
extract_expr(port_expr, linear_ir, inner_loop_begin_pos, inner_loop_end_pos);
update_loop_ports(port_expr, loop_manager, inner_loop_id, inner_loop_begin_pos, inner_loop_end_pos);
has_been_extraction = true;
break; // extracted and refreshed loop_input_ports. break for() of iter loop ports and go while() to start again.
break; // extracted and refreshed loop_input_ports. break potential_extractable_exprs loop, and go while() to start again.
}
}
if (inner_loop_input_ports.size() == 0 && inner_loop_info->get_output_ports().size() == 0) {
// become a empty loop, let remove it from loop_manager
// become a empty(inner_loop_input_ports is ref) loop after extraction, let remove it from loop_manager
loop_manager->remove_loop_info(inner_loop_id);
extraction_completed = true;
break;
Expand Down
2 changes: 1 addition & 1 deletion src/common/snippets/tests/src/lir_comparator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ LIRComparator::Result LIRComparator::compare_loop_managers_without_index(const L
for (auto map_iter = map.cbegin(), map_ref_iter = map_ref.cbegin(); map_iter != map.cend(); map_iter++, map_ref_iter++) {
const auto& loop_info = map_iter->second;
const auto& loop_info_ref = map_ref_iter->second;
PROPAGATE_ERROR("Loop with id " + std::to_string(map_iter->first) + "and Ref Loop with id" + std::to_string(map_ref_iter->first),
PROPAGATE_ERROR("Loop with id " + std::to_string(map_iter->first) + " and ref Loop with id " + std::to_string(map_ref_iter->first),
compare_loop_info(loop_info, loop_info_ref));
}
return Result::ok();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using namespace ov::snippets::lowered::pass;
class ExtractLoopInvariantsTest : public LoweredPassTestsF {
public:
ExtractLoopInvariantsTest() : LoweredPassTestsF() {
comparator.disable(LIRComparator::LIRCmpValues::LOOP_INDICES); // loop could be removed and index could be different
comparator.disable(LIRComparator::LIRCmpValues::LOOP_INDICES); // loop could be removed and loop index could be different
comparator.enable(LIRComparator::LIRCmpValues::PORT_DESCRIPTORS);
comparator.enable(LIRComparator::LIRCmpValues::PORT_CONNECTORS);
comparator.enable(LIRComparator::LIRCmpValues::LOOP_MANAGER);
Expand All @@ -28,7 +28,7 @@ class ExtractLoopInvariantsTest : public LoweredPassTestsF {
}
};

TEST_F(ExtractLoopInvariantsTest, ExtractLoopInvariantsWithParams) {
TEST_F(ExtractLoopInvariantsTest, ExtractedLoopInvariantsWithParams) {
size_t vector_size = 16;
const auto input_precision = ov::element::f32;
const ov::Shape input_shape0{1};
Expand Down Expand Up @@ -84,7 +84,7 @@ TEST_F(ExtractLoopInvariantsTest, ExtractLoopInvariantsWithParams) {
}
}

TEST_F(ExtractLoopInvariantsTest, ExtractLoopInvariantsWithScalar) {
TEST_F(ExtractLoopInvariantsTest, ExtractedLoopInvariantsWithScalar) {
size_t vector_size = 16;
const auto input_precision = ov::element::f32;
const ov::Shape scalar_shape{1};
Expand Down Expand Up @@ -140,7 +140,7 @@ TEST_F(ExtractLoopInvariantsTest, ExtractLoopInvariantsWithScalar) {
}
}

TEST_F(ExtractLoopInvariantsTest, ExtractLoopInvariantsOutputLoopUpdateNotNeed) {
TEST_F(ExtractLoopInvariantsTest, ExtractedLoopInvariantsOutputLoopUpdateNotNeed) {
size_t vector_size = 16;
const auto input_precision = ov::element::f32;
const ov::Shape input_shape_a{3, 1};
Expand All @@ -149,8 +149,8 @@ TEST_F(ExtractLoopInvariantsTest, ExtractLoopInvariantsOutputLoopUpdateNotNeed)
const std::vector<ov::snippets::VectorDims> subtensor_mul{{3, 1}, {3, 1}, {3, 1}};
const std::vector<ov::snippets::VectorDims> subtensor_add{{3, 16}, {3, 16}, {3, 16}};
/*
* Befor: Param0, Param1, Param2, [[Multiply, Broadcast, Add, Sub]], Result0, Result1
* After: Param0, Param1, Param2, [Multiply, Broadcast, [Add, Sub]], Result0, Result1
* Before: Param0, Param1, Param2, [[Multiply, Broadcast, Add, Sub]], Result0, Result1
* After: Param0, Param1, Param2, [Multiply, Broadcast, [Add, Sub]], Result0, Result1
* Param0(3,1) Param1(3,1)
* \ /
* Multiply
Expand Down Expand Up @@ -228,7 +228,7 @@ TEST_F(ExtractLoopInvariantsTest, ExtractLoopInvariantsOutputLoopUpdateNotNeed)

// softmax with shape of 1 for innermost dimension.
// Cover multiple(all) exprs are extracted, and loop is removed.
TEST_F(ExtractLoopInvariantsTest, ExtractLoopInvariantsAllExprsInLoopExtracted) {
TEST_F(ExtractLoopInvariantsTest, ExtractedLoopInvariantsAllExprsInLoopExtracted) {
size_t vector_size = 16;
const auto input_precision = ov::element::f32;
const ov::Shape input_shape{10, 1};
Expand Down Expand Up @@ -324,6 +324,61 @@ TEST_F(ExtractLoopInvariantsTest, ExtractLoopInvariantsAllExprsInLoopExtracted)
}
}

TEST_F(ExtractLoopInvariantsTest, ExtractedLoopInvariantsFromInnermostToLoopOutside) {
size_t vector_size = 16;
const auto input_precision = ov::element::f32;
const ov::Shape input_shape_0{3, 512};
const ov::Shape input_shape_1{1, 1};
ov::snippets::VectorDims layout{0, 1};
ov::snippets::VectorDims subtensor{3, 512};
/*
* before: Param0, Param1, [[Broadcast, Add]], Result
* intermediate: Param0, Param1, [Broadcast, [Add]], Result
* after: Param0, Param1, Broadcast, [[Add]], Result
* Param0(3,512) Param1(1,1)
* \ /
* \ Broadcast
* \ /
* Add
* |
* Result
*/
{
auto param_0 = linear_ir->push_node<ov::opset10::Parameter>(input_precision, input_shape_0);
auto param_1 = linear_ir->push_node<ov::opset10::Parameter>(input_precision, input_shape_1);
auto broadcastmove = linear_ir->push_node<ov::snippets::op::BroadcastMove>(param_1.second, 512);
init_expr_descriptors(*broadcastmove.first, {{1, 1}, subtensor}, {layout, layout});
auto add = linear_ir->push_node<ov::opset10::Add>(param_0.second, broadcastmove.second);
init_expr_descriptors(*add.first, {subtensor, subtensor, subtensor}, {layout, layout, layout});
auto result = linear_ir->push_node<ov::opset10::Result>(add.second);
create_and_add_unified_loop_info(linear_ir, broadcastmove.first, result.first, 3, 1,
{LoopPort((*broadcastmove.first)->get_input_port(0), true, 1),
LoopPort((*add.first)->get_input_port(0), true, 1)},
{LoopPort((*add.first)->get_output_port(0), true, 1)});
create_and_add_unified_loop_info(linear_ir, broadcastmove.first, result.first, 512, vector_size,
{LoopPort((*broadcastmove.first)->get_input_port(0), true, 0),
LoopPort((*add.first)->get_input_port(0), true, 0)},
{LoopPort((*add.first)->get_output_port(0), true, 0)});
}
{
auto param_0 = linear_ir_ref->push_node<ov::opset10::Parameter>(input_precision, input_shape_0);
auto param_1 = linear_ir_ref->push_node<ov::opset10::Parameter>(input_precision, input_shape_1);
auto broadcastmove = linear_ir_ref->push_node<ov::snippets::op::BroadcastMove>(param_1.second, 512);
init_expr_descriptors(*broadcastmove.first, {{1, 1}, subtensor}, {layout, layout});
auto add = linear_ir_ref->push_node<ov::opset10::Add>(param_0.second, broadcastmove.second);
init_expr_descriptors(*add.first, {subtensor, subtensor, subtensor}, {layout, layout, layout});
auto result = linear_ir_ref->push_node<ov::opset10::Result>(add.second);
create_and_add_unified_loop_info(linear_ir_ref, add.first, result.first, 3, 1,
{LoopPort((*add.first)->get_input_port(0), true, 1),
LoopPort((*add.first)->get_input_port(1), true, 1)},
{LoopPort((*add.first)->get_output_port(0), true, 1)});
create_and_add_unified_loop_info(linear_ir_ref, add.first, result.first, 512, vector_size,
{LoopPort((*add.first)->get_input_port(0), true, 0),
LoopPort((*add.first)->get_input_port(1), true, 0)},
{LoopPort((*add.first)->get_output_port(0), true, 0)});
}
}

} // namespace snippets
} // namespace test
} // namespace ov

0 comments on commit 54fc2fb

Please sign in to comment.