Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Snippets] GN pattern via snippets #2

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/common/snippets/include/snippets/op/reshape.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/op.hpp"
#include "snippets/shape_inference/shape_inference.hpp"

namespace ov {
namespace snippets {
namespace op {

/**
* @interface Reshape
* @brief Reshape input tensor to reqiured target shape
* @ingroup snippets
*/
class Reshape : public ov::op::Op {
public:
OPENVINO_OP("Reshape", "SnippetsOpset");
Reshape(const Output<Node>& x, ov::PartialShape target_shape);
Reshape() = default;

bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;

const ov::PartialShape& get_target_shape() const;
void set_target_shape(ov::PartialShape shape);

private:
ov::PartialShape m_target_shape = {};
};

} // namespace op
} // namespace snippets
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "snippets/pass/tokenization.hpp"

namespace ov {
namespace snippets {
namespace pass {

/**
* @interface TokenizeGroupNormSnippets
* @brief Tokenize GroupNormalization to a subgraph
* @ingroup snippets
*/
class TokenizeGroupNormSnippets: public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TokenizeGroupNormSnippets", "0");
TokenizeGroupNormSnippets();
};

} // namespace pass
} // namespace snippets
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/matcher.hpp"

namespace ov {
namespace snippets {
namespace pass {

/**
* @interface GroupNormalizationDecomposition
* @brief Decomposes GroupNormalization to a range of low-level operations
* @ingroup snippets
*/
class GroupNormalizationDecomposition: public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("GroupNormalizationDecomposition", "0");
GroupNormalizationDecomposition();
};

} // namespace pass
} // namespace snippets
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,12 @@ class ReduceShapeInfer : public IShapeInferSnippets {
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};

class ReshapeShapeInfer : public IShapeInferSnippets {
ov::PartialShape target_shape;
public:
explicit ReshapeShapeInfer(const std::shared_ptr<Node>& n);
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};

} // namespace snippets
} // namespace ov
1 change: 1 addition & 0 deletions src/common/snippets/include/snippets/snippets_isa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "op/fill.hpp"
#include "op/kernel.hpp"
#include "op/load.hpp"
#include "op/reshape.hpp"
#include "op/nop.hpp"
#include "op/scalar.hpp"
#include "op/powerstatic.hpp"
Expand Down
1 change: 1 addition & 0 deletions src/common/snippets/include/snippets/snippets_isa_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ OV_OP(LoopBegin, ov::snippets::op)
OV_OP(LoopEnd, ov::snippets::op)
OV_OP(Brgemm, ov::snippets::op)
OV_OP(BroadcastLoad, ov::snippets::op)
OV_OP(Reshape, ov::snippets::op)

OV_OP(Store, ov::snippets::op)

