Skip to content

Commit

Permalink
Alexandra comments apply 2
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Jun 5, 2024
1 parent 6edf613 commit 4620a87
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ namespace pass {
class ExtractLoopInvariants : public RangedPass {
public:
OPENVINO_RTTI("ExtractLoopInvariants", "RangedPass")
ExtractLoopInvariants();
ExtractLoopInvariants() = default;
bool run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) override;

private:
bool is_extraction_applicable(const ExpressionPtr& expr, const UnifiedLoopInfoPtr& inner_loop_info);
void extract_expr(const ExpressionPtr& expr, LinearIR& linear_ir,
static bool is_extraction_applicable(const ExpressionPtr& expr, const UnifiedLoopInfoPtr& inner_loop_info);
static void extract_expr(const ExpressionPtr& expr, LinearIR& linear_ir,
LinearIR::constExprIt& inner_loop_begin_pos, LinearIR::constExprIt& inner_loop_end_pos);
void update_loop_ports(const ExpressionPtr& expr, const LoopManagerPtr& loop_manager, size_t inner_loop_id,
static void update_loop_ports(const ExpressionPtr& expr, const LoopManagerPtr& loop_manager, size_t inner_loop_id,
LinearIR::constExprIt& inner_loop_begin_pos, LinearIR::constExprIt& inner_loop_end_pos);
bool extract_from_loop(const size_t& inner_loop_id, LinearIR& linear_ir);
static bool extract_from_loop(const size_t& inner_loop_id, LinearIR& linear_ir);
};

} // namespace pass
Expand Down
6 changes: 1 addition & 5 deletions src/common/snippets/src/lowered/loop_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,8 @@ void UnifiedLoopInfo::add_loop_ports(const std::vector<ExpressionPort>& ports, b
for (size_t i = 0; i < ports.size(); i++) {
// if already in loop ports, skip
auto loop_port = find_loop_port(ports[i]);
if (loop_port != loop_ports.end()) {
if (loop_port->dim_idx != loop_dim_idx) {
loop_port->dim_idx = loop_dim_idx;
}
if (loop_port != loop_ports.end())
continue;
}
loop_ports.push_back(LoopPort(ports[i], true, loop_dim_idx));
loop_ports_desc.push_back(LoopPortDesc());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace pass {
namespace {
void remove_last_loop_id(const std::shared_ptr<Expression>& expr) {
auto loop_ids = expr->get_loop_ids();
OPENVINO_ASSERT(!loop_ids.empty(), "Expr loop_ids should not be empty when remove last loop id.");
loop_ids.pop_back();
expr->set_loop_ids(loop_ids);
}
Expand All @@ -36,8 +37,6 @@ size_t get_stride_after_move_outer(const LoopPort& loop_port) {
}
} // namespace

ExtractLoopInvariants::ExtractLoopInvariants() : RangedPass() {}

bool ExtractLoopInvariants::is_extraction_applicable(const ExpressionPtr& expr,
const UnifiedLoopInfoPtr& inner_loop_info) {
const auto& expr_input_ports = expr->get_input_ports();
Expand Down Expand Up @@ -115,7 +114,8 @@ void ExtractLoopInvariants::update_loop_ports(const ExpressionPtr& expr, const L
std::vector<ExpressionPort> new_ports;
inner_loop_info->update_loop_ports(exp_out_ports, new_ports, false);
}
// need sort after update loop ports. There are possibility that all exprs are moved to outer loop.
// TODO: 142990.
// Need sort after update loop ports. There are possibility that all exprs are moved to outer loop.
if (!inner_loop_info->get_input_ports().empty() && !inner_loop_info->get_output_ports().empty()) {
loop_manager->sort_loop_ports(inner_loop_begin_pos, inner_loop_end_pos, inner_loop_id);
}
Expand Down

0 comments on commit 4620a87

Please sign in to comment.