Skip to content

Commit

Permalink
groupNorm tokenzation decompostion and sheduling
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Mar 18, 2024
1 parent c208a88 commit 1dbc0fd
Show file tree
Hide file tree
Showing 27 changed files with 505 additions and 31 deletions.
38 changes: 38 additions & 0 deletions src/common/snippets/include/snippets/op/reshape.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/op.hpp"
#include "snippets/shape_inference/shape_inference.hpp"

namespace ov {
namespace snippets {
namespace op {

/**
* @interface Reshape
* @brief Reshape input tensor to reqiured target shape
* @ingroup snippets
*/
class Reshape : public ov::op::Op {
public:
OPENVINO_OP("Reshape", "SnippetsOpset");
Reshape(const Output<Node>& x, ov::PartialShape target_shape);
Reshape() = default;

bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;

const ov::PartialShape& get_target_shape() const;
void set_target_shape(ov::PartialShape shape);

private:
ov::PartialShape m_target_shape = {};
};

} // namespace op
} // namespace snippets
} // namespace ov
27 changes: 27 additions & 0 deletions src/common/snippets/include/snippets/pass/gn_decomposition.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/matcher.hpp"

namespace ov {
namespace snippets {
namespace pass {

/**
* @interface GNDecomposition
* @brief Decomposes GroupNormalization to a range of low-level operations
* @ingroup snippets
*/
class GNDecomposition: public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("GNDecomposition", "0");
GNDecomposition();
};

} // namespace pass
} // namespace snippets
} // namespace ov
28 changes: 28 additions & 0 deletions src/common/snippets/include/snippets/pass/gn_tokenization.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "snippets/pass/tokenization.hpp"

namespace ov {
namespace snippets {
namespace pass {

/**
* @interface TokenizeGNSnippets
* @brief Tokenize GroupNormalization to a subgraph
* @ingroup snippets
*/
class TokenizeGNSnippets: public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TokenizeGNSnippets", "0");
TokenizeGNSnippets();
};

} // namespace pass
} // namespace snippets
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,12 @@ class ReduceShapeInfer : public IShapeInferSnippets {
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};

class ReshapeShapeInfer : public IShapeInferSnippets {
ov::PartialShape target_shape;
public:
explicit ReshapeShapeInfer(const std::shared_ptr<Node>& n);
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};

} // namespace snippets
} // namespace ov
1 change: 1 addition & 0 deletions src/common/snippets/include/snippets/snippets_isa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "op/fill.hpp"
#include "op/kernel.hpp"
#include "op/load.hpp"
#include "op/reshape.hpp"
#include "op/nop.hpp"
#include "op/scalar.hpp"
#include "op/powerstatic.hpp"
Expand Down
1 change: 1 addition & 0 deletions src/common/snippets/include/snippets/snippets_isa_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ OV_OP(LoopBegin, ov::snippets::op)
OV_OP(LoopEnd, ov::snippets::op)
OV_OP(Brgemm, ov::snippets::op)
OV_OP(BroadcastLoad, ov::snippets::op)
OV_OP(Reshape, ov::snippets::op)

OV_OP(Store, ov::snippets::op)

