Skip to content

Commit aa221ae

Browse files
committed
Apply Ivan review comments
1 parent 5cd9250 commit aa221ae

File tree

14 files changed

+121
-164
lines changed

14 files changed

+121
-164
lines changed

src/common/snippets/include/snippets/lowered/linear_ir.hpp

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -223,34 +223,6 @@ class LinearIR {
223223
*/
224224
exprIt replace_with_expr(const std::vector<ExpressionPtr>& old_exprs, const ExpressionPtr& new_expr);
225225

226-
/**
227-
* @brief Get zero to several consecutive child shape infer exprs(such as reshape, rankNormalization) from start_expr.
228-
* @param start_expr Collect from start_expr.
229-
* @return shape infer expression consumers as a sequence.
230-
*/
231-
static std::vector<ExpressionPtr> get_child_shape_infer_expr_seq(const ExpressionPtr& start_expr);
232-
233-
/**
234-
* @brief Get zero to several consecutive parent shape infer exprs(such as reshape, rankNormalization) from start_expr.
235-
* @param start_expr Collect from start_expr.
236-
* @return shape infer expression sources as a sequence.
237-
*/
238-
static std::vector<ExpressionPtr> get_parent_shape_infer_expr_seq(const ExpressionPtr& start_expr);
239-
240-
/**
241-
* @brief Get last child shape infer op from start_expr in a sequence. If no shape infer op is connect to start_expr, return start_expr.
242-
* @param start_expr Search from start_expr.
243-
* @return last child shape infer expr
244-
*/
245-
static ExpressionPtr get_last_child_shape_infer_expr(const ExpressionPtr& start_expr);
246-
247-
/**
248-
* @brief Get last parent shape infer op from start_expr in a sequence. If no shape infer op is connect to start_expr, return start_expr.
249-
* @param start_expr Search from start_expr.
250-
* @return last parent shape infer expr
251-
*/
252-
static ExpressionPtr get_last_parent_shape_infer_expr(const ExpressionPtr& start_expr);
253-
254226
private:
255227
std::shared_ptr<ShapeInferSnippetsNode> m_shape_infer = nullptr;
256228

src/common/snippets/include/snippets/shape_inference/shape_infer_instances.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ class ReduceShapeInfer : public IShapeInferSnippets {
7676
};
7777

7878
class ReshapeShapeInfer : public IShapeInferSnippets {
79-
ov::PartialShape target_shape;
79+
VectorDims target_shape;
80+
size_t target_shape_volume = 0;
8081
public:
8182
explicit ReshapeShapeInfer(const std::shared_ptr<Node>& n);
8283
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;

src/common/snippets/include/snippets/utils.hpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ VectorDims get_planar_vdims(const snippets::lowered::ExpressionPort& expr_port);
154154
* @return preordered shape: `shape[i]` = `planar_shape[order[i]]` where `shape` is shape before applying the order.
155155
*/
156156
VectorDims get_preordered_vdims(const snippets::lowered::ExpressionPort& expr_port);
157+
/* --------------------------- */
158+
157159
/**
158160
* @brief Returns element count of a shape
159161
* @param shape input shape
@@ -162,7 +164,22 @@ VectorDims get_preordered_vdims(const snippets::lowered::ExpressionPort& expr_po
162164
inline auto get_shape_size(const VectorDims& shape) -> size_t {
163165
return std::accumulate(shape.begin(), shape.end(), static_cast<size_t>(1), std::multiplies<size_t>());
164166
}
165-
/* --------------------------- */
167+
168+
/**
169+
* @brief Get zero to several consecutive child shape infer exprs(such as reshape, rankNormalization) from start_expr.
170+
* As Node maybe have multiple output. This functions return the first(left) legal sequence.
171+
* @param start_expr Collect from start_expr.
172+
* @return shape infer expression consumers as a sequence.
173+
*/
174+
std::vector<lowered::ExpressionPtr> get_first_child_shape_infer_expr_seq(const lowered::ExpressionPtr& start_expr);
175+
176+
/**
177+
* @brief Get zero to several consecutive parent shape infer exprs(such as reshape, rankNormalization) from start_expr.
178+
* As Node maybe have multiple input. This functions return the first(left) legal sequence.
179+
* @param start_expr Collect from start_expr.
180+
* @return shape infer expression sources as a sequence.
181+
*/
182+
std::vector<lowered::ExpressionPtr> get_first_parent_shape_infer_expr_seq(const lowered::ExpressionPtr& start_expr);
166183

167184
} // namespace utils
168185
} // namespace snippets

src/common/snippets/src/lowered/linear_ir.cpp