Expand Down
1 change: 1 addition & 0 deletions src/common/snippets/src/generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ RegType Generator::get_op_out_reg_type(const ov::Output<Node>& out) const {
std::dynamic_pointer_cast<op::IntermediateMemoryBuffer>(op) ||
std::dynamic_pointer_cast<op::NewMemoryBuffer>(op) ||
std::dynamic_pointer_cast<op::RankNormalization>(op) ||
std::dynamic_pointer_cast<op::Reshape>(op) ||
std::dynamic_pointer_cast<snippets::op::Store>(op)
#ifdef SNIPPETS_DEBUG_CAPS
|| std::dynamic_pointer_cast<op::PerfCountBeginBase>(op)
Expand Down
3 changes: 3 additions & 0 deletions src/common/snippets/src/lowered/linear_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,12 @@ VectorDims LinearIR::get_master_shape() const {
// Note: Snippets would benefit from a more generic master_shape calculation approach.
// It will be implemented in the scope of ROI propagation activity (ticket 120505)
const auto& source = out_exprs[0]->get_input_port_connector(0)->get_source();
auto last_exp = source.get_expr();
if (!m_config.m_enable_domain_optimization && out_exprs.size() == 1 &&
ov::is_type<snippets::op::Brgemm>(source.get_expr()->get_node())) {
master_shape = utils::get_preordered_vdims(source);
} else if (out_exprs.size() == 1 && ov::is_type<snippets::op::Reshape>(last_exp->get_node())) {
master_shape = utils::get_preordered_vdims(last_exp->get_input_port_connector(0)->get_source());
} else {
for (const auto& oe : out_exprs) {
const auto& port_desc = oe->get_input_port_descriptor(0);
Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/src/lowered/pass/allocate_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ void AllocateBuffers::set_buffer_offset(const ExpressionPtr& buffer_expr, const
// After Loop initialization, Buffer can be connected to LoopEnd - it's ok
continue;
} else {
OPENVINO_THROW(
"Buffer::set_offset() was called when Buffer didn't have the corresponding MemoryAccess op for offset propagation");
// OPENVINO_THROW(
// "Buffer::set_offset() was called when Buffer didn't have the corresponding MemoryAccess op for offset propagation");
}
}
}
Expand Down
23 changes: 21 additions & 2 deletions src/common/snippets/src/lowered/pass/assign_registers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,23 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
const auto& consumer_inputs = out_connector->get_consumers();
const auto& first_consumer = consumer_inputs.begin()->get_expr();
// TODO [96434]: Support RankNormalization (Reshape) in arbitrary place in pipeline, not just after inputs
if (ov::is_type<op::RankNormalization>(first_consumer->get_node())) {
OPENVINO_ASSERT(consumer_inputs.size() == 1, "RankNormalization is supposed to be the only consumer");
if (ov::is_type<op::RankNormalization>(first_consumer->get_node()) ||
ov::is_type<op::Reshape>(first_consumer->get_node())) {
OPENVINO_ASSERT(consumer_inputs.size() == 1, "RankNormalization or Reshape is supposed to be the only consumer");
manually_assigned_gprs[first_consumer->get_output_port_connector(0)] = io_expr->get_index();
}
const auto& second_consumer = first_consumer->get_output_port_connector(0)->get_consumers().begin()->get_expr();
if (ov::is_type<op::RankNormalization>(second_consumer->get_node()) ||
ov::is_type<op::Reshape>(second_consumer->get_node())) {
manually_assigned_gprs[second_consumer->get_output_port_connector(0)] = io_expr->get_index();
}
} else if (io_expr->get_type() == IOExpression::io_type::OUTPUT) {
manually_assigned_gprs[expr->get_input_port_connector(0)] = num_parameters + io_expr->get_index();
// reshape before result
const auto &parent = expr->get_input_port_connector(0)->get_source().get_expr();
if (ov::is_type<op::Reshape>(parent->get_node())) {
manually_assigned_gprs[parent->get_input_port_connector(0)] = num_parameters + io_expr->get_index();
}
} else {
OPENVINO_THROW("Unsupported io_type detected");
}
Expand All @@ -97,6 +108,14 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
if (ov::is_type<op::IntermediateMemoryBuffer>(buffer)) {
manually_assigned_gprs[expr->get_input_port_connector(0)] =
static_cast<Reg>(num_results + num_parameters + buffer_id);
// reshape after IntermediateMemoryBuffer
const auto& first_consumer = expr->get_output_port_connector(0)->get_consumers().begin()->get_expr();
if (ov::is_type<op::Reshape>(first_consumer->get_node())) {
manually_assigned_gprs[first_consumer->get_input_port_connector(0)] =
static_cast<Reg>(num_results + num_parameters + buffer_id);
manually_assigned_gprs[first_consumer->get_output_port_connector(0)] =
static_cast<Reg>(num_results + num_parameters + buffer_id);
}
}
manually_assigned_gprs[expr->get_output_port_connector(0)] =
static_cast<Reg>(num_results + num_parameters + buffer_id);
Expand Down
38 changes: 34 additions & 4 deletions src/common/snippets/src/lowered/pass/insert_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,41 @@ void InsertBuffers::insertion(LinearIR& linear_ir,
const auto& expr = entry_port->get_expr();
const auto port_idx = entry_port->get_index();
const auto node = expr->get_node();
const auto& parent_expr_output = expr->get_input_port_connector(port_idx)->get_source();
auto parent_expr_output = expr->get_input_port_connector(port_idx)->get_source();

auto first_not_reshape_parent_output = [&]() {
auto parent_expr = parent_expr_output.get_expr();
while (1) {
if (is_type<op::RankNormalization>(parent_expr->get_node()) ||
is_type<op::Reshape>(parent_expr->get_node())) {
parent_expr_output = parent_expr->get_input_port_connector(0)->get_source();
parent_expr = parent_expr_output.get_expr();
} else {
break;
}
}
};
// this parent(before reshape) is used to determine if buffer needed according loopInfo
first_not_reshape_parent_output();
const auto& parent_expr = parent_expr_output.get_expr();
const auto parent_port = parent_expr_output.get_index();
const auto parent = parent_expr->get_node();
const auto& parent_port = parent_expr_output.get_index();
const auto& parent = parent_expr->get_node();

if (ov::is_type<op::Buffer>(parent) ||
ov::is_type<op::VectorBuffer>(parent) ||
ov::is_type<ov::op::v0::Parameter>(parent) ||
ov::is_type<ov::op::v0::Constant>(parent))
continue;

// insert buffer before reshape
auto buffer_child = expr;
bool parent_is_reshape = false;
auto p_exp = expr->get_input_port_connector(port_idx)->get_source().get_expr();
if (is_type<op::Reshape>(p_exp->get_node())) {
buffer_child = p_exp;
parent_is_reshape = true;
}

// Each MemoryAccess op needs Buffer
const auto parent_ma = ov::as_type_ptr<op::MemoryAccess>(parent);
const auto node_ma = ov::as_type_ptr<op::MemoryAccess>(node);
Expand All @@ -184,7 +209,12 @@ void InsertBuffers::insertion(LinearIR& linear_ir,
parent_expr_output,
m_buffer_allocation_rank);
const auto buffer = std::make_shared<op::IntermediateMemoryBuffer>(parent->output(parent_port), allocation_shape);
linear_ir.insert_node(buffer, std::vector<ExpressionPort>{ parent_expr_output }, buffer_loop_ids, false, pos, { *entry_port });
if (parent_is_reshape) {
linear_ir.insert_node(buffer, std::vector<ExpressionPort>{ parent_expr_output }, buffer_loop_ids, false, pos,
{ buffer_child->get_input_port(0) });
} else {
linear_ir.insert_node(buffer, std::vector<ExpressionPort>{ parent_expr_output }, buffer_loop_ids, false, pos, { *entry_port });
}
}
}

Expand Down
41 changes: 29 additions & 12 deletions src/common/snippets/src/lowered/pass/insert_load_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,24 @@ size_t InsertLoadStore::get_count(const PortDescriptorPtr& port_desc) const {

bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExprIt& data_expr_it) {
std::shared_ptr<Expression> data_expr = *data_expr_it;
auto consumer_inputs = data_expr->get_output_port_connector(0)->get_consumers();
const auto& first_consumer = consumer_inputs.begin()->get_expr();
if (is_type<op::RankNormalization>(first_consumer->get_node())) {
OPENVINO_ASSERT(consumer_inputs.size() == 1, "RankNormalization is supposed to be the only consumer");
data_expr = first_consumer;
}
const auto& consumer_inputs = data_expr->get_output_port_connector(0)->get_consumers();
auto first_reshape_consumer = [&]() {
auto current_exp = data_expr;
auto first_consumer = consumer_inputs.begin()->get_expr();
while (1) {
if (is_type<op::RankNormalization>(first_consumer->get_node()) ||
is_type<op::Reshape>(first_consumer->get_node())) {
current_exp = first_consumer;
first_consumer = first_consumer->get_output_port_connector(0)->get_consumers().begin()->get_expr();
// OPENVINO_ASSERT(current_exp->get_output_port_connector(0)->get_consumers().size() == 1,
// "RankNormalization or Reshape is supposed to be the only consumer");
} else {
return current_exp;
}
}
};
data_expr = first_reshape_consumer();

const auto& data_ngraph_output = data_expr->get_node()->output(0);
bool was_inserted = false;
const auto& data_out = data_expr->get_output_port_connector(0);
Expand All @@ -56,12 +68,17 @@ bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExpr
}

bool InsertLoadStore::insert_store(LinearIR& linear_ir, const LinearIR::constExprIt& data_expr_it) {
const auto& data_expr = *data_expr_it;
const auto& parent_output = data_expr->get_input_port_connector(0)->get_source();
const auto& parent_expr = parent_output.get_expr();
const auto port = parent_output.get_index();
const auto& parent = parent_expr->get_node();
const auto ma = ov::as_type_ptr<op::MemoryAccess>(parent);
auto data_expr = *data_expr_it;
auto parent_output = data_expr->get_input_port_connector(0)->get_source();
auto parent_expr = parent_output.get_expr();
if (is_type<op::Reshape>(parent_expr->get_node())) {
data_expr = parent_expr;
parent_output = data_expr->get_input_port_connector(0)->get_source();
parent_expr = parent_output.get_expr();
}
auto port = parent_output.get_index();
auto parent = parent_expr->get_node();
auto ma = ov::as_type_ptr<op::MemoryAccess>(parent);
if (ma && ma->is_memory_access_output_port(port))
return false;

Expand Down
6 changes: 4 additions & 2 deletions src/common/snippets/src/lowered/pass/mark_loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ bool MarkLoops::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, l
return ov::is_type<ov::op::v0::Result>(node) ||
ov::is_type<ov::op::v0::Constant>(node) ||
ov::is_type<ov::op::v0::Parameter>(node) ||
ov::is_type<op::RankNormalization>(node);
ov::is_type<op::RankNormalization>(node) ||
ov::is_type<op::Reshape>(node);
};

auto are_conflicted = [](const ExpressionPort& lhs, const ExpressionPort& rhs) {
Expand Down Expand Up @@ -59,7 +60,8 @@ bool MarkLoops::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, l
const auto& current_expr = *loop_end_pos;
const auto& current_node = current_expr->get_node();
if (ov::is_type<ov::op::v0::Result>(current_node) ||
ov::is_type<ov::op::v0::Constant>(current_node))
ov::is_type<ov::op::v0::Constant>(current_node) ||
ov::is_type<op::Reshape>(current_node))
break;

// We finish Loop if
Expand Down
21 changes: 16 additions & 5 deletions src/common/snippets/src/lowered/pass/propagate_layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,22 @@ bool PropagateLayout::run(lowered::LinearIR& linear_ir, lowered::LinearIR::const
if (is_input) {
// Note that here we consider only the first child (which is usually load),
// but often there is another child - LoopEnd
auto consumer_inputs = target_connector->get_consumers();
const auto& first_consumer = consumer_inputs.begin()->get_expr();
// If there is a RankNormalization op after a parameter - we should skip it
if (is_type<op::RankNormalization>(first_consumer->get_node()))
consumer_inputs = first_consumer->get_output_port_connector(0)->get_consumers();
// If there is a RankNormalization or Reshape op after a parameter - we should skip it
auto first_not_reshape_child = [&]() {
auto current_exp = expr;
auto first_child = target_connector->get_consumers().begin()->get_expr();
while (1) {
if (is_type<op::RankNormalization>(first_child->get_node()) ||
is_type<op::Reshape>(first_child->get_node())) {
current_exp = first_child;
first_child = first_child->get_output_port_connector(0)->get_consumers().begin()->get_expr();
} else {
return current_exp;
}
}
};
const auto& consumer_inputs = first_not_reshape_child()->get_output_port_connector(0)->get_consumers();

std::set<std::vector<size_t>> child_layouts;
for (const auto& child_input : consumer_inputs) {
const auto& child = child_input.get_expr();
Expand Down
Loading
Loading