Expand Down
1 change: 1 addition & 0 deletions src/common/snippets/src/generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ RegType Generator::get_op_out_reg_type(const ov::Output<Node>& out) const {
std::dynamic_pointer_cast<op::IntermediateMemoryBuffer>(op) ||
std::dynamic_pointer_cast<op::NewMemoryBuffer>(op) ||
std::dynamic_pointer_cast<op::RankNormalization>(op) ||
std::dynamic_pointer_cast<op::Reshape>(op) ||
std::dynamic_pointer_cast<snippets::op::Store>(op)
#ifdef SNIPPETS_DEBUG_CAPS
|| std::dynamic_pointer_cast<op::PerfCountBeginBase>(op)
Expand Down
3 changes: 3 additions & 0 deletions src/common/snippets/src/lowered/linear_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,9 +366,12 @@ VectorDims LinearIR::get_master_shape() const {
// Note: Snippets would benefit from a more generic master_shape calculation approach.
// It will be implemented in the scope of ROI propagation activity (ticket 120505)
const auto& source = out_exprs[0]->get_input_port_connector(0)->get_source();
auto last_exp = source.get_expr();
if (!m_config.m_enable_domain_optimization && out_exprs.size() == 1 &&
ov::is_type<snippets::op::Brgemm>(source.get_expr()->get_node())) {
master_shape = utils::get_preordered_vdims(source);
} else if (out_exprs.size() == 1 && ov::is_type<snippets::op::Reshape>(last_exp->get_node())) {
master_shape = utils::get_preordered_vdims(last_exp->get_input_port_connector(0)->get_source());
} else {
for (const auto& oe : out_exprs) {
const auto& port_desc = oe->get_input_port_descriptor(0);
Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/src/lowered/pass/allocate_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ void AllocateBuffers::set_buffer_offset(const ExpressionPtr& buffer_expr, const
// After Loop initialization, Buffer can be connected to LoopEnd - it's ok
continue;
} else {
OPENVINO_THROW(
"Buffer::set_offset() was called when Buffer didn't have the corresponding MemoryAccess op for offset propagation");
// OPENVINO_THROW(
// "Buffer::set_offset() was called when Buffer didn't have the corresponding MemoryAccess op for offset propagation");
}
}
}
Expand Down
29 changes: 23 additions & 6 deletions src/common/snippets/src/lowered/pass/assign_registers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,24 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
if (io_expr->get_type() == IOExpression::io_type::INPUT) {
const auto& out_connector = expr->get_output_port_connector(0);
manually_assigned_gprs[out_connector] = io_expr->get_index();
const auto& consumer_inputs = out_connector->get_consumers();
const auto& first_consumer = consumer_inputs.begin()->get_expr();
// TODO [96434]: Support RankNormalization (Reshape) in arbitrary place in pipeline, not just after inputs
if (ov::is_type<op::RankNormalization>(first_consumer->get_node())) {
OPENVINO_ASSERT(consumer_inputs.size() == 1, "RankNormalization is supposed to be the only consumer");
manually_assigned_gprs[first_consumer->get_output_port_connector(0)] = io_expr->get_index();
// TODO [96434]: Support RankNormalization/Reshape in arbitrary place in pipeline, not just after inputs
// reshape rankNormalization sequence
auto consumer_inputs = out_connector->get_consumers();
auto child_exp = consumer_inputs.begin()->get_expr();
while (ov::is_type<op::RankNormalization>(child_exp->get_node()) ||
ov::is_type<op::Reshape>(child_exp->get_node())) {
OPENVINO_ASSERT(consumer_inputs.size() == 1, "RankNormalization or Reshape is supposed to be the only consumer");
manually_assigned_gprs[child_exp->get_output_port_connector(0)] = io_expr->get_index();
consumer_inputs = child_exp->get_output_port_connector(0)->get_consumers();
child_exp = consumer_inputs.begin()->get_expr();
}
} else if (io_expr->get_type() == IOExpression::io_type::OUTPUT) {
manually_assigned_gprs[expr->get_input_port_connector(0)] = num_parameters + io_expr->get_index();
// reshape before result
const auto &parent = expr->get_input_port_connector(0)->get_source().get_expr();
if (ov::is_type<op::Reshape>(parent->get_node())) {
manually_assigned_gprs[parent->get_input_port_connector(0)] = num_parameters + io_expr->get_index();
}
} else {
OPENVINO_THROW("Unsupported io_type detected");
}
Expand All @@ -97,6 +106,14 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
if (ov::is_type<op::IntermediateMemoryBuffer>(buffer)) {
manually_assigned_gprs[expr->get_input_port_connector(0)] =
static_cast<Reg>(num_results + num_parameters + buffer_id);
// reshape in the middle of subgraph. IntermediateMemoryBuffer is inserted before reshape as new loop should start.
const auto& first_consumer = expr->get_output_port_connector(0)->get_consumers().begin()->get_expr();
if (ov::is_type<op::Reshape>(first_consumer->get_node())) {
manually_assigned_gprs[first_consumer->get_input_port_connector(0)] =
static_cast<Reg>(num_results + num_parameters + buffer_id);
manually_assigned_gprs[first_consumer->get_output_port_connector(0)] =
static_cast<Reg>(num_results + num_parameters + buffer_id);
}
}
manually_assigned_gprs[expr->get_output_port_connector(0)] =
static_cast<Reg>(num_results + num_parameters + buffer_id);
Expand Down
32 changes: 28 additions & 4 deletions src/common/snippets/src/lowered/pass/insert_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,35 @@ void InsertBuffers::insertion(LinearIR& linear_ir,
const auto& expr = entry_port->get_expr();
const auto port_idx = entry_port->get_index();
const auto node = expr->get_node();
const auto& parent_expr_output = expr->get_input_port_connector(port_idx)->get_source();
auto parent_expr_output = expr->get_input_port_connector(port_idx)->get_source();

auto first_not_reshape_parent_output = [&]() {
auto parent_expr = parent_expr_output.get_expr();
while (is_type<op::Reshape>(parent_expr->get_node())) {
parent_expr_output = parent_expr->get_input_port_connector(0)->get_source();
parent_expr = parent_expr_output.get_expr();
}
};
// this parent(before reshape) is used to determine if buffer needed according loopInfo
first_not_reshape_parent_output();
const auto& parent_expr = parent_expr_output.get_expr();
const auto parent_port = parent_expr_output.get_index();
const auto parent = parent_expr->get_node();
const auto& parent_port = parent_expr_output.get_index();
const auto& parent = parent_expr->get_node();
if (ov::is_type<op::Buffer>(parent) ||
ov::is_type<op::VectorBuffer>(parent) ||
ov::is_type<ov::op::v0::Parameter>(parent) ||
ov::is_type<ov::op::v0::Constant>(parent))
continue;

// insert buffer before reshape
auto buffer_child = expr;
bool parent_is_reshape = false;
auto p_exp = expr->get_input_port_connector(port_idx)->get_source().get_expr();
if (is_type<op::Reshape>(p_exp->get_node())) {
buffer_child = p_exp;
parent_is_reshape = true;
}

// Each MemoryAccess op needs Buffer
const auto parent_ma = ov::as_type_ptr<op::MemoryAccess>(parent);
const auto node_ma = ov::as_type_ptr<op::MemoryAccess>(node);
Expand All @@ -178,7 +197,12 @@ void InsertBuffers::insertion(LinearIR& linear_ir,
parent_expr_output,
m_buffer_allocation_rank);
const auto buffer = std::make_shared<op::IntermediateMemoryBuffer>(parent->output(parent_port), allocation_shape);
linear_ir.insert_node(buffer, std::vector<ExpressionPort>{ parent_expr_output }, buffer_loop_ids, false, pos, { *entry_port });
if (parent_is_reshape) {
linear_ir.insert_node(buffer, std::vector<ExpressionPort>{ parent_expr_output }, buffer_loop_ids, false, pos,
{ buffer_child->get_input_port(0) });
} else {
linear_ir.insert_node(buffer, std::vector<ExpressionPort>{ parent_expr_output }, buffer_loop_ids, false, pos, { *entry_port });
}
}
}

Expand Down
41 changes: 29 additions & 12 deletions src/common/snippets/src/lowered/pass/insert_load_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,24 @@ size_t InsertLoadStore::get_count(const ExpressionPort& port) const {

bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExprIt& data_expr_it) {
std::shared_ptr<Expression> data_expr = *data_expr_it;
auto consumer_inputs = data_expr->get_output_port_connector(0)->get_consumers();
const auto& first_consumer = consumer_inputs.begin()->get_expr();
if (is_type<op::RankNormalization>(first_consumer->get_node())) {
OPENVINO_ASSERT(consumer_inputs.size() == 1, "RankNormalization is supposed to be the only consumer");
data_expr = first_consumer;
}
const auto& consumer_inputs = data_expr->get_output_port_connector(0)->get_consumers();
auto first_reshape_consumer = [&]() {
auto current_exp = data_expr;
auto first_consumer = consumer_inputs.begin()->get_expr();
while (1) {
if (is_type<op::RankNormalization>(first_consumer->get_node()) ||
is_type<op::Reshape>(first_consumer->get_node())) {
current_exp = first_consumer;
first_consumer = first_consumer->get_output_port_connector(0)->get_consumers().begin()->get_expr();
// OPENVINO_ASSERT(current_exp->get_output_port_connector(0)->get_consumers().size() == 1,
// "RankNormalization or Reshape is supposed to be the only consumer");
} else {
return current_exp;
}
}
};
data_expr = first_reshape_consumer();

const auto& data_ngraph_output = data_expr->get_node()->output(0);
bool was_inserted = false;
const auto& data_out = data_expr->get_output_port_connector(0);
Expand All @@ -61,12 +73,17 @@ bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExpr
}

bool InsertLoadStore::insert_store(LinearIR& linear_ir, const LinearIR::constExprIt& data_expr_it) {
const auto& data_expr = *data_expr_it;
const auto& parent_output = data_expr->get_input_port_connector(0)->get_source();
const auto& parent_expr = parent_output.get_expr();
const auto port = parent_output.get_index();
const auto& parent = parent_expr->get_node();
const auto ma = ov::as_type_ptr<op::MemoryAccess>(parent);
auto data_expr = *data_expr_it;
auto parent_output = data_expr->get_input_port_connector(0)->get_source();
auto parent_expr = parent_output.get_expr();
if (is_type<op::Reshape>(parent_expr->get_node())) {
data_expr = parent_expr;
parent_output = data_expr->get_input_port_connector(0)->get_source();
parent_expr = parent_output.get_expr();
}
auto port = parent_output.get_index();
auto parent = parent_expr->get_node();
auto ma = ov::as_type_ptr<op::MemoryAccess>(parent);
if (ma && ma->is_memory_access_output_port(port))
return false;

Expand Down
3 changes: 2 additions & 1 deletion src/common/snippets/src/lowered/pass/mark_loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ bool MarkLoops::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, l
return ov::is_type<ov::op::v0::Result>(node) ||
ov::is_type<ov::op::v0::Constant>(node) ||
ov::is_type<ov::op::v0::Parameter>(node) ||
ov::is_type<op::RankNormalization>(node);
ov::is_type<op::RankNormalization>(node) ||
ov::is_type<op::Reshape>(node);
};

auto are_conflicted = [](const ExpressionPort& lhs, const ExpressionPort& rhs) {
Expand Down
43 changes: 43 additions & 0 deletions src/common/snippets/src/op/reshape.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "snippets/itt.hpp"

#include "snippets/op/reshape.hpp"
#include "snippets/utils.hpp"


namespace ov {
namespace snippets {
namespace op {
Reshape::Reshape(const Output<Node>& arg, ov::PartialShape target_shape)
: Op({arg}), m_target_shape(target_shape) {
constructor_validate_and_infer_types();
}

void Reshape::validate_and_infer_types() {
set_output_type(0, get_input_element_type(0), m_target_shape);
}

std::shared_ptr<Node> Reshape::clone_with_new_inputs(const OutputVector& new_args) const {
INTERNAL_OP_SCOPE(Reshape);
check_new_args_count(this, new_args);
return std::make_shared<Reshape>(new_args.at(0), get_target_shape());
}

bool Reshape::visit_attributes(AttributeVisitor& visitor) {
visitor.on_attribute("target_shape", m_target_shape);
return true;
}

const ov::PartialShape& Reshape::get_target_shape() const {
return m_target_shape;
}

void Reshape::set_target_shape(ov::PartialShape shape) {
m_target_shape = shape;
}
}// namespace op
}// namespace snippets
}// namespace ov
Loading

0 comments on commit 1dbc0fd

Please sign in to comment.