Skip to content

Commit

Permalink
Apply Alexandra comments 2
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Mar 19, 2024
1 parent c810ef1 commit 3305c93
Show file tree
Hide file tree
Showing 12 changed files with 103 additions and 54 deletions.
8 changes: 8 additions & 0 deletions src/common/snippets/include/snippets/lowered/linear_ir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,14 @@ class LinearIR {
*/
static std::vector<ExpressionPtr> propagate_expr_through_shape_infer_ops(const ExpressionPtr& start_expr, bool downstream);

/**
* @brief Get last shape infer op from start_expr in a sequence. If no shape infer op is connect to start_expr, return start_expr.
* @param start_expr Search from start_expr.
* @param downstream search downstream if it's true, otherwise search upstream.
* @return last shape infer expr
*/
static ExpressionPtr get_last_shape_infer_expr(const ExpressionPtr& start_expr, bool downstream);

private:
std::shared_ptr<ShapeInferSnippetsNode> m_shape_infer = nullptr;

Expand Down
1 change: 1 addition & 0 deletions src/common/snippets/include/snippets/op/subgraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class Subgraph : public ov::op::util::SubGraphOp {
static auto get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t;
static auto is_domain_sensitive_op(const std::shared_ptr<ov::Node>& op) -> bool;
static auto is_shape_infer_op(const std::shared_ptr<ov::Node>& op) -> bool;
static auto get_last_shape_infer_op(const std::shared_ptr<ov::Node>& op, bool downstream) -> std::shared_ptr<ov::Node>;

void data_flow_transformations(const BlockedShapeVector& blocked_input_shapes = {},
const std::vector<ov::element::Type>& input_precisions = {},
Expand Down
7 changes: 6 additions & 1 deletion src/common/snippets/include/snippets/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,13 @@ VectorDims get_planar_vdims(const snippets::lowered::ExpressionPort& expr_port);
* @return preordered shape: `shape[i]` = `planar_shape[order[i]]` where `shape` is shape before applying the order.
*/
VectorDims get_preordered_vdims(const snippets::lowered::ExpressionPort& expr_port);
/**
* @brief Returns element count of a shape
* @param shape input shape
* @return element count of input shape
*/
inline auto get_shape_size(const VectorDims& shape) -> size_t {
return std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1), std::multiplies<size_t>());
return std::accumulate(shape.begin(), shape.end(), static_cast<size_t>(1), std::multiplies<size_t>());
}
/* --------------------------- */

Expand Down
48 changes: 40 additions & 8 deletions src/common/snippets/src/lowered/linear_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,13 +366,14 @@ VectorDims LinearIR::get_master_shape() const {
}
// Note: Snippets would benefit from a more generic master_shape calculation approach.
// It will be implemented in the scope of ROI propagation activity (ticket 120505)
const auto& source = out_exprs[0]->get_input_port_connector(0)->get_source();
auto last_exp = source.get_expr();
if (!m_config.m_enable_domain_optimization && out_exprs.size() == 1 &&
ov::is_type<snippets::op::Brgemm>(source.get_expr()->get_node())) {
master_shape = utils::get_preordered_vdims(source);
} else if (out_exprs.size() == 1 && ov::is_type<snippets::op::Reshape>(last_exp->get_node())) {
master_shape = utils::get_preordered_vdims(last_exp->get_input_port_connector(0)->get_source());
if (out_exprs.size() == 1) {
const auto& source = out_exprs[0]->get_input_port_connector(0)->get_source();
if (!m_config.m_enable_domain_optimization && ov::is_type<snippets::op::Brgemm>(source.get_expr()->get_node())) {
master_shape = utils::get_preordered_vdims(source);
} else {
auto last_shape_infer_expr = LinearIR::get_last_shape_infer_expr(out_exprs[0], false);
master_shape = utils::get_preordered_vdims(last_shape_infer_expr->get_input_port_connector(0)->get_source());
}
} else {
for (const auto& oe : out_exprs) {
const auto& port_desc = oe->get_input_port_descriptor(0);
Expand Down Expand Up @@ -514,7 +515,7 @@ std::vector<ExpressionPtr> LinearIR::propagate_expr_through_shape_infer_ops(cons
current_exp = first_child;
if (current_exp->get_output_count() == 0)
break;
auto consumers = current_exp->get_output_port_connector(0)->get_consumers();
consumers = current_exp->get_output_port_connector(0)->get_consumers();
first_child = consumers.begin()->get_expr();
}
return shape_infer_exprs;
Expand All @@ -534,6 +535,37 @@ std::vector<ExpressionPtr> LinearIR::propagate_expr_through_shape_infer_ops(cons
}
}

ExpressionPtr LinearIR::get_last_shape_infer_expr(const ExpressionPtr& start_expr, bool downstream) {
auto last_exp = start_expr;
if (downstream) {
if (last_exp->get_output_count() == 0)
return last_exp;
auto consumers = last_exp->get_output_port_connector(0)->get_consumers();
auto first_child = consumers.begin()->get_expr();
while (op::Subgraph::is_shape_infer_op(first_child->get_node())) {
OPENVINO_ASSERT(consumers.size() == 1, "Shape infer ops are supposed to be the only consumer.");
last_exp = first_child;
if (last_exp->get_output_count() == 0)
break;
consumers = last_exp->get_output_port_connector(0)->get_consumers();
first_child = consumers.begin()->get_expr();
}
return last_exp;
} else {
// upstream
if (last_exp->get_input_count() == 0)
return last_exp;
auto first_source = last_exp->get_input_port_connector(0)->get_source().get_expr();
while (op::Subgraph::is_shape_infer_op(first_source->get_node())) {
last_exp = first_source;
if (last_exp->get_input_count() == 0)
break;
first_source = last_exp->get_input_port_connector(0)->get_source().get_expr();
}
return last_exp;
}
}

LinearIR::LIRShapeInfer::LIRShapeInfer(container& body_exprs, io_container& io_exprs)
: ShapeInferSnippetsNode(),
m_exprs{std::make_shared<container>(body_exprs)} {
Expand Down
8 changes: 4 additions & 4 deletions src/common/snippets/src/lowered/pass/allocate_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,21 @@ void AllocateBuffers::set_buffer_offset(const ExpressionPtr& buffer_expr, const
}
}
// Propagate to down: in Load. Buffer can have several Load
const auto& buffer_out = buffer_expr->get_output_port_connector(0);
auto last_shape_infer = ov::snippets::lowered::LinearIR::get_last_shape_infer_expr(buffer_expr, true);
const auto& buffer_out = last_shape_infer->get_output_port_connector(0);
for (const auto& child_expr_input : buffer_out->get_consumers()) {
const auto& child_expr = child_expr_input.get_expr();
const auto port = child_expr_input.get_index();
const auto& child_node = child_expr->get_node();
auto memory_access = ov::as_type_ptr<ov::snippets::op::MemoryAccess>(child_node);
if (memory_access && memory_access->is_memory_access_input_port(port)) {
memory_access->set_input_offset(offset, port);
} else if (ov::is_type<op::LoopEnd>(child_node) || op::Subgraph::is_shape_infer_op(child_node)) {
} else if (ov::is_type<op::LoopEnd>(child_node)) {
// After Loop initialization, Buffer can be connected to LoopEnd - it's ok
// There are also buffer before shape-changing ops
continue;
} else {
OPENVINO_THROW(
"Buffer::set_offset() was called when Buffer didn't have the corresponding MemoryAccess op for offset propagation");
"Buffer::set_offset() was called when Buffer didn't have the corresponding MemoryAccess op for offset propagation");
}
}
}
Expand Down
8 changes: 2 additions & 6 deletions src/common/snippets/src/lowered/pass/insert_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,8 @@ void InsertBuffers::insertion(LinearIR& linear_ir,
parent_expr_output,
m_buffer_allocation_rank);
const auto buffer = std::make_shared<op::IntermediateMemoryBuffer>(parent->output(parent_port), allocation_shape);
if (has_shape_infer_parent) {
linear_ir.insert_node(buffer, std::vector<ExpressionPort>{ parent_expr_output }, buffer_loop_ids, false, pos,
{ top_shape_infer_expr->get_input_port(0) });
} else {
linear_ir.insert_node(buffer, std::vector<ExpressionPort>{ parent_expr_output }, buffer_loop_ids, false, pos, { *entry_port });
}
const auto buffer_consumer = has_shape_infer_parent ? top_shape_infer_expr->get_input_port(0) : *entry_port;
linear_ir.insert_node(buffer, std::vector<ExpressionPort>{ parent_expr_output }, buffer_loop_ids, false, pos, { buffer_consumer });
}
}

