diff --git a/src/common/snippets/include/snippets/lowered/linear_ir.hpp b/src/common/snippets/include/snippets/lowered/linear_ir.hpp index 9c8ac3f1f25b4d..296f50d69b3249 100644 --- a/src/common/snippets/include/snippets/lowered/linear_ir.hpp +++ b/src/common/snippets/include/snippets/lowered/linear_ir.hpp @@ -223,6 +223,14 @@ class LinearIR { */ exprIt replace_with_expr(const std::vector& old_exprs, const ExpressionPtr& new_expr); + /** + * @brief Propagate start_expr through zero to several consecutive shape infer exprs(such as reshape, rankNormalization). + * @param start_expr Propagate from start_expr. + * @param downstream Propagate downstream if it's true, otherwise propagate upstream. + * @return shape infer op consumers as a sequence if downstream, or shape infer op sources as a sequence if upstream. + */ + static std::vector propagate_expr_through_shape_infer_ops(const ExpressionPtr& start_expr, bool downstream); + private: std::shared_ptr m_shape_infer = nullptr; diff --git a/src/common/snippets/include/snippets/op/subgraph.hpp b/src/common/snippets/include/snippets/op/subgraph.hpp index 2c0558abdc7529..b03648c76dc2c1 100644 --- a/src/common/snippets/include/snippets/op/subgraph.hpp +++ b/src/common/snippets/include/snippets/op/subgraph.hpp @@ -139,6 +139,7 @@ class Subgraph : public ov::op::util::SubGraphOp { // Return estimated unique buffer count (upper bound). It's needed for tokenization static auto get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t; static auto is_domain_sensitive_op(const std::shared_ptr& op) -> bool; + static auto is_shape_infer_op(const std::shared_ptr& op) -> bool; void data_flow_transformations(const BlockedShapeVector& blocked_input_shapes = {}, const std::vector& input_precisions = {}, diff --git a/src/common/snippets/include/snippets/utils.hpp b/src/common/snippets/include/snippets/utils.hpp index 41d87f0e0fe83d..764bde23cad7fc 100644 --- a/src/common/snippets/include/snippets/utils.hpp +++ b/src/common/snippets/include/snippets/utils.hpp @@ -154,6 +154,9 @@ VectorDims get_planar_vdims(const snippets::lowered::ExpressionPort& expr_port); * @return preordered shape: `shape[i]` = `planar_shape[order[i]]` where `shape` is shape before applying the order. */ VectorDims get_preordered_vdims(const snippets::lowered::ExpressionPort& expr_port); +inline auto get_shape_size(const VectorDims& shape) -> size_t { + return std::accumulate(shape.begin(), shape.end(), static_cast(1), std::multiplies()); +} /* --------------------------- */ } // namespace utils diff --git a/src/common/snippets/src/lowered/linear_ir.cpp b/src/common/snippets/src/lowered/linear_ir.cpp index 05d3a934d2b2a4..67e7eedda67c1a 100644 --- a/src/common/snippets/src/lowered/linear_ir.cpp +++ b/src/common/snippets/src/lowered/linear_ir.cpp @@ -12,6 +12,7 @@ #include "openvino/core/graph_util.hpp" #include "openvino/core/type.hpp" #include "snippets/utils.hpp" +#include "snippets/op/subgraph.hpp" namespace ov { namespace snippets { @@ -496,6 +497,43 @@ LinearIR::exprIt LinearIR::replace_with_expr(const std::vector& o return replace_with_expr(old_exprs, new_expr, insertion_place); } +std::vector LinearIR::propagate_expr_through_shape_infer_ops(const ExpressionPtr& start_expr, bool downstream) { + std::vector shape_infer_exprs; + auto current_exp = start_expr; + if (op::Subgraph::is_shape_infer_op(current_exp->get_node())) { + shape_infer_exprs.push_back(current_exp); + } + if (downstream) { + if (current_exp->get_output_count() == 0) + return shape_infer_exprs; + auto consumers = current_exp->get_output_port_connector(0)->get_consumers(); + auto first_child = consumers.begin()->get_expr(); + while (op::Subgraph::is_shape_infer_op(first_child->get_node())) { + OPENVINO_ASSERT(consumers.size() == 1, "Shape infer ops are supposed to be the only consumer."); + shape_infer_exprs.push_back(first_child); + current_exp = first_child; + if (current_exp->get_output_count() == 0) + break; + auto consumers = current_exp->get_output_port_connector(0)->get_consumers(); + first_child = consumers.begin()->get_expr(); + } + return shape_infer_exprs; + } else { + // upstream + if (current_exp->get_input_count() == 0) + return shape_infer_exprs; + auto first_source = current_exp->get_input_port_connector(0)->get_source().get_expr(); + while (op::Subgraph::is_shape_infer_op(first_source->get_node())) { + shape_infer_exprs.push_back(first_source); + current_exp = first_source; + if (current_exp->get_input_count() == 0) + break; + first_source = current_exp->get_input_port_connector(0)->get_source().get_expr(); + } + return shape_infer_exprs; + } +} + LinearIR::LIRShapeInfer::LIRShapeInfer(container& body_exprs, io_container& io_exprs) : ShapeInferSnippetsNode(), m_exprs{std::make_shared(body_exprs)} { diff --git a/src/common/snippets/src/lowered/pass/allocate_buffers.cpp b/src/common/snippets/src/lowered/pass/allocate_buffers.cpp index aa13a7681dcea3..f287ee9edcfedb 100644 --- a/src/common/snippets/src/lowered/pass/allocate_buffers.cpp +++ b/src/common/snippets/src/lowered/pass/allocate_buffers.cpp @@ -54,12 +54,13 @@ void AllocateBuffers::set_buffer_offset(const ExpressionPtr& buffer_expr, const auto memory_access = ov::as_type_ptr(child_node); if (memory_access && memory_access->is_memory_access_input_port(port)) { memory_access->set_input_offset(offset, port); - } else if (ov::is_type(child_node)) { + } else if (ov::is_type(child_node) || op::Subgraph::is_shape_infer_op(child_node)) { // After Loop initialization, Buffer can be connected to LoopEnd - it's ok + // There are also buffer before shape-changing ops continue; } else { - // OPENVINO_THROW( - // "Buffer::set_offset() was called when Buffer didn't have the corresponding MemoryAccess op for offset propagation"); + OPENVINO_THROW( + "Buffer::set_offset() was called when Buffer didn't have the corresponding MemoryAccess op for offset propagation"); } } } diff --git a/src/common/snippets/src/lowered/pass/assign_registers.cpp b/src/common/snippets/src/lowered/pass/assign_registers.cpp index 13b4727151681b..65499819b3685a 100644 --- a/src/common/snippets/src/lowered/pass/assign_registers.cpp +++ b/src/common/snippets/src/lowered/pass/assign_registers.cpp @@ -5,6 +5,7 @@ #include "snippets/lowered/pass/assign_registers.hpp" #include "snippets/lowered/linear_ir.hpp" +#include "snippets/op/subgraph.hpp" #include "snippets/snippets_isa.hpp" #include "snippets/itt.hpp" @@ -79,23 +80,22 @@ bool AssignRegisters::run(LinearIR& linear_ir) { if (io_expr->get_type() == IOExpression::io_type::INPUT) { const auto& out_connector = expr->get_output_port_connector(0); manually_assigned_gprs[out_connector] = io_expr->get_index(); - // TODO [96434]: Support RankNormalization/Reshape in arbitrary place in pipeline, not just after inputs - // reshape rankNormalization sequence - auto consumer_inputs = out_connector->get_consumers(); - auto child_exp = consumer_inputs.begin()->get_expr(); - while (ov::is_type(child_exp->get_node()) || - ov::is_type(child_exp->get_node())) { - OPENVINO_ASSERT(consumer_inputs.size() == 1, "RankNormalization or Reshape is supposed to be the only consumer"); - manually_assigned_gprs[child_exp->get_output_port_connector(0)] = io_expr->get_index(); - consumer_inputs = child_exp->get_output_port_connector(0)->get_consumers(); - child_exp = consumer_inputs.begin()->get_expr(); + // TODO [96434]: Support shape infer ops in arbitrary place in pipeline, not just after inputs + // shape infer ops sequence after input + auto shape_infer_consumers = LinearIR::propagate_expr_through_shape_infer_ops(io_expr, true); + if (!shape_infer_consumers.empty()) { + for (const auto& child_shape_infer_expr : shape_infer_consumers) { + manually_assigned_gprs[child_shape_infer_expr->get_output_port_connector(0)] = io_expr->get_index(); + } } } else if (io_expr->get_type() == IOExpression::io_type::OUTPUT) { manually_assigned_gprs[expr->get_input_port_connector(0)] = num_parameters + io_expr->get_index(); - // reshape before result - const auto &parent = expr->get_input_port_connector(0)->get_source().get_expr(); - if (ov::is_type(parent->get_node())) { - manually_assigned_gprs[parent->get_input_port_connector(0)] = num_parameters + io_expr->get_index(); + // shape infer ops sequence before result + auto shape_infer_sources = LinearIR::propagate_expr_through_shape_infer_ops(io_expr, false); + if (!shape_infer_sources.empty()) { + for (const auto& parent_shape_infer_expr : shape_infer_sources) { + manually_assigned_gprs[parent_shape_infer_expr->get_input_port_connector(0)] = num_parameters + io_expr->get_index(); + } } } else { OPENVINO_THROW("Unsupported io_type detected"); @@ -106,13 +106,15 @@ bool AssignRegisters::run(LinearIR& linear_ir) { if (ov::is_type(buffer)) { manually_assigned_gprs[expr->get_input_port_connector(0)] = static_cast(num_results + num_parameters + buffer_id); - // reshape in the middle of subgraph. IntermediateMemoryBuffer is inserted before reshape as new loop should start. - const auto& first_consumer = expr->get_output_port_connector(0)->get_consumers().begin()->get_expr(); - if (ov::is_type(first_consumer->get_node())) { - manually_assigned_gprs[first_consumer->get_input_port_connector(0)] = - static_cast(num_results + num_parameters + buffer_id); - manually_assigned_gprs[first_consumer->get_output_port_connector(0)] = - static_cast(num_results + num_parameters + buffer_id); + // shape infer ops in the middle of subgraph. IntermediateMemoryBuffer is inserted before reshape as new loop should start. + // child shape info ops share the same memory as IntermediateMemoryBuffer. + auto shape_infer_consumers = LinearIR::propagate_expr_through_shape_infer_ops(expr, true); + if (!shape_infer_consumers.empty()) { + for (const auto& child_shape_infer_expr : shape_infer_consumers) { + manually_assigned_gprs[child_shape_infer_expr->get_input_port_connector(0)] = + manually_assigned_gprs[child_shape_infer_expr->get_output_port_connector(0)] = + static_cast(num_results + num_parameters + buffer_id); + } } } manually_assigned_gprs[expr->get_output_port_connector(0)] = diff --git a/src/common/snippets/src/lowered/pass/insert_buffers.cpp b/src/common/snippets/src/lowered/pass/insert_buffers.cpp index 2b10c1934a33b1..af8d6de30b963b 100644 --- a/src/common/snippets/src/lowered/pass/insert_buffers.cpp +++ b/src/common/snippets/src/lowered/pass/insert_buffers.cpp @@ -149,15 +149,17 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const auto node = expr->get_node(); auto parent_expr_output = expr->get_input_port_connector(port_idx)->get_source(); - auto first_not_reshape_parent_output = [&]() { - auto parent_expr = parent_expr_output.get_expr(); - while (is_type(parent_expr->get_node())) { - parent_expr_output = parent_expr->get_input_port_connector(0)->get_source(); - parent_expr = parent_expr_output.get_expr(); - } - }; - // this parent(before reshape) is used to determine if buffer needed according loopInfo - first_not_reshape_parent_output(); + const auto& first_parent_expr = parent_expr_output.get_expr(); + bool has_shape_infer_parent = false; + auto top_shape_infer_expr = expr; + // parent before shape infer ops is used to determine if buffer needed according loopInfo + auto shape_infer_parents = LinearIR::propagate_expr_through_shape_infer_ops(first_parent_expr, false); + if (!shape_infer_parents.empty()) { + parent_expr_output = shape_infer_parents.back()->get_input_port_connector(0)->get_source(); + has_shape_infer_parent = true; + top_shape_infer_expr = shape_infer_parents.back(); + } + const auto& parent_expr = parent_expr_output.get_expr(); const auto& parent_port = parent_expr_output.get_index(); const auto& parent = parent_expr->get_node(); @@ -167,15 +169,6 @@ void InsertBuffers::insertion(LinearIR& linear_ir, ov::is_type(parent)) continue; - // insert buffer before reshape - auto buffer_child = expr; - bool parent_is_reshape = false; - auto p_exp = expr->get_input_port_connector(port_idx)->get_source().get_expr(); - if (is_type(p_exp->get_node())) { - buffer_child = p_exp; - parent_is_reshape = true; - } - // Each MemoryAccess op needs Buffer const auto parent_ma = ov::as_type_ptr(parent); const auto node_ma = ov::as_type_ptr(node); @@ -197,9 +190,9 @@ void InsertBuffers::insertion(LinearIR& linear_ir, parent_expr_output, m_buffer_allocation_rank); const auto buffer = std::make_shared(parent->output(parent_port), allocation_shape); - if (parent_is_reshape) { + if (has_shape_infer_parent) { linear_ir.insert_node(buffer, std::vector{ parent_expr_output }, buffer_loop_ids, false, pos, - { buffer_child->get_input_port(0) }); + { top_shape_infer_expr->get_input_port(0) }); } else { linear_ir.insert_node(buffer, std::vector{ parent_expr_output }, buffer_loop_ids, false, pos, { *entry_port }); } diff --git a/src/common/snippets/src/lowered/pass/insert_load_store.cpp b/src/common/snippets/src/lowered/pass/insert_load_store.cpp index defaeb6e4ce0df..3e0afe9cf7e3cb 100644 --- a/src/common/snippets/src/lowered/pass/insert_load_store.cpp +++ b/src/common/snippets/src/lowered/pass/insert_load_store.cpp @@ -36,23 +36,9 @@ size_t InsertLoadStore::get_count(const ExpressionPort& port) const { bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExprIt& data_expr_it) { std::shared_ptr data_expr = *data_expr_it; - const auto& consumer_inputs = data_expr->get_output_port_connector(0)->get_consumers(); - auto first_reshape_consumer = [&]() { - auto current_exp = data_expr; - auto first_consumer = consumer_inputs.begin()->get_expr(); - while (1) { - if (is_type(first_consumer->get_node()) || - is_type(first_consumer->get_node())) { - current_exp = first_consumer; - first_consumer = first_consumer->get_output_port_connector(0)->get_consumers().begin()->get_expr(); - // OPENVINO_ASSERT(current_exp->get_output_port_connector(0)->get_consumers().size() == 1, - // "RankNormalization or Reshape is supposed to be the only consumer"); - } else { - return current_exp; - } - } - }; - data_expr = first_reshape_consumer(); + auto shape_infer_consumers = LinearIR::propagate_expr_through_shape_infer_ops(data_expr, true); + if (!shape_infer_consumers.empty()) + data_expr = shape_infer_consumers.back(); const auto& data_ngraph_output = data_expr->get_node()->output(0); bool was_inserted = false; @@ -74,16 +60,15 @@ bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExpr bool InsertLoadStore::insert_store(LinearIR& linear_ir, const LinearIR::constExprIt& data_expr_it) { auto data_expr = *data_expr_it; - auto parent_output = data_expr->get_input_port_connector(0)->get_source(); - auto parent_expr = parent_output.get_expr(); - if (is_type(parent_expr->get_node())) { - data_expr = parent_expr; - parent_output = data_expr->get_input_port_connector(0)->get_source(); - parent_expr = parent_output.get_expr(); - } - auto port = parent_output.get_index(); - auto parent = parent_expr->get_node(); - auto ma = ov::as_type_ptr(parent); + auto shape_infer_consumers = LinearIR::propagate_expr_through_shape_infer_ops(data_expr, false); + if (!shape_infer_consumers.empty()) + data_expr = shape_infer_consumers.back(); + + const auto& parent_output = data_expr->get_input_port_connector(0)->get_source(); + const auto& parent_expr = parent_output.get_expr(); + const auto port = parent_output.get_index(); + const auto& parent = parent_expr->get_node(); + const auto ma = ov::as_type_ptr(parent); if (ma && ma->is_memory_access_output_port(port)) return false; diff --git a/src/common/snippets/src/lowered/pass/validate.cpp b/src/common/snippets/src/lowered/pass/validate.cpp index 8dc95a94f9c015..68b4d75c541d57 100644 --- a/src/common/snippets/src/lowered/pass/validate.cpp +++ b/src/common/snippets/src/lowered/pass/validate.cpp @@ -32,13 +32,9 @@ void validate_ports(const ExpressionPtr& expr) { void validate_parameter(const ExpressionPtr& expr, const LinearIR& linear_ir) { OPENVINO_ASSERT(ov::is_type(expr->get_node()), "Parameter validation expects Parameter op"); - auto consumer_inputs = expr->get_output_port_connector(0)->get_consumers(); - const auto& first_consumer = consumer_inputs.begin()->get_expr(); - if (is_type(first_consumer->get_node())) { - OPENVINO_ASSERT(consumer_inputs.size() == 1, - "If there is RankNormalization after Parameter, it should be single consumer of the Parameter"); - consumer_inputs = first_consumer->get_output_port_connector(0)->get_consumers(); - } + auto shape_infer_consumers = LinearIR::propagate_expr_through_shape_infer_ops(expr, true); + auto expr_val = shape_infer_consumers.empty() ? expr : shape_infer_consumers.back(); + auto consumer_inputs = expr_val->get_output_port_connector(0)->get_consumers(); std::set> layouts; for (const auto& consumer_input : consumer_inputs) { const auto& node = consumer_input.get_expr()->get_node(); @@ -56,7 +52,9 @@ void validate_parameter(const ExpressionPtr& expr, const LinearIR& linear_ir) { void validate_result(const ExpressionPtr& expr, const LinearIR& linear_ir) { OPENVINO_ASSERT(ov::is_type(expr->get_node()), "Result validation expects Result op"); - const auto source = expr->get_input_port_connector(0)->get_source(); + auto shape_infer_parents = snippets::lowered::LinearIR::propagate_expr_through_shape_infer_ops(expr, false); + auto expr_val = shape_infer_parents.empty() ? expr : shape_infer_parents.back(); + const auto source = expr_val->get_input_port_connector(0)->get_source(); const auto ma = ov::as_type_ptr(source.get_expr()->get_node()); OPENVINO_ASSERT(ma && ma->is_memory_access_output_port(source.get_index()), "Result expects MemoryAccess parent"); @@ -71,7 +69,10 @@ void validate_buffer(const ExpressionPtr& expr, const LinearIR& linear_ir) { OPENVINO_ASSERT(ma && ma->is_memory_access_input_port(source.get_index()), "Buffer expects MemoryAccess parent"); - const auto& out = expr->get_output_port_connector(0); + auto shape_infer_consumers = LinearIR::propagate_expr_through_shape_infer_ops(expr, true); + auto expr_val = shape_infer_consumers.empty() ? expr : shape_infer_consumers.back(); + + const auto& out = expr_val->get_output_port_connector(0); const auto consumers = out->get_consumers(); for (const auto& consumer_input : consumers) { const auto& node = consumer_input.get_expr()->get_node(); diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index df4fc1693590f7..6acfe2bb9663da 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -18,7 +18,6 @@ #include "snippets/pass/canonicalization.hpp" #include "snippets/pass/align_element_types.hpp" #include "snippets/pass/reduce_to_snippets_reduce.hpp" -#include "snippets/pass/gn_decomposition.hpp" #include "snippets/utils.hpp" @@ -57,8 +56,6 @@ #include #include -#include "snippets/lowered/pass/serialize_control_flow.hpp" - using namespace std; using namespace ov::op::util; @@ -81,7 +78,13 @@ auto Subgraph::is_domain_sensitive_op(const std::shared_ptr& op) -> bo ov::is_type(op) || ov::is_type(op) || // Broadcast is domain sensetive op because the output shape depends on ov::is_type(op) || // the both input and broadcast shapes (the both - are inputs of op). Note: is used only in MHA pattern - ov::is_type(op); + ov::is_type(op) || + ov::is_type(op); +} + +auto Subgraph::is_shape_infer_op(const std::shared_ptr& op) -> bool { + return ov::is_type(op) || + ov::is_type(op); } void Subgraph::init_config() { @@ -277,7 +280,8 @@ auto Subgraph::constant_input_should_be_inside_body(const std::shared_ptr(node) || ov::is_type(node) || ov::is_type(node) || - ov::is_type(node); + ov::is_type(node) || + ov::is_type(node); } bool Subgraph::check_broadcast(const std::shared_ptr& node) noexcept { @@ -388,10 +392,6 @@ void Subgraph::data_flow_transformations(const BlockedShapeVector& blocked_input OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::data_flow_transformations") ov::snippets::pass::Manager manager; - // GNDecomposition should be before canonicalization(rankNorm) as scale/bias shape is C and need special process. - if (config.m_has_domain_sensitive_ops) - manager.register_pass(); - if (!blocked_input_shapes.empty()) manager.register_pass(blocked_input_shapes); if (!input_precisions.empty() && !output_precisions.empty()) @@ -414,12 +414,6 @@ void Subgraph::data_flow_transformations(const BlockedShapeVector& blocked_input manager.register_positioned_passes(backend_passes); manager.run_passes(body_ptr()); - - // ov::pass::Manager magr; - // std::string xmlo = "data_flow.xml"; - // std::string bino = "data_flow.bin"; - // magr.register_pass(xmlo, bino); - // magr.run_passes(body_ptr()); } void Subgraph::control_flow_transformations(lowered::LinearIR& linear_ir, @@ -493,9 +487,6 @@ snippets::Schedule Subgraph::generate_from_linear_ir(const std::shared_ptrgenerate(linear_ir, lowering_result, compile_params); VectorDims parallel_exec_domain = linear_ir.get_master_shape(); diff --git a/src/common/snippets/src/pass/align_element_types.cpp b/src/common/snippets/src/pass/align_element_types.cpp index 08430af05a0745..34250a7e1a1429 100644 --- a/src/common/snippets/src/pass/align_element_types.cpp +++ b/src/common/snippets/src/pass/align_element_types.cpp @@ -79,14 +79,14 @@ bool pass::AlignElementTypes::run_on_model(const std::shared_ptr& m) auto parent_output = parameter->output(0); auto consumer_inputs = parent_output.get_target_inputs(); - const auto& first_child = consumer_inputs.begin()->get_node()->shared_from_this(); - // Note: RankNormalization of is designed for shape-inference purposes only. + auto first_child = consumer_inputs.begin()->get_node()->shared_from_this(); + // Note: shape infer ops is designed for shape-inference purposes only. // It does not process any data (nor does it emit any code), so it doesn't require Convert operations - if (is_type(first_child) || - is_type(first_child)) { - OPENVINO_ASSERT(consumer_inputs.size() == 1, "RankNormalization is supposed to be the only consumer"); + while (op::Subgraph::is_shape_infer_op(first_child)) { + OPENVINO_ASSERT(consumer_inputs.size() == 1, "Shape infer ops are supposed to be the only consumer"); parent_output = first_child->output(0); consumer_inputs = parent_output.get_target_inputs(); + first_child = consumer_inputs.begin()->get_node()->shared_from_this(); } // Snippets supports Transpose only after Parameter or before Result nodes diff --git a/src/common/snippets/src/pass/common_optimizations.cpp b/src/common/snippets/src/pass/common_optimizations.cpp index 1e10d2dc6dfe6e..7c8089bc776bec 100644 --- a/src/common/snippets/src/pass/common_optimizations.cpp +++ b/src/common/snippets/src/pass/common_optimizations.cpp @@ -6,6 +6,7 @@ #include "snippets/pass/fq_decomposition.hpp" #include "snippets/pass/softmax_reshape_elimination.hpp" +#include "snippets/pass/gn_decomposition.hpp" #include "snippets/pass/explicit_transpose_matmul_inputs.hpp" #include "snippets/pass/transpose_decomposition.hpp" #include "snippets/pass/fuse_transpose_brgemm.hpp" @@ -50,6 +51,7 @@ CommonOptimizations::CommonOptimizations(const SnippetsTokenization::Config& con REGISTER_SNIPPETS_PASS(manager, ov::snippets::pass::ExplicitTransposeMatMulInputs, is_domain_sensitive); REGISTER_SNIPPETS_PASS(manager, ov::snippets::pass::CommonFakeQuantizeDecomposition, is_quantized); REGISTER_SNIPPETS_PASS(manager, ov::snippets::pass::SoftmaxReshapeElimination, is_domain_sensitive); + REGISTER_SNIPPETS_PASS(manager, ov::snippets::pass::GNDecomposition, is_domain_sensitive); manager.run_passes(body); ov::snippets::pass::CommonOptimizations::SubgraphManager subgraph_manager; diff --git a/src/common/snippets/src/pass/gn_decomposition.cpp b/src/common/snippets/src/pass/gn_decomposition.cpp index ea5da94483130f..aec78587b588a5 100644 --- a/src/common/snippets/src/pass/gn_decomposition.cpp +++ b/src/common/snippets/src/pass/gn_decomposition.cpp @@ -5,6 +5,7 @@ #include "snippets/pass/gn_decomposition.hpp" #include "openvino/op/group_normalization.hpp" +#include "snippets/op/reduce.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "snippets/itt.hpp" #include "snippets/lowered/port_descriptor.hpp" @@ -57,10 +58,8 @@ GNDecomposition::GNDecomposition() { reshaped_node1 = std::make_shared(reshaped_node_orig, element::f32); } - // reduceSum on dimension [C / group * spatial] - std::vector axis(1, 3); - auto axis_node = std::make_shared(element::i64, Shape{axis.size()}, axis); - const auto reduce_sum = std::make_shared(reshaped_node1, axis_node, true); + const auto reduce_sum = std::make_shared(reshaped_node1, group_rank - 1); + op::ReduceBase::compute_and_set_reduce_subtensors(reduce_sum); // reduceMean auto group_shape_static = group_shape.to_shape(); @@ -78,7 +77,8 @@ GNDecomposition::GNDecomposition() { auto sqr_const = std::make_shared(element::f32, Shape{1}, std::vector{2}); auto sqr = std::make_shared(sub_mean, sqr_const); // reduceSum((x - mean) ^ 2) - auto sqr_reduce_sum = std::make_shared(sqr, axis_node, true); + auto sqr_reduce_sum = std::make_shared(sqr, group_rank - 1); + op::ReduceBase::compute_and_set_reduce_subtensors(sqr_reduce_sum); // reduceMean((x - mean) ^ 2) const auto group_size_inv_node_aux = std::make_shared(element::f32, Shape{}, std::vector{group_size_inv}); auto sqr_mean = std::make_shared(sqr_reduce_sum, group_size_inv_node_aux); @@ -90,6 +90,23 @@ GNDecomposition::GNDecomposition() { // divide variance const auto variance_inv = std::make_shared(variance, -1.f); + + // remove invariance in inner loop + std::vector subtensor_invariance(group_rank, 1); + subtensor_invariance[3] = PortDescriptor::ServiceDimensions::FULL_DIM; + PortDescriptorUtils::set_port_descriptor_ptr(reduce_mean->input(0), std::make_shared(reduce_mean->input(0), subtensor_invariance)); + PortDescriptorUtils::set_port_descriptor_ptr(reduce_mean->output(0), std::make_shared(reduce_mean->output(0), subtensor_invariance)); + PortDescriptorUtils::set_port_descriptor_ptr(sqr_mean->input(0), std::make_shared(sqr_mean->input(0), subtensor_invariance)); + PortDescriptorUtils::set_port_descriptor_ptr(sqr_mean->input(1), std::make_shared(sqr_mean->input(1), subtensor_invariance)); + PortDescriptorUtils::set_port_descriptor_ptr(sqr_mean->output(0), std::make_shared(sqr_mean->output(0), subtensor_invariance)); + PortDescriptorUtils::set_port_descriptor_ptr(eps_add->input(0), std::make_shared(eps_add->input(0), subtensor_invariance)); + PortDescriptorUtils::set_port_descriptor_ptr(eps_add->input(1), std::make_shared(eps_add->input(1), subtensor_invariance)); + PortDescriptorUtils::set_port_descriptor_ptr(eps_add->output(0), std::make_shared(eps_add->output(0), subtensor_invariance)); + PortDescriptorUtils::set_port_descriptor_ptr(variance->input(0), std::make_shared(variance->input(0), subtensor_invariance)); + PortDescriptorUtils::set_port_descriptor_ptr(variance->output(0), std::make_shared(variance->output(0), subtensor_invariance)); + PortDescriptorUtils::set_port_descriptor_ptr(variance_inv->input(0), std::make_shared(variance_inv->input(0), subtensor_invariance)); + PortDescriptorUtils::set_port_descriptor_ptr(variance_inv->output(0), std::make_shared(variance_inv->output(0), subtensor_invariance)); + auto mvn = std::make_shared(sub_mean, variance_inv); // reshape mvn from [N, group, 1, (C / group) * spatial] to [N, group, C / group, spatial] @@ -122,19 +139,12 @@ GNDecomposition::GNDecomposition() { auto result_prec = group_norm_node->get_output_element_type(0); std::shared_ptr biased_node_convert = biased_node; if (result_prec != element::f32) { - biased_node_convert = std::make_shared(biased_node, data.get_element_type()); + biased_node_convert = std::make_shared(biased_node, result_prec); } // reshape_back [N, group, C / group, spatial] to [N, C, spatial] const auto reshape_back_node = std::make_shared(biased_node_convert, orig_shape); - std::vector subtensor(group_rank, 1); - subtensor[3] = PortDescriptor::ServiceDimensions::FULL_DIM; - PortDescriptorUtils::set_port_descriptor_ptr(reduce_sum->input(0), std::make_shared(reduce_sum->input(0), subtensor)); - PortDescriptorUtils::set_port_descriptor_ptr(reduce_sum->output(0), std::make_shared(reduce_sum->output(0), subtensor)); - PortDescriptorUtils::set_port_descriptor_ptr(sqr_reduce_sum->input(0), std::make_shared(sqr_reduce_sum->input(0), subtensor)); - PortDescriptorUtils::set_port_descriptor_ptr(sqr_reduce_sum->output(0), std::make_shared(sqr_reduce_sum->output(0), subtensor)); - return ov::replace_node_update_name(group_norm_node, reshape_back_node); }; diff --git a/src/common/snippets/src/pass/gn_tokenization.cpp b/src/common/snippets/src/pass/gn_tokenization.cpp index 4332d4d44d66e0..62fe124b2a4f01 100644 --- a/src/common/snippets/src/pass/gn_tokenization.cpp +++ b/src/common/snippets/src/pass/gn_tokenization.cpp @@ -20,8 +20,7 @@ ov::snippets::pass::TokenizeGNSnippets::TokenizeGNSnippets() { ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::pass::TokenizeGNSnippets") auto group_norm_node = ov::as_type_ptr(m.get_match_root()); - if (group_norm_node->is_dynamic() || - TokenizeSnippets::get_supported_element_types().count(group_norm_node->get_element_type()) == 0) + if (group_norm_node->is_dynamic() || group_norm_node->get_element_type() != element::f32) return false; auto subgraph = op::Subgraph::wrap_node_as_subgraph(group_norm_node); diff --git a/src/common/snippets/src/shape_inference/shape_infer_instances.cpp b/src/common/snippets/src/shape_inference/shape_infer_instances.cpp index ba7ebe082a6fe4..371de613305f37 100644 --- a/src/common/snippets/src/shape_inference/shape_infer_instances.cpp +++ b/src/common/snippets/src/shape_inference/shape_infer_instances.cpp @@ -255,15 +255,12 @@ Result ReshapeShapeInfer::infer(const std::vector& input_shapes) OPENVINO_ASSERT(input_shapes.size() == 1, "Invalid number of shapes is passed in ReshapeShapeInfer"); OPENVINO_ASSERT(target_shape.is_static(), "target_shape should be static in ReshapeShapeInfer"); VectorDims result_shape = target_shape.get_shape(); - const auto input_elems = - std::accumulate(input_shapes[0].get().begin(), input_shapes[0].get().end(), static_cast(1), std::multiplies()); - const auto output_elems = - std::accumulate(result_shape.begin(), result_shape.end(), static_cast(1), std::multiplies()); + const auto input_elems = utils::get_shape_size(input_shapes[0].get()); + const auto output_elems = utils::get_shape_size(result_shape); OPENVINO_ASSERT(input_elems == output_elems, "Tensor volume should be the same after reshape in ReshapeShapeInfer"); return {{result_shape}, ShapeInferStatus::success}; } - } // namespace snippets } // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.cpp index 9aec5d4a933f5e..f32985f9999d5b 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.cpp @@ -200,13 +200,10 @@ jit_kernel_static_emitter::jit_kernel_static_emitter(dnnl::impl::cpu::x64::jit_g element::Type etype; switch (expr->get_type()) { case snippets::lowered::IOExpression::io_type::INPUT: { - // Note that here we consider only the first child (which is usually load), - // but often there is another child - LoopEnd - auto consumer_inputs = expr->get_output_port_connector(0)->get_consumers(); - const auto& first_consumer = consumer_inputs.begin()->get_expr(); - // If there is a RankNormalization op after a parameter - we should skip it - if (is_type(first_consumer->get_node())) - consumer_inputs = first_consumer->get_output_port_connector(0)->get_consumers(); + // input->shape changing ops->load + auto shape_infer_consumers = snippets::lowered::LinearIR::propagate_expr_through_shape_infer_ops(expr, true); + auto mem_desc_expr = shape_infer_consumers.empty() ? expr : shape_infer_consumers.back(); + auto consumer_inputs = mem_desc_expr->get_output_port_connector(0)->get_consumers(); for (const auto& child_input : consumer_inputs) { const auto ma = ov::as_type_ptr(child_input.get_expr()->get_node()); if (ma && ma->is_memory_access_input_port(child_input.get_index())) { @@ -214,19 +211,16 @@ jit_kernel_static_emitter::jit_kernel_static_emitter(dnnl::impl::cpu::x64::jit_g break; } } - etype = expr->get_node()->get_output_element_type(0); + etype = mem_desc_expr->get_node()->get_output_element_type(0); + break; break; } case snippets::lowered::IOExpression::io_type::OUTPUT: { - // store->reshape->result - const auto& source = expr->get_input_port_connector(0)->get_source(); - auto p_exp = source.get_expr(); - if (ov::is_type(p_exp->get_node())) { - desc = p_exp->get_input_port_connector(0)->get_source().get_descriptor_ptr(); - } else { - desc = expr->get_input_port_connector(0)->get_source().get_descriptor_ptr(); - } - etype = expr->get_node()->get_input_element_type(0); + // store->shape changing ops->result + auto shape_infer_sources = snippets::lowered::LinearIR::propagate_expr_through_shape_infer_ops(expr, false); + auto mem_desc_expr = shape_infer_sources.empty() ? expr : shape_infer_sources.back(); + desc = mem_desc_expr->get_input_port_connector(0)->get_source().get_descriptor_ptr(); + etype = mem_desc_expr->get_node()->get_input_element_type(0); break; } default : { OPENVINO_THROW("Kernel detected unsupported io_type"); diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index 623d7ae247f4a7..f7a50ffa14852f 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -641,12 +641,6 @@ void Snippet::SnippetJitExecutor::generate(const jit_snippets_compile_args* jcp) SNIPPETS_REGISTER_PASS_RELATIVE(Place::After, ov::intel_cpu::pass::FuseLoadStoreConvert, ov::intel_cpu::pass::SetBrgemmCopyBBuffersShape); - // ov::pass::Manager magr; - // std::string xmlo = "original.xml"; - // std::string bino = "original.bin"; - // magr.register_pass(xmlo, bino); - // magr.run_passes(snippetAttrs.snippet->body_ptr()); - schedule = snippetAttrs.snippet->generate_from_linear_ir(std::make_shared(), backend_passes, reinterpret_cast(jcp)); diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index de25836e4b0417..0eff25aed316ce 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -474,8 +474,7 @@ void Transformations::PreLpt(const std::vector& defaultPrecis // todo: only support f32 in first version CPU_SET_CALLBACK_X64(manager, [](const_node_ptr &node) -> bool { - return !node->is_dynamic() && - ov::snippets::pass::TokenizeSnippets::get_supported_element_types().count(node->get_element_type()) != 0; + return !node->is_dynamic() && node->get_element_type() == element::f32; }, ov::pass::GroupNormalizationDecomposition); diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/single_layer_tests/group_normalization.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/single_layer_tests/group_normalization.cpp index df2416102e450b..bd7257318235be 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/single_layer_tests/group_normalization.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/single_layer_tests/group_normalization.cpp @@ -8,8 +8,6 @@ using ov::test::GroupNormalizationTest; const std::vector netPrecisions = { ov::element::f32, - // ov::element::bf16, // remove specific merge convert - // ov::element::i8 // ref impl does not support int8 precision }; // static shapes @@ -42,6 +40,11 @@ const std::vector epsilon = { 0.0001 }; +std::vector additionalConfig = { + {{ov::hint::inference_precision(ov::element::f32)}}, + {{ov::hint::inference_precision(ov::element::bf16)}} +}; + INSTANTIATE_TEST_SUITE_P( smoke_GroupNormalizationStatic, GroupNormalizationTest, @@ -52,7 +55,7 @@ INSTANTIATE_TEST_SUITE_P( testing::ValuesIn(numGroups), testing::ValuesIn(epsilon), testing::Values(ov::test::utils::DEVICE_CPU), - testing::Values(ov::AnyMap())), + testing::ValuesIn(additionalConfig)), GroupNormalizationTest::getTestCaseName); INSTANTIATE_TEST_SUITE_P( @@ -65,7 +68,7 @@ INSTANTIATE_TEST_SUITE_P( testing::ValuesIn(numGroups), testing::ValuesIn(epsilon), testing::Values(ov::test::utils::DEVICE_CPU), - testing::Values(ov::AnyMap())), + testing::ValuesIn(additionalConfig)), GroupNormalizationTest::getTestCaseName); } // anonymous namespace \ No newline at end of file diff --git a/src/tests/functional/shared_test_classes/include/shared_test_classes/single_op/group_normalization.hpp b/src/tests/functional/shared_test_classes/include/shared_test_classes/single_op/group_normalization.hpp index 612c53db90ab39..606ee8ede9e972 100644 --- a/src/tests/functional/shared_test_classes/include/shared_test_classes/single_op/group_normalization.hpp +++ b/src/tests/functional/shared_test_classes/include/shared_test_classes/single_op/group_normalization.hpp @@ -27,8 +27,8 @@ class GroupNormalizationTest : public testing::WithParamInterfaceGetParam(); + std::tie(ngPrc, inType, outType, shapes, num_groups, epsilon, targetDevice, additional_config) = this->GetParam(); InputShape biasInputShape = ExtractBiasShape(shapes); init_input_shapes({shapes, biasInputShape, biasInputShape}); ov::ParameterVector params; @@ -73,6 +78,8 @@ class GroupNormalizationTest : public testing::WithParamInterface(results, params, "GroupNormalization"); }