Skip to content

Commit

Permalink
Apply Ivan review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Mar 26, 2024
1 parent 5cd9250 commit aa221ae
Show file tree
Hide file tree
Showing 14 changed files with 121 additions and 164 deletions.
28 changes: 0 additions & 28 deletions src/common/snippets/include/snippets/lowered/linear_ir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,34 +223,6 @@ class LinearIR {
*/
exprIt replace_with_expr(const std::vector<ExpressionPtr>& old_exprs, const ExpressionPtr& new_expr);

/**
* @brief Get zero to several consecutive child shape infer exprs(such as reshape, rankNormalization) from start_expr.
* @param start_expr Collect from start_expr.
* @return shape infer expression consumers as a sequence.
*/
static std::vector<ExpressionPtr> get_child_shape_infer_expr_seq(const ExpressionPtr& start_expr);

/**
* @brief Get zero to several consecutive parent shape infer exprs(such as reshape, rankNormalization) from start_expr.
* @param start_expr Collect from start_expr.
* @return shape infer expression sources as a sequence.
*/
static std::vector<ExpressionPtr> get_parent_shape_infer_expr_seq(const ExpressionPtr& start_expr);

/**
* @brief Get last child shape infer op from start_expr in a sequence. If no shape infer op is connect to start_expr, return start_expr.
* @param start_expr Search from start_expr.
* @return last child shape infer expr
*/
static ExpressionPtr get_last_child_shape_infer_expr(const ExpressionPtr& start_expr);

/**
* @brief Get last parent shape infer op from start_expr in a sequence. If no shape infer op is connect to start_expr, return start_expr.
* @param start_expr Search from start_expr.
* @return last parent shape infer expr
*/
static ExpressionPtr get_last_parent_shape_infer_expr(const ExpressionPtr& start_expr);

private:
std::shared_ptr<ShapeInferSnippetsNode> m_shape_infer = nullptr;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ class ReduceShapeInfer : public IShapeInferSnippets {
};

class ReshapeShapeInfer : public IShapeInferSnippets {
ov::PartialShape target_shape;
VectorDims target_shape;
size_t target_shape_volume = 0;
public:
explicit ReshapeShapeInfer(const std::shared_ptr<Node>& n);
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
Expand Down
19 changes: 18 additions & 1 deletion src/common/snippets/include/snippets/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ VectorDims get_planar_vdims(const snippets::lowered::ExpressionPort& expr_port);
* @return preordered shape: `shape[i]` = `planar_shape[order[i]]` where `shape` is shape before applying the order.
*/
VectorDims get_preordered_vdims(const snippets::lowered::ExpressionPort& expr_port);
/* --------------------------- */

/**
* @brief Returns element count of a shape
* @param shape input shape
Expand All @@ -162,7 +164,22 @@ VectorDims get_preordered_vdims(const snippets::lowered::ExpressionPort& expr_po
inline auto get_shape_size(const VectorDims& shape) -> size_t {
return std::accumulate(shape.begin(), shape.end(), static_cast<size_t>(1), std::multiplies<size_t>());
}
/* --------------------------- */

/**
* @brief Get zero to several consecutive child shape infer exprs(such as reshape, rankNormalization) from start_expr.
* As Node maybe have multiple output. This functions return the first(left) legal sequence.
* @param start_expr Collect from start_expr.
* @return shape infer expression consumers as a sequence.
*/
std::vector<lowered::ExpressionPtr> get_first_child_shape_infer_expr_seq(const lowered::ExpressionPtr& start_expr);

/**
* @brief Get zero to several consecutive parent shape infer exprs(such as reshape, rankNormalization) from start_expr.
* As Node maybe have multiple input. This functions return the first(left) legal sequence.
* @param start_expr Collect from start_expr.
* @return shape infer expression sources as a sequence.
*/
std::vector<lowered::ExpressionPtr> get_first_parent_shape_infer_expr_seq(const lowered::ExpressionPtr& start_expr);

} // namespace utils
} // namespace snippets
Expand Down
91 changes: 3 additions & 88 deletions src/common/snippets/src/lowered/linear_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,9 @@ VectorDims LinearIR::get_master_shape() const {
if (!m_config.m_enable_domain_optimization && ov::is_type<snippets::op::Brgemm>(source.get_expr()->get_node())) {
master_shape = utils::get_preordered_vdims(source);
} else {
auto last_shape_infer_expr = LinearIR::get_last_parent_shape_infer_expr(out_exprs[0]);
master_shape = utils::get_preordered_vdims(last_shape_infer_expr->get_input_port_connector(0)->get_source());
const auto& shape_infer_seq = utils::get_first_parent_shape_infer_expr_seq(out_exprs[0]);
const auto& expr = shape_infer_seq.empty() ? out_exprs[0] : shape_infer_seq.back();
master_shape = utils::get_preordered_vdims(expr->get_input_port_connector(0)->get_source());
}
} else {
for (const auto& oe : out_exprs) {
Expand Down Expand Up @@ -498,92 +499,6 @@ LinearIR::exprIt LinearIR::replace_with_expr(const std::vector<ExpressionPtr>& o
return replace_with_expr(old_exprs, new_expr, insertion_place);
}

std::vector<ExpressionPtr> LinearIR::get_child_shape_infer_expr_seq(const ExpressionPtr& start_expr) {
std::vector<ExpressionPtr> shape_infer_exprs;
auto current_exp = start_expr;
if (op::Subgraph::is_shape_infer_op(current_exp->get_node())) {
OPENVINO_ASSERT(current_exp->get_input_port_connector(0)->get_consumers().size() == 1, "Shape infer ops are supposed to be the only consumer.");
shape_infer_exprs.push_back(current_exp);
}
if (current_exp->get_output_count() == 0)
return shape_infer_exprs;
auto output_consumers = current_exp->get_output_port_connector(0)->get_consumers();
auto first_child = output_consumers.begin()->get_expr();
while (op::Subgraph::is_shape_infer_op(first_child->get_node())) {
OPENVINO_ASSERT(output_consumers.size() == 1, "Shape infer ops are supposed to be the only consumer.");
shape_infer_exprs.push_back(first_child);
current_exp = first_child;
if (current_exp->get_output_count() == 0)
break;
output_consumers = current_exp->get_output_port_connector(0)->get_consumers();
first_child = output_consumers.begin()->get_expr();
}
return shape_infer_exprs;
}

std::vector<ExpressionPtr> LinearIR::get_parent_shape_infer_expr_seq(const ExpressionPtr& start_expr) {
std::vector<ExpressionPtr> shape_infer_exprs;
auto current_exp = start_expr;
if (op::Subgraph::is_shape_infer_op(current_exp->get_node())) {
OPENVINO_ASSERT(current_exp->get_input_port_connector(0)->get_consumers().size() == 1, "Shape infer ops are supposed to be the only consumer.");
shape_infer_exprs.push_back(current_exp);
}
if (current_exp->get_input_count() == 0)
return shape_infer_exprs;
auto input = current_exp->get_input_port_connector(0);
auto first_parent = input->get_source().get_expr();
while (op::Subgraph::is_shape_infer_op(first_parent->get_node())) {
shape_infer_exprs.push_back(first_parent);
current_exp = first_parent;
if (current_exp->get_input_count() == 0)
break;
input = current_exp->get_input_port_connector(0);
first_parent = input->get_source().get_expr();
if (!ov::is_type<snippets::op::Store>(first_parent->get_node())) {
// there are maybe some loopEnd consumers of store as well for loop code gen purpose
OPENVINO_ASSERT(input->get_consumers().size() == 1, "Shape infer ops are supposed to be the only consumer if it doesn't consume a store ops.");
}
}
return shape_infer_exprs;
}

ExpressionPtr LinearIR::get_last_child_shape_infer_expr(const ExpressionPtr& start_expr) {
auto last_exp = start_expr;
if (last_exp->get_output_count() == 0)
return last_exp;
auto consumers = last_exp->get_output_port_connector(0)->get_consumers();
auto first_child = consumers.begin()->get_expr();
while (op::Subgraph::is_shape_infer_op(first_child->get_node())) {
OPENVINO_ASSERT(consumers.size() == 1, "Shape infer ops are supposed to be the only consumer.");
last_exp = first_child;
if (last_exp->get_output_count() == 0)
break;
consumers = last_exp->get_output_port_connector(0)->get_consumers();
first_child = consumers.begin()->get_expr();
}
return last_exp;
}

ExpressionPtr LinearIR::get_last_parent_shape_infer_expr(const ExpressionPtr& start_expr) {
auto last_exp = start_expr;
if (last_exp->get_input_count() == 0)
return last_exp;
auto input = last_exp->get_input_port_connector(0);
auto first_parent = input->get_source().get_expr();
while (op::Subgraph::is_shape_infer_op(first_parent->get_node())) {
last_exp = first_parent;
if (last_exp->get_input_count() == 0)
break;
input = last_exp->get_input_port_connector(0);
first_parent = input->get_source().get_expr();
if (!ov::is_type<snippets::op::Store>(first_parent->get_node())) {
// there are maybe some loopEnd consumers of store as well for loop code gen purpose
OPENVINO_ASSERT(input->get_consumers().size() == 1, "Shape infer ops are supposed to be the only consumer if it doesn't consume a store ops.");
}
}
return last_exp;
}

LinearIR::LIRShapeInfer::LIRShapeInfer(container& body_exprs, io_container& io_exprs)
: ShapeInferSnippetsNode(),
m_exprs{std::make_shared<container>(body_exprs)} {
Expand Down
6 changes: 4 additions & 2 deletions src/common/snippets/src/lowered/pass/allocate_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "snippets/lowered/pass/normalize_buffer_ids.hpp"
#include "snippets/pass/tokenization.hpp"
#include "snippets/itt.hpp"
#include "snippets/utils.hpp"

namespace ov {
namespace snippets {
Expand Down Expand Up @@ -46,8 +47,9 @@ void AllocateBuffers::set_buffer_offset(const ExpressionPtr& buffer_expr, const
}
}
// Propagate to down: in Load. Buffer can have several Load
auto last_shape_infer = ov::snippets::lowered::LinearIR::get_last_child_shape_infer_expr(buffer_expr);
const auto& buffer_out = last_shape_infer->get_output_port_connector(0);
const auto& shape_infer_seq = utils::get_first_child_shape_infer_expr_seq(buffer_expr);
const auto& target_expr = shape_infer_seq.empty() ? buffer_expr : shape_infer_seq.back();
const auto& buffer_out = target_expr->get_output_port_connector(0);
for (const auto& child_expr_input : buffer_out->get_consumers()) {
const auto& child_expr = child_expr_input.get_expr();
const auto port = child_expr_input.get_index();
Expand Down
30 changes: 12 additions & 18 deletions src/common/snippets/src/lowered/pass/assign_registers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
#include "snippets/lowered/pass/assign_registers.hpp"

#include "snippets/lowered/linear_ir.hpp"
#include "snippets/op/subgraph.hpp"
#include "snippets/snippets_isa.hpp"
#include "snippets/itt.hpp"
#include "snippets/utils.hpp"

// This header is needed to avoid MSVC warning "C2039: 'inserter': is not a member of 'std'"
#include <iterator>
Expand Down Expand Up @@ -82,20 +82,16 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
manually_assigned_gprs[out_connector] = io_expr->get_index();
// TODO [96434]: Support shape infer ops in arbitrary place in pipeline, not just after inputs
// shape infer ops sequence after input
auto shape_infer_consumers = LinearIR::get_child_shape_infer_expr_seq(io_expr);
if (!shape_infer_consumers.empty()) {
for (const auto& child_shape_infer_expr : shape_infer_consumers) {
manually_assigned_gprs[child_shape_infer_expr->get_output_port_connector(0)] = io_expr->get_index();
}
const auto& shape_infer_consumers = utils::get_first_child_shape_infer_expr_seq(io_expr);
for (const auto& child_shape_infer_expr : shape_infer_consumers) {
manually_assigned_gprs[child_shape_infer_expr->get_output_port_connector(0)] = io_expr->get_index();
}
} else if (io_expr->get_type() == IOExpression::io_type::OUTPUT) {
manually_assigned_gprs[expr->get_input_port_connector(0)] = num_parameters + io_expr->get_index();
// shape infer ops sequence before result
auto shape_infer_sources = LinearIR::get_parent_shape_infer_expr_seq(io_expr);
if (!shape_infer_sources.empty()) {
for (const auto& parent_shape_infer_expr : shape_infer_sources) {
manually_assigned_gprs[parent_shape_infer_expr->get_input_port_connector(0)] = num_parameters + io_expr->get_index();
}
const auto& shape_infer_sources = utils::get_first_parent_shape_infer_expr_seq(io_expr);
for (const auto& parent_shape_infer_expr : shape_infer_sources) {
manually_assigned_gprs[parent_shape_infer_expr->get_input_port_connector(0)] = num_parameters + io_expr->get_index();
}
} else {
OPENVINO_THROW("Unsupported io_type detected");
Expand All @@ -108,13 +104,11 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
static_cast<Reg>(num_results + num_parameters + buffer_id);
// shape infer ops in the middle of subgraph. IntermediateMemoryBuffer is inserted before reshape as new loop should start.
// child shape info ops share the same memory as IntermediateMemoryBuffer.
auto shape_infer_consumers = LinearIR::get_child_shape_infer_expr_seq(expr);
if (!shape_infer_consumers.empty()) {
for (const auto& child_shape_infer_expr : shape_infer_consumers) {
manually_assigned_gprs[child_shape_infer_expr->get_input_port_connector(0)] =
manually_assigned_gprs[child_shape_infer_expr->get_output_port_connector(0)] =
static_cast<Reg>(num_results + num_parameters + buffer_id);
}
const auto& shape_infer_consumers = utils::get_first_child_shape_infer_expr_seq(expr);
for (const auto& child_shape_infer_expr : shape_infer_consumers) {
manually_assigned_gprs[child_shape_infer_expr->get_input_port_connector(0)] =
manually_assigned_gprs[child_shape_infer_expr->get_output_port_connector(0)] =
static_cast<Reg>(num_results + num_parameters + buffer_id);
}
}
manually_assigned_gprs[expr->get_output_port_connector(0)] =
Expand Down
8 changes: 3 additions & 5 deletions src/common/snippets/src/lowered/pass/insert_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,17 @@ void InsertBuffers::insertion(LinearIR& linear_ir,
const auto port_idx = entry_port->get_index();
const auto node = expr->get_node();
auto parent_expr_output = expr->get_input_port_connector(port_idx)->get_source();

const auto& first_parent_expr = parent_expr_output.get_expr();
auto parent_expr = parent_expr_output.get_expr();
bool has_shape_infer_parent = false;
auto top_shape_infer_expr = expr;
// parent before shape infer ops is used to determine if buffer needed according loopInfo
auto shape_infer_parents = LinearIR::get_parent_shape_infer_expr_seq(first_parent_expr);
const auto& shape_infer_parents = utils::get_first_parent_shape_infer_expr_seq(parent_expr);
if (!shape_infer_parents.empty()) {
parent_expr_output = shape_infer_parents.back()->get_input_port_connector(0)->get_source();
has_shape_infer_parent = true;
top_shape_infer_expr = shape_infer_parents.back();
parent_expr = parent_expr_output.get_expr();
}

const auto& parent_expr = parent_expr_output.get_expr();
const auto& parent_port = parent_expr_output.get_index();
const auto& parent = parent_expr->get_node();
if (ov::is_type<op::Buffer>(parent) ||
Expand Down
8 changes: 4 additions & 4 deletions src/common/snippets/src/lowered/pass/insert_load_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ size_t InsertLoadStore::get_count(const ExpressionPort& port) const {
}

bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExprIt& data_expr_it) {
std::shared_ptr<Expression> data_expr = *data_expr_it;
data_expr = LinearIR::get_last_child_shape_infer_expr(data_expr);
const auto& shape_infer_seq = utils::get_first_child_shape_infer_expr_seq(*data_expr_it);
const std::shared_ptr<Expression>& data_expr = shape_infer_seq.empty() ? *data_expr_it : shape_infer_seq.back();
const auto& data_ngraph_output = data_expr->get_node()->output(0);
bool was_inserted = false;
const auto& data_out = data_expr->get_output_port_connector(0);
Expand All @@ -56,8 +56,8 @@ bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExpr
}

bool InsertLoadStore::insert_store(LinearIR& linear_ir, const LinearIR::constExprIt& data_expr_it) {
auto data_expr = *data_expr_it;
data_expr = LinearIR::get_last_parent_shape_infer_expr(data_expr);
const auto& shape_infer_seq = utils::get_first_parent_shape_infer_expr_seq(*data_expr_it);
const auto& data_expr = shape_infer_seq.empty() ? *data_expr_it : shape_infer_seq.back();

const auto& parent_output = data_expr->get_input_port_connector(0)->get_source();
const auto& parent_expr = parent_output.get_expr();
Expand Down
9 changes: 6 additions & 3 deletions src/common/snippets/src/lowered/pass/validate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ void validate_ports(const ExpressionPtr& expr) {
void validate_parameter(const ExpressionPtr& expr, const LinearIR& linear_ir) {
OPENVINO_ASSERT(ov::is_type<ov::op::v0::Parameter>(expr->get_node()),
"Parameter validation expects Parameter op");
auto expr_val = LinearIR::get_last_child_shape_infer_expr(expr);
const auto& shape_infer_seq = utils::get_first_child_shape_infer_expr_seq(expr);
const auto& expr_val = shape_infer_seq.empty() ? expr : shape_infer_seq.back();
auto consumer_inputs = expr_val->get_output_port_connector(0)->get_consumers();
std::set<std::vector<size_t>> layouts;
for (const auto& consumer_input : consumer_inputs) {
Expand All @@ -51,7 +52,8 @@ void validate_parameter(const ExpressionPtr& expr, const LinearIR& linear_ir) {
void validate_result(const ExpressionPtr& expr, const LinearIR& linear_ir) {
OPENVINO_ASSERT(ov::is_type<ov::op::v0::Result>(expr->get_node()),
"Result validation expects Result op");
auto expr_val = LinearIR::get_last_parent_shape_infer_expr(expr);
const auto& shape_infer_seq = utils::get_first_parent_shape_infer_expr_seq(expr);
const auto& expr_val = shape_infer_seq.empty() ? expr : shape_infer_seq.back();
const auto source = expr_val->get_input_port_connector(0)->get_source();
const auto ma = ov::as_type_ptr<snippets::op::MemoryAccess>(source.get_expr()->get_node());
OPENVINO_ASSERT(ma && ma->is_memory_access_output_port(source.get_index()),
Expand All @@ -66,7 +68,8 @@ void validate_buffer(const ExpressionPtr& expr, const LinearIR& linear_ir) {
const auto ma = ov::as_type_ptr<snippets::op::MemoryAccess>(source.get_expr()->get_node());
OPENVINO_ASSERT(ma && ma->is_memory_access_input_port(source.get_index()),
"Buffer expects MemoryAccess parent");
auto expr_val = LinearIR::get_last_child_shape_infer_expr(expr);
const auto& shape_infer_seq = utils::get_first_child_shape_infer_expr_seq(expr);
const auto& expr_val = shape_infer_seq.empty() ? expr : shape_infer_seq.back();
const auto& out = expr_val->get_output_port_connector(0);
const auto consumers = out->get_consumers();
for (const auto& consumer_input : consumers) {
Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/src/op/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ auto Subgraph::get_last_child_shape_infer_op(const std::shared_ptr<ov::Node>& op
return last_op;
auto consumers = last_op->get_output_target_inputs(0);
auto first_child = consumers.begin()->get_node()->shared_from_this();
while (op::Subgraph::is_shape_infer_op(first_child)) {
while (is_shape_infer_op(first_child)) {
OPENVINO_ASSERT(consumers.size() == 1, "Shape infer ops are supposed to be the only consumer.");
last_op = first_child;
if (last_op->get_output_size() == 0)
Expand All @@ -109,7 +109,7 @@ auto Subgraph::get_last_parent_shape_infer_op(const std::shared_ptr<ov::Node>& o
if (last_op->get_input_size() == 0)
return last_op;
auto first_parent = last_op->get_input_node_shared_ptr(0);
while (op::Subgraph::is_shape_infer_op(first_parent)) {
while (is_shape_infer_op(first_parent)) {
last_op = first_parent;
if (last_op->get_input_size() == 0)
break;
Expand Down
Loading

0 comments on commit aa221ae

Please sign in to comment.