Skip to content

Commit

Permalink
gn tokenization condition update, manual subtensor settings remove
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Mar 25, 2024
1 parent 5e95ef6 commit 105387f
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 33 deletions.
8 changes: 4 additions & 4 deletions src/common/snippets/include/snippets/lowered/linear_ir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,12 @@ class LinearIR {
exprIt replace_with_expr(const std::vector<ExpressionPtr>& old_exprs, const ExpressionPtr& new_expr);

/**
* @brief Propagate start_expr through zero to several consecutive shape infer exprs(such as reshape, rankNormalization).
* @param start_expr Propagate from start_expr.
* @param downstream Propagate downstream if it's true, otherwise propagate upstream.
* @brief Get zero to several consecutive shape infer exprs(such as reshape, rankNormalization) from start_expr.
* @param start_expr Collect from start_expr.
* @param downstream Collect downstream if it's true, otherwise collect upstream.
* @return shape infer op consumers as a sequence if downstream, or shape infer op sources as a sequence if upstream.
*/
static std::vector<ExpressionPtr> propagate_expr_through_shape_infer_ops(const ExpressionPtr& start_expr, bool downstream);
static std::vector<ExpressionPtr> get_shape_infer_expr_seq(const ExpressionPtr& start_expr, bool downstream);

/**
* @brief Get last shape infer op from start_expr in a sequence. If no shape infer op is connect to start_expr, return start_expr.
Expand Down
1 change: 0 additions & 1 deletion src/common/snippets/include/snippets/op/reshape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#pragma once

#include "openvino/op/op.hpp"
#include "snippets/shape_inference/shape_inference.hpp"

namespace ov {
namespace snippets {
Expand Down
2 changes: 1 addition & 1 deletion src/common/snippets/src/lowered/linear_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ LinearIR::exprIt LinearIR::replace_with_expr(const std::vector<ExpressionPtr>& o
return replace_with_expr(old_exprs, new_expr, insertion_place);
}

std::vector<ExpressionPtr> LinearIR::propagate_expr_through_shape_infer_ops(const ExpressionPtr& start_expr, bool downstream) {
std::vector<ExpressionPtr> LinearIR::get_shape_infer_expr_seq(const ExpressionPtr& start_expr, bool downstream) {
std::vector<ExpressionPtr> shape_infer_exprs;
auto current_exp = start_expr;
if (op::Subgraph::is_shape_infer_op(current_exp->get_node())) {
Expand Down
6 changes: 3 additions & 3 deletions src/common/snippets/src/lowered/pass/assign_registers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
manually_assigned_gprs[out_connector] = io_expr->get_index();
// TODO [96434]: Support shape infer ops in arbitrary place in pipeline, not just after inputs
// shape infer ops sequence after input
auto shape_infer_consumers = LinearIR::propagate_expr_through_shape_infer_ops(io_expr, true);
auto shape_infer_consumers = LinearIR::get_shape_infer_expr_seq(io_expr, true);
if (!shape_infer_consumers.empty()) {
for (const auto& child_shape_infer_expr : shape_infer_consumers) {
manually_assigned_gprs[child_shape_infer_expr->get_output_port_connector(0)] = io_expr->get_index();
Expand All @@ -91,7 +91,7 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
} else if (io_expr->get_type() == IOExpression::io_type::OUTPUT) {
manually_assigned_gprs[expr->get_input_port_connector(0)] = num_parameters + io_expr->get_index();
// shape infer ops sequence before result
auto shape_infer_sources = LinearIR::propagate_expr_through_shape_infer_ops(io_expr, false);
auto shape_infer_sources = LinearIR::get_shape_infer_expr_seq(io_expr, false);
if (!shape_infer_sources.empty()) {
for (const auto& parent_shape_infer_expr : shape_infer_sources) {
manually_assigned_gprs[parent_shape_infer_expr->get_input_port_connector(0)] = num_parameters + io_expr->get_index();
Expand All @@ -108,7 +108,7 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
static_cast<Reg>(num_results + num_parameters + buffer_id);
// shape infer ops in the middle of subgraph. IntermediateMemoryBuffer is inserted before reshape as new loop should start.
// child shape info ops share the same memory as IntermediateMemoryBuffer.
auto shape_infer_consumers = LinearIR::propagate_expr_through_shape_infer_ops(expr, true);
auto shape_infer_consumers = LinearIR::get_shape_infer_expr_seq(expr, true);
if (!shape_infer_consumers.empty()) {
for (const auto& child_shape_infer_expr : shape_infer_consumers) {
manually_assigned_gprs[child_shape_infer_expr->get_input_port_connector(0)] =
Expand Down
2 changes: 1 addition & 1 deletion src/common/snippets/src/lowered/pass/insert_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ void InsertBuffers::insertion(LinearIR& linear_ir,
bool has_shape_infer_parent = false;
auto top_shape_infer_expr = expr;
// parent before shape infer ops is used to determine if buffer needed according loopInfo
auto shape_infer_parents = LinearIR::propagate_expr_through_shape_infer_ops(first_parent_expr, false);
auto shape_infer_parents = LinearIR::get_shape_infer_expr_seq(first_parent_expr, false);
if (!shape_infer_parents.empty()) {
parent_expr_output = shape_infer_parents.back()->get_input_port_connector(0)->get_source();
has_shape_infer_parent = true;
Expand Down
19 changes: 1 addition & 18 deletions src/common/snippets/src/pass/gn_decomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ GNDecomposition::GNDecomposition() {
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::pass::GNDecomposition")
auto group_norm_node = ov::as_type_ptr<ov::op::v12::GroupNormalization>(m.get_match_root());
OPENVINO_ASSERT(!group_norm_node->is_dynamic(), "GroupNormalization decomposition in snippets only support static node.");

const auto data = group_norm_node->input_value(0);
const auto scale = group_norm_node->input_value(1);
Expand Down Expand Up @@ -87,26 +88,8 @@ GNDecomposition::GNDecomposition() {
auto eps_add = std::make_shared<ov::op::v1::Add>(sqr_mean, eps_node); // fma to this add and parent multiply
// variance = sqrt( reducemean( (x - mean) ^ 2 ) + eps )
auto variance = std::make_shared<ov::op::v0::Sqrt>(eps_add);

// divide variance
const auto variance_inv = std::make_shared<ov::snippets::op::PowerStatic>(variance, -1.f);

// remove invariance in inner loop
std::vector<size_t> subtensor_invariance(group_rank, 1);
subtensor_invariance[3] = PortDescriptor::ServiceDimensions::FULL_DIM;
PortDescriptorUtils::set_port_descriptor_ptr(reduce_mean->input(0), std::make_shared<PortDescriptor>(reduce_mean->input(0), subtensor_invariance));
PortDescriptorUtils::set_port_descriptor_ptr(reduce_mean->output(0), std::make_shared<PortDescriptor>(reduce_mean->output(0), subtensor_invariance));
PortDescriptorUtils::set_port_descriptor_ptr(sqr_mean->input(0), std::make_shared<PortDescriptor>(sqr_mean->input(0), subtensor_invariance));
PortDescriptorUtils::set_port_descriptor_ptr(sqr_mean->input(1), std::make_shared<PortDescriptor>(sqr_mean->input(1), subtensor_invariance));
PortDescriptorUtils::set_port_descriptor_ptr(sqr_mean->output(0), std::make_shared<PortDescriptor>(sqr_mean->output(0), subtensor_invariance));
PortDescriptorUtils::set_port_descriptor_ptr(eps_add->input(0), std::make_shared<PortDescriptor>(eps_add->input(0), subtensor_invariance));
PortDescriptorUtils::set_port_descriptor_ptr(eps_add->input(1), std::make_shared<PortDescriptor>(eps_add->input(1), subtensor_invariance));
PortDescriptorUtils::set_port_descriptor_ptr(eps_add->output(0), std::make_shared<PortDescriptor>(eps_add->output(0), subtensor_invariance));
PortDescriptorUtils::set_port_descriptor_ptr(variance->input(0), std::make_shared<PortDescriptor>(variance->input(0), subtensor_invariance));
PortDescriptorUtils::set_port_descriptor_ptr(variance->output(0), std::make_shared<PortDescriptor>(variance->output(0), subtensor_invariance));
PortDescriptorUtils::set_port_descriptor_ptr(variance_inv->input(0), std::make_shared<PortDescriptor>(variance_inv->input(0), subtensor_invariance));
PortDescriptorUtils::set_port_descriptor_ptr(variance_inv->output(0), std::make_shared<PortDescriptor>(variance_inv->output(0), subtensor_invariance));

auto mvn = std::make_shared<ov::op::v1::Multiply>(sub_mean, variance_inv);

// reshape mvn from [N, group, 1, (C / group) * spatial] to [N, group, C / group, spatial]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -471,11 +471,35 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
},
ov::pass::NormalizeL2Decomposition);

CPU_SET_CALLBACK_X64(manager,
[this](const_node_ptr &node) -> bool {
return !node->is_dynamic() && node->get_element_type() == element::f32 && inferencePrecision != ov::element::bf16;
},
ov::pass::GroupNormalizationDecomposition);
if (!useLpt) {
CPU_SET_CALLBACK_X64(manager,
[this](const_node_ptr &node) -> bool {
// This is a callback from snippets. If GroupNorm node is appropriate for snippets execution with higher perf,
// then it will not be decomposed to mvn+reshape+eltwises, support it with snippets instead.
if (node->is_dynamic() || inferencePrecision != element::f32)
return false;
const auto group_norm = ov::as_type_ptr<const ov::op::v12::GroupNormalization>(node);
if (!group_norm)
return false;
const auto num_groups = static_cast<size_t>(group_norm->get_num_groups());
const auto shape = group_norm->get_input_partial_shape(0).to_shape();
size_t snippets_work_amount = shape[0] * num_groups;
size_t concurrency = parallel_get_max_threads();
if (concurrency > snippets_work_amount)
return false;
size_t spatial_dim = 1;
for (size_t i = 2; i < shape.size(); ++i)
spatial_dim = spatial_dim * shape[i];
size_t snippets_tensor_size = spatial_dim * shape[1] / num_groups * node->get_element_type().size();
size_t cache_size_l1 = dnnl::utils::get_cache_size(1, true);
if (snippets_tensor_size > cache_size_l1) {
return false;
}

return true;
},
ov::pass::GroupNormalizationDecomposition);
}

CPU_ENABLE_PASS_COMMON(manager, ov::pass::SoftmaxDecomposition);
CPU_SET_CALLBACK_COMMON(manager,
Expand Down

0 comments on commit 105387f

Please sign in to comment.