Lines changed: 3 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -371,8 +371,9 @@ VectorDims LinearIR::get_master_shape() const {
371371
if (!m_config.m_enable_domain_optimization && ov::is_type<snippets::op::Brgemm>(source.get_expr()->get_node())) {
372372
master_shape = utils::get_preordered_vdims(source);
373373
} else {
374-
auto last_shape_infer_expr = LinearIR::get_last_parent_shape_infer_expr(out_exprs[0]);
375-
master_shape = utils::get_preordered_vdims(last_shape_infer_expr->get_input_port_connector(0)->get_source());
374+
const auto& shape_infer_seq = utils::get_first_parent_shape_infer_expr_seq(out_exprs[0]);
375+
const auto& expr = shape_infer_seq.empty() ? out_exprs[0] : shape_infer_seq.back();
376+
master_shape = utils::get_preordered_vdims(expr->get_input_port_connector(0)->get_source());
376377
}
377378
} else {
378379
for (const auto& oe : out_exprs) {
@@ -498,92 +499,6 @@ LinearIR::exprIt LinearIR::replace_with_expr(const std::vector<ExpressionPtr>& o
498499
return replace_with_expr(old_exprs, new_expr, insertion_place);
499500
}
500501

501-
std::vector<ExpressionPtr> LinearIR::get_child_shape_infer_expr_seq(const ExpressionPtr& start_expr) {
502-
std::vector<ExpressionPtr> shape_infer_exprs;
503-
auto current_exp = start_expr;
504-
if (op::Subgraph::is_shape_infer_op(current_exp->get_node())) {
505-
OPENVINO_ASSERT(current_exp->get_input_port_connector(0)->get_consumers().size() == 1, "Shape infer ops are supposed to be the only consumer.");
506-
shape_infer_exprs.push_back(current_exp);
507-
}
508-
if (current_exp->get_output_count() == 0)
509-
return shape_infer_exprs;
510-
auto output_consumers = current_exp->get_output_port_connector(0)->get_consumers();
511-
auto first_child = output_consumers.begin()->get_expr();
512-
while (op::Subgraph::is_shape_infer_op(first_child->get_node())) {
513-
OPENVINO_ASSERT(output_consumers.size() == 1, "Shape infer ops are supposed to be the only consumer.");
514-
shape_infer_exprs.push_back(first_child);
515-
current_exp = first_child;
516-
if (current_exp->get_output_count() == 0)
517-
break;
518-
output_consumers = current_exp->get_output_port_connector(0)->get_consumers();
519-
first_child = output_consumers.begin()->get_expr();
520-
}
521-
return shape_infer_exprs;
522-
}
523-
524-
std::vector<ExpressionPtr> LinearIR::get_parent_shape_infer_expr_seq(const ExpressionPtr& start_expr) {
525-
std::vector<ExpressionPtr> shape_infer_exprs;
526-
auto current_exp = start_expr;
527-
if (op::Subgraph::is_shape_infer_op(current_exp->get_node())) {
528-
OPENVINO_ASSERT(current_exp->get_input_port_connector(0)->get_consumers().size() == 1, "Shape infer ops are supposed to be the only consumer.");
529-
shape_infer_exprs.push_back(current_exp);
530-
}
531-
if (current_exp->get_input_count() == 0)
532-
return shape_infer_exprs;
533-
auto input = current_exp->get_input_port_connector(0);
534-
auto first_parent = input->get_source().get_expr();
535-
while (op::Subgraph::is_shape_infer_op(first_parent->get_node())) {
536-
shape_infer_exprs.push_back(first_parent);
537-
current_exp = first_parent;
538-
if (current_exp->get_input_count() == 0)
539-
break;
540-
input = current_exp->get_input_port_connector(0);
541-
first_parent = input->get_source().get_expr();
542-
if (!ov::is_type<snippets::op::Store>(first_parent->get_node())) {
543-
// there are maybe some loopEnd consumers of store as well for loop code gen purpose
544-
OPENVINO_ASSERT(input->get_consumers().size() == 1, "Shape infer ops are supposed to be the only consumer if it doesn't consume a store ops.");
545-
}
546-
}
547-
return shape_infer_exprs;
548-
}
549-
550-
ExpressionPtr LinearIR::get_last_child_shape_infer_expr(const ExpressionPtr& start_expr) {
551-
auto last_exp = start_expr;
552-
if (last_exp->get_output_count() == 0)
553-
return last_exp;
554-
auto consumers = last_exp->get_output_port_connector(0)->get_consumers();
555-
auto first_child = consumers.begin()->get_expr();
556-
while (op::Subgraph::is_shape_infer_op(first_child->get_node())) {
557-
OPENVINO_ASSERT(consumers.size() == 1, "Shape infer ops are supposed to be the only consumer.");
558-
last_exp = first_child;
559-
if (last_exp->get_output_count() == 0)
560-
break;
561-
consumers = last_exp->get_output_port_connector(0)->get_consumers();
562-
first_child = consumers.begin()->get_expr();
563-
}
564-
return last_exp;
565-
}
566-
567-
ExpressionPtr LinearIR::get_last_parent_shape_infer_expr(const ExpressionPtr& start_expr) {
568-
auto last_exp = start_expr;
569-
if (last_exp->get_input_count() == 0)
570-
return last_exp;
571-
auto input = last_exp->get_input_port_connector(0);
572-
auto first_parent = input->get_source().get_expr();
573-
while (op::Subgraph::is_shape_infer_op(first_parent->get_node())) {
574-
last_exp = first_parent;
575-
if (last_exp->get_input_count() == 0)
576-
break;
577-
input = last_exp->get_input_port_connector(0);
578-
first_parent = input->get_source().get_expr();
579-
if (!ov::is_type<snippets::op::Store>(first_parent->get_node())) {
580-
// there are maybe some loopEnd consumers of store as well for loop code gen purpose
581-
OPENVINO_ASSERT(input->get_consumers().size() == 1, "Shape infer ops are supposed to be the only consumer if it doesn't consume a store ops.");
582-
}
583-
}
584-
return last_exp;
585-
}
586-
587502
LinearIR::LIRShapeInfer::LIRShapeInfer(container& body_exprs, io_container& io_exprs)
588503
: ShapeInferSnippetsNode(),
589504
m_exprs{std::make_shared<container>(body_exprs)} {

src/common/snippets/src/lowered/pass/allocate_buffers.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "snippets/lowered/pass/normalize_buffer_ids.hpp"
1414
#include "snippets/pass/tokenization.hpp"
1515
#include "snippets/itt.hpp"
16+
#include "snippets/utils.hpp"
1617

1718
namespace ov {
1819
namespace snippets {
@@ -46,8 +47,9 @@ void AllocateBuffers::set_buffer_offset(const ExpressionPtr& buffer_expr, const
4647
}
4748
}
4849
// Propagate to down: in Load. Buffer can have several Load
49-
auto last_shape_infer = ov::snippets::lowered::LinearIR::get_last_child_shape_infer_expr(buffer_expr);
50-
const auto& buffer_out = last_shape_infer->get_output_port_connector(0);
50+
const auto& shape_infer_seq = utils::get_first_child_shape_infer_expr_seq(buffer_expr);
51+
const auto& target_expr = shape_infer_seq.empty() ? buffer_expr : shape_infer_seq.back();
52+
const auto& buffer_out = target_expr->get_output_port_connector(0);
5153
for (const auto& child_expr_input : buffer_out->get_consumers()) {
5254
const auto& child_expr = child_expr_input.get_expr();
5355
const auto port = child_expr_input.get_index();

src/common/snippets/src/lowered/pass/assign_registers.cpp

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
#include "snippets/lowered/pass/assign_registers.hpp"
66

77
#include "snippets/lowered/linear_ir.hpp"
8-
#include "snippets/op/subgraph.hpp"
98
#include "snippets/snippets_isa.hpp"
109
#include "snippets/itt.hpp"
10+
#include "snippets/utils.hpp"
1111

1212
// This header is needed to avoid MSVC warning "C2039: 'inserter': is not a member of 'std'"
1313
#include <iterator>
@@ -82,20 +82,16 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
8282
manually_assigned_gprs[out_connector] = io_expr->get_index();
8383
// TODO [96434]: Support shape infer ops in arbitrary place in pipeline, not just after inputs
8484
// shape infer ops sequence after input
85-
auto shape_infer_consumers = LinearIR::get_child_shape_infer_expr_seq(io_expr);
86-
if (!shape_infer_consumers.empty()) {
87-
for (const auto& child_shape_infer_expr : shape_infer_consumers) {
88-
manually_assigned_gprs[child_shape_infer_expr->get_output_port_connector(0)] = io_expr->get_index();
89-
}
85+
const auto& shape_infer_consumers = utils::get_first_child_shape_infer_expr_seq(io_expr);
86+
for (const auto& child_shape_infer_expr : shape_infer_consumers) {
87+
manually_assigned_gprs[child_shape_infer_expr->get_output_port_connector(0)] = io_expr->get_index();
9088
}
9189
} else if (io_expr->get_type() == IOExpression::io_type::OUTPUT) {
9290
manually_assigned_gprs[expr->get_input_port_connector(0)] = num_parameters + io_expr->get_index();
9391
// shape infer ops sequence before result
94-
auto shape_infer_sources = LinearIR::get_parent_shape_infer_expr_seq(io_expr);
95-
if (!shape_infer_sources.empty()) {
96-
for (const auto& parent_shape_infer_expr : shape_infer_sources) {
97-
manually_assigned_gprs[parent_shape_infer_expr->get_input_port_connector(0)] = num_parameters + io_expr->get_index();
98-
}
92+
const auto& shape_infer_sources = utils::get_first_parent_shape_infer_expr_seq(io_expr);
93+
for (const auto& parent_shape_infer_expr : shape_infer_sources) {
94+
manually_assigned_gprs[parent_shape_infer_expr->get_input_port_connector(0)] = num_parameters + io_expr->get_index();
9995
}
10096
} else {
10197
OPENVINO_THROW("Unsupported io_type detected");
@@ -108,13 +104,11 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
108104
static_cast<Reg>(num_results + num_parameters + buffer_id);
109105
// shape infer ops in the middle of subgraph. IntermediateMemoryBuffer is inserted before reshape as new loop should start.
110106
// child shape info ops share the same memory as IntermediateMemoryBuffer.
111-
auto shape_infer_consumers = LinearIR::get_child_shape_infer_expr_seq(expr);
112-
if (!shape_infer_consumers.empty()) {
113-
for (const auto& child_shape_infer_expr : shape_infer_consumers) {
114-
manually_assigned_gprs[child_shape_infer_expr->get_input_port_connector(0)] =
115-
manually_assigned_gprs[child_shape_infer_expr->get_output_port_connector(0)] =
116-
static_cast<Reg>(num_results + num_parameters + buffer_id);
117-
}
107+
const auto& shape_infer_consumers = utils::get_first_child_shape_infer_expr_seq(expr);
108+
for (const auto& child_shape_infer_expr : shape_infer_consumers) {
109+
manually_assigned_gprs[child_shape_infer_expr->get_input_port_connector(0)] =
110+
manually_assigned_gprs[child_shape_infer_expr->get_output_port_connector(0)] =
111+
static_cast<Reg>(num_results + num_parameters + buffer_id);
118112
}
119113
}
120114
manually_assigned_gprs[expr->get_output_port_connector(0)] =

src/common/snippets/src/lowered/pass/insert_buffers.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,19 +148,17 @@ void InsertBuffers::insertion(LinearIR& linear_ir,
148148
const auto port_idx = entry_port->get_index();
149149
const auto node = expr->get_node();
150150
auto parent_expr_output = expr->get_input_port_connector(port_idx)->get_source();
151-
152-
const auto& first_parent_expr = parent_expr_output.get_expr();
151+
auto parent_expr = parent_expr_output.get_expr();
153152
bool has_shape_infer_parent = false;
154153
auto top_shape_infer_expr = expr;
155154
// parent before shape infer ops is used to determine if buffer needed according loopInfo
156-
auto shape_infer_parents = LinearIR::get_parent_shape_infer_expr_seq(first_parent_expr);
155+
const auto& shape_infer_parents = utils::get_first_parent_shape_infer_expr_seq(parent_expr);
157156
if (!shape_infer_parents.empty()) {
158157
parent_expr_output = shape_infer_parents.back()->get_input_port_connector(0)->get_source();
159158
has_shape_infer_parent = true;
160159
top_shape_infer_expr = shape_infer_parents.back();
160+
parent_expr = parent_expr_output.get_expr();
161161
}
162-
163-
const auto& parent_expr = parent_expr_output.get_expr();
164162
const auto& parent_port = parent_expr_output.get_index();
165163
const auto& parent = parent_expr->get_node();
166164
if (ov::is_type<op::Buffer>(parent) ||

src/common/snippets/src/lowered/pass/insert_load_store.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ size_t InsertLoadStore::get_count(const ExpressionPort& port) const {
3535
}
3636

3737
bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExprIt& data_expr_it) {
38-
std::shared_ptr<Expression> data_expr = *data_expr_it;
39-
data_expr = LinearIR::get_last_child_shape_infer_expr(data_expr);
38+
const auto& shape_infer_seq = utils::get_first_child_shape_infer_expr_seq(*data_expr_it);
39+
const std::shared_ptr<Expression>& data_expr = shape_infer_seq.empty() ? *data_expr_it : shape_infer_seq.back();
4040
const auto& data_ngraph_output = data_expr->get_node()->output(0);
4141
bool was_inserted = false;
4242
const auto& data_out = data_expr->get_output_port_connector(0);
@@ -56,8 +56,8 @@ bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExpr
5656
}
5757

5858
bool InsertLoadStore::insert_store(LinearIR& linear_ir, const LinearIR::constExprIt& data_expr_it) {
59-
auto data_expr = *data_expr_it;
60-
data_expr = LinearIR::get_last_parent_shape_infer_expr(data_expr);
59+
const auto& shape_infer_seq = utils::get_first_parent_shape_infer_expr_seq(*data_expr_it);
60+
const auto& data_expr = shape_infer_seq.empty() ? *data_expr_it : shape_infer_seq.back();
6161

6262
const auto& parent_output = data_expr->get_input_port_connector(0)->get_source();
6363
const auto& parent_expr = parent_output.get_expr();

src/common/snippets/src/lowered/pass/validate.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ void validate_ports(const ExpressionPtr& expr) {
3232
void validate_parameter(const ExpressionPtr& expr, const LinearIR& linear_ir) {
3333
OPENVINO_ASSERT(ov::is_type<ov::op::v0::Parameter>(expr->get_node()),
3434
"Parameter validation expects Parameter op");
35-
auto expr_val = LinearIR::get_last_child_shape_infer_expr(expr);
35+
const auto& shape_infer_seq = utils::get_first_child_shape_infer_expr_seq(expr);
36+
const auto& expr_val = shape_infer_seq.empty() ? expr : shape_infer_seq.back();
3637
auto consumer_inputs = expr_val->get_output_port_connector(0)->get_consumers();
3738
std::set<std::vector<size_t>> layouts;
3839
for (const auto& consumer_input : consumer_inputs) {
@@ -51,7 +52,8 @@ void validate_parameter(const ExpressionPtr& expr, const LinearIR& linear_ir) {
5152
void validate_result(const ExpressionPtr& expr, const LinearIR& linear_ir) {
5253
OPENVINO_ASSERT(ov::is_type<ov::op::v0::Result>(expr->get_node()),
5354
"Result validation expects Result op");
54-
auto expr_val = LinearIR::get_last_parent_shape_infer_expr(expr);
55+
const auto& shape_infer_seq = utils::get_first_parent_shape_infer_expr_seq(expr);
56+
const auto& expr_val = shape_infer_seq.empty() ? expr : shape_infer_seq.back();
5557
const auto source = expr_val->get_input_port_connector(0)->get_source();
5658
const auto ma = ov::as_type_ptr<snippets::op::MemoryAccess>(source.get_expr()->get_node());
5759
OPENVINO_ASSERT(ma && ma->is_memory_access_output_port(source.get_index()),
@@ -66,7 +68,8 @@ void validate_buffer(const ExpressionPtr& expr, const LinearIR& linear_ir) {
6668
const auto ma = ov::as_type_ptr<snippets::op::MemoryAccess>(source.get_expr()->get_node());
6769
OPENVINO_ASSERT(ma && ma->is_memory_access_input_port(source.get_index()),
6870
"Buffer expects MemoryAccess parent");
69-
auto expr_val = LinearIR::get_last_child_shape_infer_expr(expr);
71+
const auto& shape_infer_seq = utils::get_first_child_shape_infer_expr_seq(expr);
72+
const auto& expr_val = shape_infer_seq.empty() ? expr : shape_infer_seq.back();
7073
const auto& out = expr_val->get_output_port_connector(0);
7174
const auto consumers = out->get_consumers();
7275
for (const auto& consumer_input : consumers) {

src/common/snippets/src/op/subgraph.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ auto Subgraph::get_last_child_shape_infer_op(const std::shared_ptr<ov::Node>& op
9393
return last_op;
9494
auto consumers = last_op->get_output_target_inputs(0);
9595
auto first_child = consumers.begin()->get_node()->shared_from_this();
96-
while (op::Subgraph::is_shape_infer_op(first_child)) {
96+
while (is_shape_infer_op(first_child)) {
9797
OPENVINO_ASSERT(consumers.size() == 1, "Shape infer ops are supposed to be the only consumer.");
9898
last_op = first_child;
9999
if (last_op->get_output_size() == 0)
@@ -109,7 +109,7 @@ auto Subgraph::get_last_parent_shape_infer_op(const std::shared_ptr<ov::Node>& o
109109
if (last_op->get_input_size() == 0)
110110
return last_op;
111111
auto first_parent = last_op->get_input_node_shared_ptr(0);
112-
while (op::Subgraph::is_shape_infer_op(first_parent)) {
112+
while (is_shape_infer_op(first_parent)) {
113113
last_op = first_parent;
114114
if (last_op->get_input_size() == 0)
115115
break;

0 commit comments

Comments
 (0)