Skip to content

Commit

Permalink
Apply Vladislav comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Mar 25, 2024
1 parent 105387f commit 5cd9250
Show file tree
Hide file tree
Showing 13 changed files with 145 additions and 123 deletions.
28 changes: 20 additions & 8 deletions src/common/snippets/include/snippets/lowered/linear_ir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,20 +224,32 @@ class LinearIR {
exprIt replace_with_expr(const std::vector<ExpressionPtr>& old_exprs, const ExpressionPtr& new_expr);

/**
* @brief Get zero to several consecutive shape infer exprs(such as reshape, rankNormalization) from start_expr.
* @brief Get zero to several consecutive child shape infer exprs(such as reshape, rankNormalization) from start_expr.
* @param start_expr Collect from start_expr.
* @param downstream Collect downstream if it's true, otherwise collect upstream.
* @return shape infer op consumers as a sequence if downstream, or shape infer op sources as a sequence if upstream.
* @return shape infer expression consumers as a sequence.
*/
static std::vector<ExpressionPtr> get_shape_infer_expr_seq(const ExpressionPtr& start_expr, bool downstream);
static std::vector<ExpressionPtr> get_child_shape_infer_expr_seq(const ExpressionPtr& start_expr);

/**
* @brief Get last shape infer op from start_expr in a sequence. If no shape infer op is connect to start_expr, return start_expr.
* @brief Get zero to several consecutive parent shape infer exprs(such as reshape, rankNormalization) from start_expr.
* @param start_expr Collect from start_expr.
* @return shape infer expression sources as a sequence.
*/
static std::vector<ExpressionPtr> get_parent_shape_infer_expr_seq(const ExpressionPtr& start_expr);

/**
* @brief Get last child shape infer op from start_expr in a sequence. If no shape infer op is connect to start_expr, return start_expr.
* @param start_expr Search from start_expr.
* @return last child shape infer expr
*/
static ExpressionPtr get_last_child_shape_infer_expr(const ExpressionPtr& start_expr);

/**
* @brief Get last parent shape infer op from start_expr in a sequence. If no shape infer op is connect to start_expr, return start_expr.
* @param start_expr Search from start_expr.
* @param downstream search downstream if it's true, otherwise search upstream.
* @return last shape infer expr
* @return last parent shape infer expr
*/
static ExpressionPtr get_last_shape_infer_expr(const ExpressionPtr& start_expr, bool downstream);
static ExpressionPtr get_last_parent_shape_infer_expr(const ExpressionPtr& start_expr);

private:
std::shared_ptr<ShapeInferSnippetsNode> m_shape_infer = nullptr;
Expand Down
3 changes: 2 additions & 1 deletion src/common/snippets/include/snippets/op/subgraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ class Subgraph : public ov::op::util::SubGraphOp {
static auto get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t;
static auto is_domain_sensitive_op(const std::shared_ptr<ov::Node>& op) -> bool;
static auto is_shape_infer_op(const std::shared_ptr<ov::Node>& op) -> bool;
static auto get_last_shape_infer_op(const std::shared_ptr<ov::Node>& op, bool downstream) -> std::shared_ptr<ov::Node>;
static auto get_last_child_shape_infer_op(const std::shared_ptr<ov::Node>& op) -> std::shared_ptr<ov::Node>;
static auto get_last_parent_shape_infer_op(const std::shared_ptr<ov::Node>& op) -> std::shared_ptr<ov::Node>;

void data_flow_transformations(const BlockedShapeVector& blocked_input_shapes = {},
const std::vector<ov::element::Type>& input_precisions = {},
Expand Down
116 changes: 67 additions & 49 deletions src/common/snippets/src/lowered/linear_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ VectorDims LinearIR::get_master_shape() const {
if (!m_config.m_enable_domain_optimization && ov::is_type<snippets::op::Brgemm>(source.get_expr()->get_node())) {
master_shape = utils::get_preordered_vdims(source);
} else {
auto last_shape_infer_expr = LinearIR::get_last_shape_infer_expr(out_exprs[0], false);
auto last_shape_infer_expr = LinearIR::get_last_parent_shape_infer_expr(out_exprs[0]);
master_shape = utils::get_preordered_vdims(last_shape_infer_expr->get_input_port_connector(0)->get_source());
}
} else {
Expand Down Expand Up @@ -498,72 +498,90 @@ LinearIR::exprIt LinearIR::replace_with_expr(const std::vector<ExpressionPtr>& o
return replace_with_expr(old_exprs, new_expr, insertion_place);
}

std::vector<ExpressionPtr> LinearIR::get_shape_infer_expr_seq(const ExpressionPtr& start_expr, bool downstream) {
std::vector<ExpressionPtr> LinearIR::get_child_shape_infer_expr_seq(const ExpressionPtr& start_expr) {
std::vector<ExpressionPtr> shape_infer_exprs;
auto current_exp = start_expr;
if (op::Subgraph::is_shape_infer_op(current_exp->get_node())) {
OPENVINO_ASSERT(current_exp->get_input_port_connector(0)->get_consumers().size() == 1, "Shape infer ops are supposed to be the only consumer.");
shape_infer_exprs.push_back(current_exp);
}
if (downstream) {
if (current_exp->get_output_count() == 0)
return shape_infer_exprs;
auto output_consumers = current_exp->get_output_port_connector(0)->get_consumers();
auto first_child = output_consumers.begin()->get_expr();
while (op::Subgraph::is_shape_infer_op(first_child->get_node())) {
OPENVINO_ASSERT(output_consumers.size() == 1, "Shape infer ops are supposed to be the only consumer.");
shape_infer_exprs.push_back(first_child);
current_exp = first_child;
if (current_exp->get_output_count() == 0)
return shape_infer_exprs;
auto consumers = current_exp->get_output_port_connector(0)->get_consumers();
auto first_child = consumers.begin()->get_expr();
while (op::Subgraph::is_shape_infer_op(first_child->get_node())) {
OPENVINO_ASSERT(consumers.size() == 1, "Shape infer ops are supposed to be the only consumer.");
shape_infer_exprs.push_back(first_child);
current_exp = first_child;
if (current_exp->get_output_count() == 0)
break;
consumers = current_exp->get_output_port_connector(0)->get_consumers();
first_child = consumers.begin()->get_expr();
}
break;
output_consumers = current_exp->get_output_port_connector(0)->get_consumers();
first_child = output_consumers.begin()->get_expr();
}
return shape_infer_exprs;
}

std::vector<ExpressionPtr> LinearIR::get_parent_shape_infer_expr_seq(const ExpressionPtr& start_expr) {
std::vector<ExpressionPtr> shape_infer_exprs;
auto current_exp = start_expr;
if (op::Subgraph::is_shape_infer_op(current_exp->get_node())) {
OPENVINO_ASSERT(current_exp->get_input_port_connector(0)->get_consumers().size() == 1, "Shape infer ops are supposed to be the only consumer.");
shape_infer_exprs.push_back(current_exp);
}
if (current_exp->get_input_count() == 0)
return shape_infer_exprs;
} else {
// upstream
auto input = current_exp->get_input_port_connector(0);
auto first_parent = input->get_source().get_expr();
while (op::Subgraph::is_shape_infer_op(first_parent->get_node())) {
shape_infer_exprs.push_back(first_parent);
current_exp = first_parent;
if (current_exp->get_input_count() == 0)
return shape_infer_exprs;
auto first_source = current_exp->get_input_port_connector(0)->get_source().get_expr();
while (op::Subgraph::is_shape_infer_op(first_source->get_node())) {
shape_infer_exprs.push_back(first_source);
current_exp = first_source;
if (current_exp->get_input_count() == 0)
break;
first_source = current_exp->get_input_port_connector(0)->get_source().get_expr();
break;
input = current_exp->get_input_port_connector(0);
first_parent = input->get_source().get_expr();
if (!ov::is_type<snippets::op::Store>(first_parent->get_node())) {
// there are maybe some loopEnd consumers of store as well for loop code gen purpose
OPENVINO_ASSERT(input->get_consumers().size() == 1, "Shape infer ops are supposed to be the only consumer if it doesn't consume a store ops.");
}
return shape_infer_exprs;
}
return shape_infer_exprs;
}

ExpressionPtr LinearIR::get_last_shape_infer_expr(const ExpressionPtr& start_expr, bool downstream) {
ExpressionPtr LinearIR::get_last_child_shape_infer_expr(const ExpressionPtr& start_expr) {
auto last_exp = start_expr;
if (downstream) {
if (last_exp->get_output_count() == 0)
return last_exp;
auto consumers = last_exp->get_output_port_connector(0)->get_consumers();
auto first_child = consumers.begin()->get_expr();
while (op::Subgraph::is_shape_infer_op(first_child->get_node())) {
OPENVINO_ASSERT(consumers.size() == 1, "Shape infer ops are supposed to be the only consumer.");
last_exp = first_child;
if (last_exp->get_output_count() == 0)
return last_exp;
auto consumers = last_exp->get_output_port_connector(0)->get_consumers();
auto first_child = consumers.begin()->get_expr();
while (op::Subgraph::is_shape_infer_op(first_child->get_node())) {
OPENVINO_ASSERT(consumers.size() == 1, "Shape infer ops are supposed to be the only consumer.");
last_exp = first_child;
if (last_exp->get_output_count() == 0)
break;
consumers = last_exp->get_output_port_connector(0)->get_consumers();
first_child = consumers.begin()->get_expr();
}
break;
consumers = last_exp->get_output_port_connector(0)->get_consumers();
first_child = consumers.begin()->get_expr();
}
return last_exp;
}

ExpressionPtr LinearIR::get_last_parent_shape_infer_expr(const ExpressionPtr& start_expr) {
auto last_exp = start_expr;
if (last_exp->get_input_count() == 0)
return last_exp;
} else {
// upstream
auto input = last_exp->get_input_port_connector(0);
auto first_parent = input->get_source().get_expr();
while (op::Subgraph::is_shape_infer_op(first_parent->get_node())) {
last_exp = first_parent;
if (last_exp->get_input_count() == 0)
return last_exp;
auto first_source = last_exp->get_input_port_connector(0)->get_source().get_expr();
while (op::Subgraph::is_shape_infer_op(first_source->get_node())) {
last_exp = first_source;
if (last_exp->get_input_count() == 0)
break;
first_source = last_exp->get_input_port_connector(0)->get_source().get_expr();
break;
input = last_exp->get_input_port_connector(0);
first_parent = input->get_source().get_expr();
if (!ov::is_type<snippets::op::Store>(first_parent->get_node())) {
// there are maybe some loopEnd consumers of store as well for loop code gen purpose
OPENVINO_ASSERT(input->get_consumers().size() == 1, "Shape infer ops are supposed to be the only consumer if it doesn't consume a store ops.");
}
return last_exp;
}
return last_exp;
}

LinearIR::LIRShapeInfer::LIRShapeInfer(container& body_exprs, io_container& io_exprs)
Expand Down
2 changes: 1 addition & 1 deletion src/common/snippets/src/lowered/pass/allocate_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void AllocateBuffers::set_buffer_offset(const ExpressionPtr& buffer_expr, const
}
}
// Propagate to down: in Load. Buffer can have several Load
auto last_shape_infer = ov::snippets::lowered::LinearIR::get_last_shape_infer_expr(buffer_expr, true);
auto last_shape_infer = ov::snippets::lowered::LinearIR::get_last_child_shape_infer_expr(buffer_expr);
const auto& buffer_out = last_shape_infer->get_output_port_connector(0);
for (const auto& child_expr_input : buffer_out->get_consumers()) {
const auto& child_expr = child_expr_input.get_expr();
Expand Down
6 changes: 3 additions & 3 deletions src/common/snippets/src/lowered/pass/assign_registers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
manually_assigned_gprs[out_connector] = io_expr->get_index();
// TODO [96434]: Support shape infer ops in arbitrary place in pipeline, not just after inputs
// shape infer ops sequence after input
auto shape_infer_consumers = LinearIR::get_shape_infer_expr_seq(io_expr, true);
auto shape_infer_consumers = LinearIR::get_child_shape_infer_expr_seq(io_expr);
if (!shape_infer_consumers.empty()) {
for (const auto& child_shape_infer_expr : shape_infer_consumers) {
manually_assigned_gprs[child_shape_infer_expr->get_output_port_connector(0)] = io_expr->get_index();
Expand All @@ -91,7 +91,7 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
} else if (io_expr->get_type() == IOExpression::io_type::OUTPUT) {
manually_assigned_gprs[expr->get_input_port_connector(0)] = num_parameters + io_expr->get_index();
// shape infer ops sequence before result
auto shape_infer_sources = LinearIR::get_shape_infer_expr_seq(io_expr, false);
auto shape_infer_sources = LinearIR::get_parent_shape_infer_expr_seq(io_expr);
if (!shape_infer_sources.empty()) {
for (const auto& parent_shape_infer_expr : shape_infer_sources) {
manually_assigned_gprs[parent_shape_infer_expr->get_input_port_connector(0)] = num_parameters + io_expr->get_index();
Expand All @@ -108,7 +108,7 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
static_cast<Reg>(num_results + num_parameters + buffer_id);
// shape infer ops in the middle of subgraph. IntermediateMemoryBuffer is inserted before reshape as new loop should start.
// child shape info ops share the same memory as IntermediateMemoryBuffer.
auto shape_infer_consumers = LinearIR::get_shape_infer_expr_seq(expr, true);
auto shape_infer_consumers = LinearIR::get_child_shape_infer_expr_seq(expr);
if (!shape_infer_consumers.empty()) {
for (const auto& child_shape_infer_expr : shape_infer_consumers) {
manually_assigned_gprs[child_shape_infer_expr->get_input_port_connector(0)] =
Expand Down
2 changes: 1 addition & 1 deletion src/common/snippets/src/lowered/pass/insert_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ void InsertBuffers::insertion(LinearIR& linear_ir,
bool has_shape_infer_parent = false;
auto top_shape_infer_expr = expr;
// parent before shape infer ops is used to determine if buffer needed according loopInfo
auto shape_infer_parents = LinearIR::get_shape_infer_expr_seq(first_parent_expr, false);
auto shape_infer_parents = LinearIR::get_parent_shape_infer_expr_seq(first_parent_expr);
if (!shape_infer_parents.empty()) {
parent_expr_output = shape_infer_parents.back()->get_input_port_connector(0)->get_source();
has_shape_infer_parent = true;
Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/src/lowered/pass/insert_load_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ size_t InsertLoadStore::get_count(const ExpressionPort& port) const {

bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExprIt& data_expr_it) {
std::shared_ptr<Expression> data_expr = *data_expr_it;
data_expr = LinearIR::get_last_shape_infer_expr(data_expr, true);
data_expr = LinearIR::get_last_child_shape_infer_expr(data_expr);
const auto& data_ngraph_output = data_expr->get_node()->output(0);
bool was_inserted = false;
const auto& data_out = data_expr->get_output_port_connector(0);
Expand All @@ -57,7 +57,7 @@ bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExpr

bool InsertLoadStore::insert_store(LinearIR& linear_ir, const LinearIR::constExprIt& data_expr_it) {
auto data_expr = *data_expr_it;
data_expr = LinearIR::get_last_shape_infer_expr(data_expr, false);
data_expr = LinearIR::get_last_parent_shape_infer_expr(data_expr);

const auto& parent_output = data_expr->get_input_port_connector(0)->get_source();
const auto& parent_expr = parent_output.get_expr();
Expand Down
6 changes: 3 additions & 3 deletions src/common/snippets/src/lowered/pass/validate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ void validate_ports(const ExpressionPtr& expr) {
void validate_parameter(const ExpressionPtr& expr, const LinearIR& linear_ir) {
OPENVINO_ASSERT(ov::is_type<ov::op::v0::Parameter>(expr->get_node()),
"Parameter validation expects Parameter op");
auto expr_val = LinearIR::get_last_shape_infer_expr(expr, true);
auto expr_val = LinearIR::get_last_child_shape_infer_expr(expr);
auto consumer_inputs = expr_val->get_output_port_connector(0)->get_consumers();
std::set<std::vector<size_t>> layouts;
for (const auto& consumer_input : consumer_inputs) {
Expand All @@ -51,7 +51,7 @@ void validate_parameter(const ExpressionPtr& expr, const LinearIR& linear_ir) {
void validate_result(const ExpressionPtr& expr, const LinearIR& linear_ir) {
OPENVINO_ASSERT(ov::is_type<ov::op::v0::Result>(expr->get_node()),
"Result validation expects Result op");
auto expr_val = LinearIR::get_last_shape_infer_expr(expr, false);
auto expr_val = LinearIR::get_last_parent_shape_infer_expr(expr);
const auto source = expr_val->get_input_port_connector(0)->get_source();
const auto ma = ov::as_type_ptr<snippets::op::MemoryAccess>(source.get_expr()->get_node());
OPENVINO_ASSERT(ma && ma->is_memory_access_output_port(source.get_index()),
Expand All @@ -66,7 +66,7 @@ void validate_buffer(const ExpressionPtr& expr, const LinearIR& linear_ir) {
const auto ma = ov::as_type_ptr<snippets::op::MemoryAccess>(source.get_expr()->get_node());
OPENVINO_ASSERT(ma && ma->is_memory_access_input_port(source.get_index()),
"Buffer expects MemoryAccess parent");
auto expr_val = LinearIR::get_last_shape_infer_expr(expr, true);
auto expr_val = LinearIR::get_last_child_shape_infer_expr(expr);
const auto& out = expr_val->get_output_port_connector(0);
const auto consumers = out->get_consumers();
for (const auto& consumer_input : consumers) {
Expand Down
Loading

0 comments on commit 5cd9250

Please sign in to comment.