Expand Down
9 changes: 2 additions & 7 deletions src/common/snippets/src/lowered/pass/insert_load_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ size_t InsertLoadStore::get_count(const ExpressionPort& port) const {

bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExprIt& data_expr_it) {
std::shared_ptr<Expression> data_expr = *data_expr_it;
auto shape_infer_consumers = LinearIR::propagate_expr_through_shape_infer_ops(data_expr, true);
if (!shape_infer_consumers.empty())
data_expr = shape_infer_consumers.back();

data_expr = LinearIR::get_last_shape_infer_expr(data_expr, true);
const auto& data_ngraph_output = data_expr->get_node()->output(0);
bool was_inserted = false;
const auto& data_out = data_expr->get_output_port_connector(0);
Expand All @@ -60,9 +57,7 @@ bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExpr

bool InsertLoadStore::insert_store(LinearIR& linear_ir, const LinearIR::constExprIt& data_expr_it) {
auto data_expr = *data_expr_it;
auto shape_infer_consumers = LinearIR::propagate_expr_through_shape_infer_ops(data_expr, false);
if (!shape_infer_consumers.empty())
data_expr = shape_infer_consumers.back();
data_expr = LinearIR::get_last_shape_infer_expr(data_expr, false);

const auto& parent_output = data_expr->get_input_port_connector(0)->get_source();
const auto& parent_expr = parent_output.get_expr();
Expand Down
11 changes: 3 additions & 8 deletions src/common/snippets/src/lowered/pass/validate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ void validate_ports(const ExpressionPtr& expr) {
void validate_parameter(const ExpressionPtr& expr, const LinearIR& linear_ir) {
OPENVINO_ASSERT(ov::is_type<ov::op::v0::Parameter>(expr->get_node()),
"Parameter validation expects Parameter op");
auto shape_infer_consumers = LinearIR::propagate_expr_through_shape_infer_ops(expr, true);
auto expr_val = shape_infer_consumers.empty() ? expr : shape_infer_consumers.back();
auto expr_val = LinearIR::get_last_shape_infer_expr(expr, true);
auto consumer_inputs = expr_val->get_output_port_connector(0)->get_consumers();
std::set<std::vector<size_t>> layouts;
for (const auto& consumer_input : consumer_inputs) {
Expand All @@ -52,8 +51,7 @@ void validate_parameter(const ExpressionPtr& expr, const LinearIR& linear_ir) {
void validate_result(const ExpressionPtr& expr, const LinearIR& linear_ir) {
OPENVINO_ASSERT(ov::is_type<ov::op::v0::Result>(expr->get_node()),
"Result validation expects Result op");
auto shape_infer_parents = snippets::lowered::LinearIR::propagate_expr_through_shape_infer_ops(expr, false);
auto expr_val = shape_infer_parents.empty() ? expr : shape_infer_parents.back();
auto expr_val = LinearIR::get_last_shape_infer_expr(expr, false);
const auto source = expr_val->get_input_port_connector(0)->get_source();
const auto ma = ov::as_type_ptr<snippets::op::MemoryAccess>(source.get_expr()->get_node());
OPENVINO_ASSERT(ma && ma->is_memory_access_output_port(source.get_index()),
Expand All @@ -68,10 +66,7 @@ void validate_buffer(const ExpressionPtr& expr, const LinearIR& linear_ir) {
const auto ma = ov::as_type_ptr<snippets::op::MemoryAccess>(source.get_expr()->get_node());
OPENVINO_ASSERT(ma && ma->is_memory_access_input_port(source.get_index()),
"Buffer expects MemoryAccess parent");

auto shape_infer_consumers = LinearIR::propagate_expr_through_shape_infer_ops(expr, true);
auto expr_val = shape_infer_consumers.empty() ? expr : shape_infer_consumers.back();

auto expr_val = LinearIR::get_last_shape_infer_expr(expr, true);
const auto& out = expr_val->get_output_port_connector(0);
const auto consumers = out->get_consumers();
for (const auto& consumer_input : consumers) {
Expand Down
31 changes: 29 additions & 2 deletions src/common/snippets/src/op/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,34 @@ auto Subgraph::is_shape_infer_op(const std::shared_ptr<ov::Node>& op) -> bool {
ov::is_type<snippets::op::RankNormalization>(op);
}

auto Subgraph::get_last_shape_infer_op(const std::shared_ptr<ov::Node>& op, bool downstream) -> std::shared_ptr<ov::Node> {
auto last_op = op;
if (downstream) {
if (last_op->get_output_size() == 0)
return last_op;
auto first_child = last_op->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
while (op::Subgraph::is_shape_infer_op(first_child)) {
last_op = first_child;
if (last_op->get_output_size() == 0)
break;
first_child = last_op->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
}
return last_op;
} else {
// upstream
if (last_op->get_input_size() == 0)
return last_op;
auto first_parent = last_op->get_input_node_shared_ptr(0);
while (op::Subgraph::is_shape_infer_op(first_parent)) {
last_op = first_parent;
if (last_op->get_input_size() == 0)
break;
first_parent = last_op->get_input_node_shared_ptr(0);
}
return last_op;
}
}

void Subgraph::init_config() {
auto update = [](bool& flag, bool status) { flag = flag || status; };
const auto ops = body_ptr()->get_ops();
Expand Down Expand Up @@ -327,8 +355,7 @@ VectorDims Subgraph::infer_master_shape() {
OPENVINO_ASSERT(!output_dims.empty(), "Can't calculate master_shape before the first shape inference");
} else {
for (const auto& res : body_ptr()->get_results()) {
auto reshape = ov::as_type_ptr<op::Reshape>(res->get_input_node_shared_ptr(0));
auto res_input = reshape ? reshape->input(0) : res->input(0);
auto res_input = get_last_shape_infer_op(res, false)->input(0);
OPENVINO_ASSERT(res_input.get_partial_shape().is_static(), "Result have dynamic shape in static pipeline");
// We need to account to the shape's layout stored in Output<Node> rt_info
const auto& planar_shape = utils::get_preordered_pshape(res_input.get_source_output());
Expand Down
15 changes: 4 additions & 11 deletions src/common/snippets/src/pass/align_element_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ bool pass::AlignElementTypes::run_on_model(const std::shared_ptr<ov::Model>& m)
for (size_t i = 0; i < m_output_precisions.size(); i++) {
const auto needed_out_type = m_output_precisions[i];
if (results[i]->get_input_element_type(0) != needed_out_type) {
std::shared_ptr<ov::Node> consumer = results[i];
std::shared_ptr<ov::Node> consumer = op::Subgraph::get_last_shape_infer_op(results[i], false);
auto parent_output = consumer->get_input_source_output(0);

// Snippets supports Transpose only after Parameter or before Result nodes
Expand Down Expand Up @@ -76,18 +76,11 @@ bool pass::AlignElementTypes::run_on_model(const std::shared_ptr<ov::Model>& m)
parameter->set_element_type(needed_in_type);
parameter->validate_and_infer_types();

auto parent_output = parameter->output(0);
auto consumer_inputs = parent_output.get_target_inputs();

auto first_child = consumer_inputs.begin()->get_node()->shared_from_this();
// Note: shape infer ops is designed for shape-inference purposes only.
// It does not process any data (nor does it emit any code), so it doesn't require Convert operations
while (op::Subgraph::is_shape_infer_op(first_child)) {
OPENVINO_ASSERT(consumer_inputs.size() == 1, "Shape infer ops are supposed to be the only consumer");
parent_output = first_child->output(0);
consumer_inputs = parent_output.get_target_inputs();
first_child = consumer_inputs.begin()->get_node()->shared_from_this();
}
auto first_child = op::Subgraph::get_last_shape_infer_op(parameter, true);
auto parent_output = first_child->output(0);
auto consumer_inputs = parent_output.get_target_inputs();

// Snippets supports Transpose only after Parameter or before Result nodes
// So we have to insert Convert after Transpose (if there is) on Subgraph inputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,7 @@ jit_kernel_static_emitter::jit_kernel_static_emitter(dnnl::impl::cpu::x64::jit_g
switch (expr->get_type()) {
case snippets::lowered::IOExpression::io_type::INPUT: {
// input->shape changing ops->load
auto shape_infer_consumers = snippets::lowered::LinearIR::propagate_expr_through_shape_infer_ops(expr, true);
auto mem_desc_expr = shape_infer_consumers.empty() ? expr : shape_infer_consumers.back();
auto mem_desc_expr = ov::snippets::lowered::LinearIR::get_last_shape_infer_expr(expr, true);
auto consumer_inputs = mem_desc_expr->get_output_port_connector(0)->get_consumers();
for (const auto& child_input : consumer_inputs) {
const auto ma = ov::as_type_ptr<snippets::op::MemoryAccess>(child_input.get_expr()->get_node());
Expand All @@ -217,8 +216,7 @@ jit_kernel_static_emitter::jit_kernel_static_emitter(dnnl::impl::cpu::x64::jit_g
}
case snippets::lowered::IOExpression::io_type::OUTPUT: {
// store->shape changing ops->result
auto shape_infer_sources = snippets::lowered::LinearIR::propagate_expr_through_shape_infer_ops(expr, false);
auto mem_desc_expr = shape_infer_sources.empty() ? expr : shape_infer_sources.back();
auto mem_desc_expr = ov::snippets::lowered::LinearIR::get_last_shape_infer_expr(expr, false);
desc = mem_desc_expr->get_input_port_connector(0)->get_source().get_descriptor_ptr();
etype = mem_desc_expr->get_node()->get_input_element_type(0);
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -471,10 +471,9 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
},
ov::pass::NormalizeL2Decomposition);

// todo: only support f32 in first version
CPU_SET_CALLBACK_X64(manager,
[](const_node_ptr &node) -> bool {
return !node->is_dynamic() && node->get_element_type() == element::f32;
[this](const_node_ptr &node) -> bool {
return !node->is_dynamic() && node->get_element_type() == element::f32 && inferencePrecision != ov::element::bf16;
},
ov::pass::GroupNormalizationDecomposition);

Expand Down

0 comments on commit 3305c93

Please sign in to comment.