Skip to content

Commit

Permalink
apply Alexandra comments conti
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Mar 18, 2024
1 parent 2e0c476 commit c810ef1
Show file tree
Hide file tree
Showing 20 changed files with 183 additions and 155 deletions.
8 changes: 8 additions & 0 deletions src/common/snippets/include/snippets/lowered/linear_ir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,14 @@ class LinearIR {
*/
exprIt replace_with_expr(const std::vector<ExpressionPtr>& old_exprs, const ExpressionPtr& new_expr);

/**
* @brief Propagate start_expr through zero to several consecutive shape infer exprs(such as reshape, rankNormalization).
* @param start_expr Propagate from start_expr.
* @param downstream Propagate downstream if it's true, otherwise propagate upstream.
* @return shape infer op consumers as a sequence if downstream, or shape infer op sources as a sequence if upstream.
*/
static std::vector<ExpressionPtr> propagate_expr_through_shape_infer_ops(const ExpressionPtr& start_expr, bool downstream);

private:
std::shared_ptr<ShapeInferSnippetsNode> m_shape_infer = nullptr;

Expand Down
1 change: 1 addition & 0 deletions src/common/snippets/include/snippets/op/subgraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class Subgraph : public ov::op::util::SubGraphOp {
// Return estimated unique buffer count (upper bound). It's needed for tokenization
static auto get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t;
static auto is_domain_sensitive_op(const std::shared_ptr<ov::Node>& op) -> bool;
static auto is_shape_infer_op(const std::shared_ptr<ov::Node>& op) -> bool;

void data_flow_transformations(const BlockedShapeVector& blocked_input_shapes = {},
const std::vector<ov::element::Type>& input_precisions = {},
Expand Down
3 changes: 3 additions & 0 deletions src/common/snippets/include/snippets/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ VectorDims get_planar_vdims(const snippets::lowered::ExpressionPort& expr_port);
* @return preordered shape: `shape[i]` = `planar_shape[order[i]]` where `shape` is shape before applying the order.
*/
VectorDims get_preordered_vdims(const snippets::lowered::ExpressionPort& expr_port);
inline auto get_shape_size(const VectorDims& shape) -> size_t {
return std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1), std::multiplies<size_t>());
}
/* --------------------------- */

} // namespace utils
Expand Down
38 changes: 38 additions & 0 deletions src/common/snippets/src/lowered/linear_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "openvino/core/graph_util.hpp"
#include "openvino/core/type.hpp"
#include "snippets/utils.hpp"
#include "snippets/op/subgraph.hpp"

namespace ov {
namespace snippets {
Expand Down Expand Up @@ -496,6 +497,43 @@ LinearIR::exprIt LinearIR::replace_with_expr(const std::vector<ExpressionPtr>& o
return replace_with_expr(old_exprs, new_expr, insertion_place);
}

std::vector<ExpressionPtr> LinearIR::propagate_expr_through_shape_infer_ops(const ExpressionPtr& start_expr, bool downstream) {
std::vector<ExpressionPtr> shape_infer_exprs;
auto current_exp = start_expr;
if (op::Subgraph::is_shape_infer_op(current_exp->get_node())) {
shape_infer_exprs.push_back(current_exp);
}
if (downstream) {
if (current_exp->get_output_count() == 0)
return shape_infer_exprs;
auto consumers = current_exp->get_output_port_connector(0)->get_consumers();
auto first_child = consumers.begin()->get_expr();
while (op::Subgraph::is_shape_infer_op(first_child->get_node())) {
OPENVINO_ASSERT(consumers.size() == 1, "Shape infer ops are supposed to be the only consumer.");
shape_infer_exprs.push_back(first_child);
current_exp = first_child;
if (current_exp->get_output_count() == 0)
break;
auto consumers = current_exp->get_output_port_connector(0)->get_consumers();
first_child = consumers.begin()->get_expr();
}
return shape_infer_exprs;
} else {
// upstream
if (current_exp->get_input_count() == 0)
return shape_infer_exprs;
auto first_source = current_exp->get_input_port_connector(0)->get_source().get_expr();
while (op::Subgraph::is_shape_infer_op(first_source->get_node())) {
shape_infer_exprs.push_back(first_source);
current_exp = first_source;
if (current_exp->get_input_count() == 0)
break;
first_source = current_exp->get_input_port_connector(0)->get_source().get_expr();
}
return shape_infer_exprs;
}
}

LinearIR::LIRShapeInfer::LIRShapeInfer(container& body_exprs, io_container& io_exprs)
: ShapeInferSnippetsNode(),
m_exprs{std::make_shared<container>(body_exprs)} {
Expand Down
7 changes: 4 additions & 3 deletions src/common/snippets/src/lowered/pass/allocate_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,13 @@ void AllocateBuffers::set_buffer_offset(const ExpressionPtr& buffer_expr, const
auto memory_access = ov::as_type_ptr<ov::snippets::op::MemoryAccess>(child_node);
if (memory_access && memory_access->is_memory_access_input_port(port)) {
memory_access->set_input_offset(offset, port);
} else if (ov::is_type<op::LoopEnd>(child_node)) {
} else if (ov::is_type<op::LoopEnd>(child_node) || op::Subgraph::is_shape_infer_op(child_node)) {
// After Loop initialization, Buffer can be connected to LoopEnd - it's ok
// There are also buffer before shape-changing ops
continue;
} else {
// OPENVINO_THROW(
// "Buffer::set_offset() was called when Buffer didn't have the corresponding MemoryAccess op for offset propagation");
OPENVINO_THROW(
"Buffer::set_offset() was called when Buffer didn't have the corresponding MemoryAccess op for offset propagation");
}
}
}
Expand Down
44 changes: 23 additions & 21 deletions src/common/snippets/src/lowered/pass/assign_registers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "snippets/lowered/pass/assign_registers.hpp"

#include "snippets/lowered/linear_ir.hpp"
#include "snippets/op/subgraph.hpp"
#include "snippets/snippets_isa.hpp"
#include "snippets/itt.hpp"

Expand Down Expand Up @@ -79,23 +80,22 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
if (io_expr->get_type() == IOExpression::io_type::INPUT) {
const auto& out_connector = expr->get_output_port_connector(0);
manually_assigned_gprs[out_connector] = io_expr->get_index();
// TODO [96434]: Support RankNormalization/Reshape in arbitrary place in pipeline, not just after inputs
// reshape rankNormalization sequence
auto consumer_inputs = out_connector->get_consumers();
auto child_exp = consumer_inputs.begin()->get_expr();
while (ov::is_type<op::RankNormalization>(child_exp->get_node()) ||
ov::is_type<op::Reshape>(child_exp->get_node())) {
OPENVINO_ASSERT(consumer_inputs.size() == 1, "RankNormalization or Reshape is supposed to be the only consumer");
manually_assigned_gprs[child_exp->get_output_port_connector(0)] = io_expr->get_index();
consumer_inputs = child_exp->get_output_port_connector(0)->get_consumers();
child_exp = consumer_inputs.begin()->get_expr();
// TODO [96434]: Support shape infer ops in arbitrary place in pipeline, not just after inputs
// shape infer ops sequence after input
auto shape_infer_consumers = LinearIR::propagate_expr_through_shape_infer_ops(io_expr, true);
if (!shape_infer_consumers.empty()) {
for (const auto& child_shape_infer_expr : shape_infer_consumers) {
manually_assigned_gprs[child_shape_infer_expr->get_output_port_connector(0)] = io_expr->get_index();
}
}
} else if (io_expr->get_type() == IOExpression::io_type::OUTPUT) {
manually_assigned_gprs[expr->get_input_port_connector(0)] = num_parameters + io_expr->get_index();
// reshape before result
const auto &parent = expr->get_input_port_connector(0)->get_source().get_expr();
if (ov::is_type<op::Reshape>(parent->get_node())) {
manually_assigned_gprs[parent->get_input_port_connector(0)] = num_parameters + io_expr->get_index();
// shape infer ops sequence before result
auto shape_infer_sources = LinearIR::propagate_expr_through_shape_infer_ops(io_expr, false);
if (!shape_infer_sources.empty()) {
for (const auto& parent_shape_infer_expr : shape_infer_sources) {
manually_assigned_gprs[parent_shape_infer_expr->get_input_port_connector(0)] = num_parameters + io_expr->get_index();
}
}
} else {
OPENVINO_THROW("Unsupported io_type detected");
Expand All @@ -106,13 +106,15 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
if (ov::is_type<op::IntermediateMemoryBuffer>(buffer)) {
manually_assigned_gprs[expr->get_input_port_connector(0)] =
static_cast<Reg>(num_results + num_parameters + buffer_id);
// reshape in the middle of subgraph. IntermediateMemoryBuffer is inserted before reshape as new loop should start.
const auto& first_consumer = expr->get_output_port_connector(0)->get_consumers().begin()->get_expr();
if (ov::is_type<op::Reshape>(first_consumer->get_node())) {
manually_assigned_gprs[first_consumer->get_input_port_connector(0)] =
static_cast<Reg>(num_results + num_parameters + buffer_id);
manually_assigned_gprs[first_consumer->get_output_port_connector(0)] =
static_cast<Reg>(num_results + num_parameters + buffer_id);
// shape infer ops in the middle of subgraph. IntermediateMemoryBuffer is inserted before reshape as new loop should start.
// child shape info ops share the same memory as IntermediateMemoryBuffer.
auto shape_infer_consumers = LinearIR::propagate_expr_through_shape_infer_ops(expr, true);
if (!shape_infer_consumers.empty()) {
for (const auto& child_shape_infer_expr : shape_infer_consumers) {
manually_assigned_gprs[child_shape_infer_expr->get_input_port_connector(0)] =
manually_assigned_gprs[child_shape_infer_expr->get_output_port_connector(0)] =
static_cast<Reg>(num_results + num_parameters + buffer_id);
}
}
}
manually_assigned_gprs[expr->get_output_port_connector(0)] =
Expand Down
33 changes: 13 additions & 20 deletions src/common/snippets/src/lowered/pass/insert_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,17 @@ void InsertBuffers::insertion(LinearIR& linear_ir,
const auto node = expr->get_node();
auto parent_expr_output = expr->get_input_port_connector(port_idx)->get_source();

auto first_not_reshape_parent_output = [&]() {
auto parent_expr = parent_expr_output.get_expr();
while (is_type<op::Reshape>(parent_expr->get_node())) {
parent_expr_output = parent_expr->get_input_port_connector(0)->get_source();
parent_expr = parent_expr_output.get_expr();
}
};
// this parent(before reshape) is used to determine if buffer needed according loopInfo
first_not_reshape_parent_output();
const auto& first_parent_expr = parent_expr_output.get_expr();
bool has_shape_infer_parent = false;
auto top_shape_infer_expr = expr;
// parent before shape infer ops is used to determine if buffer needed according loopInfo
auto shape_infer_parents = LinearIR::propagate_expr_through_shape_infer_ops(first_parent_expr, false);
if (!shape_infer_parents.empty()) {
parent_expr_output = shape_infer_parents.back()->get_input_port_connector(0)->get_source();
has_shape_infer_parent = true;
top_shape_infer_expr = shape_infer_parents.back();
}

const auto& parent_expr = parent_expr_output.get_expr();
const auto& parent_port = parent_expr_output.get_index();
const auto& parent = parent_expr->get_node();
Expand All @@ -167,15 +169,6 @@ void InsertBuffers::insertion(LinearIR& linear_ir,
ov::is_type<ov::op::v0::Constant>(parent))
continue;

// insert buffer before reshape
auto buffer_child = expr;
bool parent_is_reshape = false;
auto p_exp = expr->get_input_port_connector(port_idx)->get_source().get_expr();
if (is_type<op::Reshape>(p_exp->get_node())) {
buffer_child = p_exp;
parent_is_reshape = true;
}

// Each MemoryAccess op needs Buffer
const auto parent_ma = ov::as_type_ptr<op::MemoryAccess>(parent);
const auto node_ma = ov::as_type_ptr<op::MemoryAccess>(node);
Expand All @@ -197,9 +190,9 @@ void InsertBuffers::insertion(LinearIR& linear_ir,
parent_expr_output,
m_buffer_allocation_rank);
const auto buffer = std::make_shared<op::IntermediateMemoryBuffer>(parent->output(parent_port), allocation_shape);
if (parent_is_reshape) {
if (has_shape_infer_parent) {
linear_ir.insert_node(buffer, std::vector<ExpressionPort>{ parent_expr_output }, buffer_loop_ids, false, pos,
{ buffer_child->get_input_port(0) });
{ top_shape_infer_expr->get_input_port(0) });
} else {
linear_ir.insert_node(buffer, std::vector<ExpressionPort>{ parent_expr_output }, buffer_loop_ids, false, pos, { *entry_port });
}
Expand Down
39 changes: 12 additions & 27 deletions src/common/snippets/src/lowered/pass/insert_load_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,9 @@ size_t InsertLoadStore::get_count(const ExpressionPort& port) const {

bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExprIt& data_expr_it) {
std::shared_ptr<Expression> data_expr = *data_expr_it;
const auto& consumer_inputs = data_expr->get_output_port_connector(0)->get_consumers();
auto first_reshape_consumer = [&]() {
auto current_exp = data_expr;
auto first_consumer = consumer_inputs.begin()->get_expr();
while (1) {
if (is_type<op::RankNormalization>(first_consumer->get_node()) ||
is_type<op::Reshape>(first_consumer->get_node())) {
current_exp = first_consumer;
first_consumer = first_consumer->get_output_port_connector(0)->get_consumers().begin()->get_expr();
// OPENVINO_ASSERT(current_exp->get_output_port_connector(0)->get_consumers().size() == 1,
// "RankNormalization or Reshape is supposed to be the only consumer");
} else {
return current_exp;
}
}
};
data_expr = first_reshape_consumer();
auto shape_infer_consumers = LinearIR::propagate_expr_through_shape_infer_ops(data_expr, true);
if (!shape_infer_consumers.empty())
data_expr = shape_infer_consumers.back();

const auto& data_ngraph_output = data_expr->get_node()->output(0);
bool was_inserted = false;
Expand All @@ -74,16 +60,15 @@ bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExpr

bool InsertLoadStore::insert_store(LinearIR& linear_ir, const LinearIR::constExprIt& data_expr_it) {
auto data_expr = *data_expr_it;
auto parent_output = data_expr->get_input_port_connector(0)->get_source();
auto parent_expr = parent_output.get_expr();
if (is_type<op::Reshape>(parent_expr->get_node())) {
data_expr = parent_expr;
parent_output = data_expr->get_input_port_connector(0)->get_source();
parent_expr = parent_output.get_expr();
}
auto port = parent_output.get_index();
auto parent = parent_expr->get_node();
auto ma = ov::as_type_ptr<op::MemoryAccess>(parent);
auto shape_infer_consumers = LinearIR::propagate_expr_through_shape_infer_ops(data_expr, false);
if (!shape_infer_consumers.empty())
data_expr = shape_infer_consumers.back();

const auto& parent_output = data_expr->get_input_port_connector(0)->get_source();
const auto& parent_expr = parent_output.get_expr();
const auto port = parent_output.get_index();
const auto& parent = parent_expr->get_node();
const auto ma = ov::as_type_ptr<op::MemoryAccess>(parent);
if (ma && ma->is_memory_access_output_port(port))
return false;

Expand Down
19 changes: 10 additions & 9 deletions src/common/snippets/src/lowered/pass/validate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,9 @@ void validate_ports(const ExpressionPtr& expr) {
void validate_parameter(const ExpressionPtr& expr, const LinearIR& linear_ir) {
OPENVINO_ASSERT(ov::is_type<ov::op::v0::Parameter>(expr->get_node()),
"Parameter validation expects Parameter op");
auto consumer_inputs = expr->get_output_port_connector(0)->get_consumers();
const auto& first_consumer = consumer_inputs.begin()->get_expr();
if (is_type<snippets::op::RankNormalization>(first_consumer->get_node())) {
OPENVINO_ASSERT(consumer_inputs.size() == 1,
"If there is RankNormalization after Parameter, it should be single consumer of the Parameter");
consumer_inputs = first_consumer->get_output_port_connector(0)->get_consumers();
}
auto shape_infer_consumers = LinearIR::propagate_expr_through_shape_infer_ops(expr, true);
auto expr_val = shape_infer_consumers.empty() ? expr : shape_infer_consumers.back();
auto consumer_inputs = expr_val->get_output_port_connector(0)->get_consumers();
std::set<std::vector<size_t>> layouts;
for (const auto& consumer_input : consumer_inputs) {
const auto& node = consumer_input.get_expr()->get_node();
Expand All @@ -56,7 +52,9 @@ void validate_parameter(const ExpressionPtr& expr, const LinearIR& linear_ir) {
void validate_result(const ExpressionPtr& expr, const LinearIR& linear_ir) {
OPENVINO_ASSERT(ov::is_type<ov::op::v0::Result>(expr->get_node()),
"Result validation expects Result op");
const auto source = expr->get_input_port_connector(0)->get_source();
auto shape_infer_parents = snippets::lowered::LinearIR::propagate_expr_through_shape_infer_ops(expr, false);
auto expr_val = shape_infer_parents.empty() ? expr : shape_infer_parents.back();
const auto source = expr_val->get_input_port_connector(0)->get_source();
const auto ma = ov::as_type_ptr<snippets::op::MemoryAccess>(source.get_expr()->get_node());
OPENVINO_ASSERT(ma && ma->is_memory_access_output_port(source.get_index()),
"Result expects MemoryAccess parent");
Expand All @@ -71,7 +69,10 @@ void validate_buffer(const ExpressionPtr& expr, const LinearIR& linear_ir) {
OPENVINO_ASSERT(ma && ma->is_memory_access_input_port(source.get_index()),
"Buffer expects MemoryAccess parent");

const auto& out = expr->get_output_port_connector(0);
auto shape_infer_consumers = LinearIR::propagate_expr_through_shape_infer_ops(expr, true);
auto expr_val = shape_infer_consumers.empty() ? expr : shape_infer_consumers.back();

const auto& out = expr_val->get_output_port_connector(0);
const auto consumers = out->get_consumers();
for (const auto& consumer_input : consumers) {
const auto& node = consumer_input.get_expr()->get_node();
Expand Down
Loading

0 comments on commit c810ef1

Please sign in to comment.