diff --git a/src/common/snippets/include/snippets/lowered/linear_ir.hpp b/src/common/snippets/include/snippets/lowered/linear_ir.hpp index 296f50d69b3249..2ebaa7c2ab1728 100644 --- a/src/common/snippets/include/snippets/lowered/linear_ir.hpp +++ b/src/common/snippets/include/snippets/lowered/linear_ir.hpp @@ -231,6 +231,14 @@ class LinearIR { */ static std::vector propagate_expr_through_shape_infer_ops(const ExpressionPtr& start_expr, bool downstream); + /** + * @brief Get last shape infer op from start_expr in a sequence. If no shape infer op is connect to start_expr, return start_expr. + * @param start_expr Search from start_expr. + * @param downstream search downstream if it's true, otherwise search upstream. + * @return last shape infer expr + */ + static ExpressionPtr get_last_shape_infer_expr(const ExpressionPtr& start_expr, bool downstream); + private: std::shared_ptr m_shape_infer = nullptr; diff --git a/src/common/snippets/include/snippets/op/subgraph.hpp b/src/common/snippets/include/snippets/op/subgraph.hpp index b03648c76dc2c1..2b02bac7b7b5c6 100644 --- a/src/common/snippets/include/snippets/op/subgraph.hpp +++ b/src/common/snippets/include/snippets/op/subgraph.hpp @@ -140,6 +140,7 @@ class Subgraph : public ov::op::util::SubGraphOp { static auto get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t; static auto is_domain_sensitive_op(const std::shared_ptr& op) -> bool; static auto is_shape_infer_op(const std::shared_ptr& op) -> bool; + static auto get_last_shape_infer_op(const std::shared_ptr& op, bool downstream) -> std::shared_ptr; void data_flow_transformations(const BlockedShapeVector& blocked_input_shapes = {}, const std::vector& input_precisions = {}, diff --git a/src/common/snippets/include/snippets/utils.hpp b/src/common/snippets/include/snippets/utils.hpp index 764bde23cad7fc..9669796628ad44 100644 --- a/src/common/snippets/include/snippets/utils.hpp +++ b/src/common/snippets/include/snippets/utils.hpp @@ -154,8 +154,13 @@ VectorDims get_planar_vdims(const snippets::lowered::ExpressionPort& expr_port); * @return preordered shape: `shape[i]` = `planar_shape[order[i]]` where `shape` is shape before applying the order. */ VectorDims get_preordered_vdims(const snippets::lowered::ExpressionPort& expr_port); +/** + * @brief Returns element count of a shape + * @param shape input shape + * @return element count of input shape + */ inline auto get_shape_size(const VectorDims& shape) -> size_t { - return std::accumulate(shape.begin(), shape.end(), static_cast(1), std::multiplies()); + return std::accumulate(shape.begin(), shape.end(), static_cast(1), std::multiplies()); } /* --------------------------- */ diff --git a/src/common/snippets/src/lowered/linear_ir.cpp b/src/common/snippets/src/lowered/linear_ir.cpp index 67e7eedda67c1a..3bff272f03f6c4 100644 --- a/src/common/snippets/src/lowered/linear_ir.cpp +++ b/src/common/snippets/src/lowered/linear_ir.cpp @@ -366,13 +366,14 @@ VectorDims LinearIR::get_master_shape() const { } // Note: Snippets would benefit from a more generic master_shape calculation approach. // It will be implemented in the scope of ROI propagation activity (ticket 120505) - const auto& source = out_exprs[0]->get_input_port_connector(0)->get_source(); - auto last_exp = source.get_expr(); - if (!m_config.m_enable_domain_optimization && out_exprs.size() == 1 && - ov::is_type(source.get_expr()->get_node())) { - master_shape = utils::get_preordered_vdims(source); - } else if (out_exprs.size() == 1 && ov::is_type(last_exp->get_node())) { - master_shape = utils::get_preordered_vdims(last_exp->get_input_port_connector(0)->get_source()); + if (out_exprs.size() == 1) { + const auto& source = out_exprs[0]->get_input_port_connector(0)->get_source(); + if (!m_config.m_enable_domain_optimization && ov::is_type(source.get_expr()->get_node())) { + master_shape = utils::get_preordered_vdims(source); + } else { + auto last_shape_infer_expr = LinearIR::get_last_shape_infer_expr(out_exprs[0], false); + master_shape = utils::get_preordered_vdims(last_shape_infer_expr->get_input_port_connector(0)->get_source()); + } } else { for (const auto& oe : out_exprs) { const auto& port_desc = oe->get_input_port_descriptor(0); @@ -514,7 +515,7 @@ std::vector LinearIR::propagate_expr_through_shape_infer_ops(cons current_exp = first_child; if (current_exp->get_output_count() == 0) break; - auto consumers = current_exp->get_output_port_connector(0)->get_consumers(); + consumers = current_exp->get_output_port_connector(0)->get_consumers(); first_child = consumers.begin()->get_expr(); } return shape_infer_exprs; @@ -534,6 +535,37 @@ std::vector LinearIR::propagate_expr_through_shape_infer_ops(cons } } +ExpressionPtr LinearIR::get_last_shape_infer_expr(const ExpressionPtr& start_expr, bool downstream) { + auto last_exp = start_expr; + if (downstream) { + if (last_exp->get_output_count() == 0) + return last_exp; + auto consumers = last_exp->get_output_port_connector(0)->get_consumers(); + auto first_child = consumers.begin()->get_expr(); + while (op::Subgraph::is_shape_infer_op(first_child->get_node())) { + OPENVINO_ASSERT(consumers.size() == 1, "Shape infer ops are supposed to be the only consumer."); + last_exp = first_child; + if (last_exp->get_output_count() == 0) + break; + consumers = last_exp->get_output_port_connector(0)->get_consumers(); + first_child = consumers.begin()->get_expr(); + } + return last_exp; + } else { + // upstream + if (last_exp->get_input_count() == 0) + return last_exp; + auto first_source = last_exp->get_input_port_connector(0)->get_source().get_expr(); + while (op::Subgraph::is_shape_infer_op(first_source->get_node())) { + last_exp = first_source; + if (last_exp->get_input_count() == 0) + break; + first_source = last_exp->get_input_port_connector(0)->get_source().get_expr(); + } + return last_exp; + } +} + LinearIR::LIRShapeInfer::LIRShapeInfer(container& body_exprs, io_container& io_exprs) : ShapeInferSnippetsNode(), m_exprs{std::make_shared(body_exprs)} { diff --git a/src/common/snippets/src/lowered/pass/allocate_buffers.cpp b/src/common/snippets/src/lowered/pass/allocate_buffers.cpp index f287ee9edcfedb..cfdab4b48287c7 100644 --- a/src/common/snippets/src/lowered/pass/allocate_buffers.cpp +++ b/src/common/snippets/src/lowered/pass/allocate_buffers.cpp @@ -46,7 +46,8 @@ void AllocateBuffers::set_buffer_offset(const ExpressionPtr& buffer_expr, const } } // Propagate to down: in Load. Buffer can have several Load - const auto& buffer_out = buffer_expr->get_output_port_connector(0); + auto last_shape_infer = ov::snippets::lowered::LinearIR::get_last_shape_infer_expr(buffer_expr, true); + const auto& buffer_out = last_shape_infer->get_output_port_connector(0); for (const auto& child_expr_input : buffer_out->get_consumers()) { const auto& child_expr = child_expr_input.get_expr(); const auto port = child_expr_input.get_index(); @@ -54,13 +55,12 @@ void AllocateBuffers::set_buffer_offset(const ExpressionPtr& buffer_expr, const auto memory_access = ov::as_type_ptr(child_node); if (memory_access && memory_access->is_memory_access_input_port(port)) { memory_access->set_input_offset(offset, port); - } else if (ov::is_type(child_node) || op::Subgraph::is_shape_infer_op(child_node)) { + } else if (ov::is_type(child_node)) { // After Loop initialization, Buffer can be connected to LoopEnd - it's ok - // There are also buffer before shape-changing ops continue; } else { OPENVINO_THROW( - "Buffer::set_offset() was called when Buffer didn't have the corresponding MemoryAccess op for offset propagation"); + "Buffer::set_offset() was called when Buffer didn't have the corresponding MemoryAccess op for offset propagation"); } } } diff --git a/src/common/snippets/src/lowered/pass/insert_buffers.cpp b/src/common/snippets/src/lowered/pass/insert_buffers.cpp index af8d6de30b963b..3174add775fae5 100644 --- a/src/common/snippets/src/lowered/pass/insert_buffers.cpp +++ b/src/common/snippets/src/lowered/pass/insert_buffers.cpp @@ -190,12 +190,8 @@ void InsertBuffers::insertion(LinearIR& linear_ir, parent_expr_output, m_buffer_allocation_rank); const auto buffer = std::make_shared(parent->output(parent_port), allocation_shape); - if (has_shape_infer_parent) { - linear_ir.insert_node(buffer, std::vector{ parent_expr_output }, buffer_loop_ids, false, pos, - { top_shape_infer_expr->get_input_port(0) }); - } else { - linear_ir.insert_node(buffer, std::vector{ parent_expr_output }, buffer_loop_ids, false, pos, { *entry_port }); - } + const auto buffer_consumer = has_shape_infer_parent ? top_shape_infer_expr->get_input_port(0) : *entry_port; + linear_ir.insert_node(buffer, std::vector{ parent_expr_output }, buffer_loop_ids, false, pos, { buffer_consumer }); } } diff --git a/src/common/snippets/src/lowered/pass/insert_load_store.cpp b/src/common/snippets/src/lowered/pass/insert_load_store.cpp index 3e0afe9cf7e3cb..fcc21ceedc1cde 100644 --- a/src/common/snippets/src/lowered/pass/insert_load_store.cpp +++ b/src/common/snippets/src/lowered/pass/insert_load_store.cpp @@ -36,10 +36,7 @@ size_t InsertLoadStore::get_count(const ExpressionPort& port) const { bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExprIt& data_expr_it) { std::shared_ptr data_expr = *data_expr_it; - auto shape_infer_consumers = LinearIR::propagate_expr_through_shape_infer_ops(data_expr, true); - if (!shape_infer_consumers.empty()) - data_expr = shape_infer_consumers.back(); - + data_expr = LinearIR::get_last_shape_infer_expr(data_expr, true); const auto& data_ngraph_output = data_expr->get_node()->output(0); bool was_inserted = false; const auto& data_out = data_expr->get_output_port_connector(0); @@ -60,9 +57,7 @@ bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExpr bool InsertLoadStore::insert_store(LinearIR& linear_ir, const LinearIR::constExprIt& data_expr_it) { auto data_expr = *data_expr_it; - auto shape_infer_consumers = LinearIR::propagate_expr_through_shape_infer_ops(data_expr, false); - if (!shape_infer_consumers.empty()) - data_expr = shape_infer_consumers.back(); + data_expr = LinearIR::get_last_shape_infer_expr(data_expr, false); const auto& parent_output = data_expr->get_input_port_connector(0)->get_source(); const auto& parent_expr = parent_output.get_expr(); diff --git a/src/common/snippets/src/lowered/pass/validate.cpp b/src/common/snippets/src/lowered/pass/validate.cpp index 68b4d75c541d57..b9a57801d6a351 100644 --- a/src/common/snippets/src/lowered/pass/validate.cpp +++ b/src/common/snippets/src/lowered/pass/validate.cpp @@ -32,8 +32,7 @@ void validate_ports(const ExpressionPtr& expr) { void validate_parameter(const ExpressionPtr& expr, const LinearIR& linear_ir) { OPENVINO_ASSERT(ov::is_type(expr->get_node()), "Parameter validation expects Parameter op"); - auto shape_infer_consumers = LinearIR::propagate_expr_through_shape_infer_ops(expr, true); - auto expr_val = shape_infer_consumers.empty() ? expr : shape_infer_consumers.back(); + auto expr_val = LinearIR::get_last_shape_infer_expr(expr, true); auto consumer_inputs = expr_val->get_output_port_connector(0)->get_consumers(); std::set> layouts; for (const auto& consumer_input : consumer_inputs) { @@ -52,8 +51,7 @@ void validate_parameter(const ExpressionPtr& expr, const LinearIR& linear_ir) { void validate_result(const ExpressionPtr& expr, const LinearIR& linear_ir) { OPENVINO_ASSERT(ov::is_type(expr->get_node()), "Result validation expects Result op"); - auto shape_infer_parents = snippets::lowered::LinearIR::propagate_expr_through_shape_infer_ops(expr, false); - auto expr_val = shape_infer_parents.empty() ? expr : shape_infer_parents.back(); + auto expr_val = LinearIR::get_last_shape_infer_expr(expr, false); const auto source = expr_val->get_input_port_connector(0)->get_source(); const auto ma = ov::as_type_ptr(source.get_expr()->get_node()); OPENVINO_ASSERT(ma && ma->is_memory_access_output_port(source.get_index()), @@ -68,10 +66,7 @@ void validate_buffer(const ExpressionPtr& expr, const LinearIR& linear_ir) { const auto ma = ov::as_type_ptr(source.get_expr()->get_node()); OPENVINO_ASSERT(ma && ma->is_memory_access_input_port(source.get_index()), "Buffer expects MemoryAccess parent"); - - auto shape_infer_consumers = LinearIR::propagate_expr_through_shape_infer_ops(expr, true); - auto expr_val = shape_infer_consumers.empty() ? expr : shape_infer_consumers.back(); - + auto expr_val = LinearIR::get_last_shape_infer_expr(expr, true); const auto& out = expr_val->get_output_port_connector(0); const auto consumers = out->get_consumers(); for (const auto& consumer_input : consumers) { diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index 6acfe2bb9663da..3e050ee97cdd88 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -87,6 +87,34 @@ auto Subgraph::is_shape_infer_op(const std::shared_ptr& op) -> bool { ov::is_type(op); } +auto Subgraph::get_last_shape_infer_op(const std::shared_ptr& op, bool downstream) -> std::shared_ptr { + auto last_op = op; + if (downstream) { + if (last_op->get_output_size() == 0) + return last_op; + auto first_child = last_op->get_output_target_inputs(0).begin()->get_node()->shared_from_this(); + while (op::Subgraph::is_shape_infer_op(first_child)) { + last_op = first_child; + if (last_op->get_output_size() == 0) + break; + first_child = last_op->get_output_target_inputs(0).begin()->get_node()->shared_from_this(); + } + return last_op; + } else { + // upstream + if (last_op->get_input_size() == 0) + return last_op; + auto first_parent = last_op->get_input_node_shared_ptr(0); + while (op::Subgraph::is_shape_infer_op(first_parent)) { + last_op = first_parent; + if (last_op->get_input_size() == 0) + break; + first_parent = last_op->get_input_node_shared_ptr(0); + } + return last_op; + } +} + void Subgraph::init_config() { auto update = [](bool& flag, bool status) { flag = flag || status; }; const auto ops = body_ptr()->get_ops(); @@ -327,8 +355,7 @@ VectorDims Subgraph::infer_master_shape() { OPENVINO_ASSERT(!output_dims.empty(), "Can't calculate master_shape before the first shape inference"); } else { for (const auto& res : body_ptr()->get_results()) { - auto reshape = ov::as_type_ptr(res->get_input_node_shared_ptr(0)); - auto res_input = reshape ? reshape->input(0) : res->input(0); + auto res_input = get_last_shape_infer_op(res, false)->input(0); OPENVINO_ASSERT(res_input.get_partial_shape().is_static(), "Result have dynamic shape in static pipeline"); // We need to account to the shape's layout stored in Output rt_info const auto& planar_shape = utils::get_preordered_pshape(res_input.get_source_output()); diff --git a/src/common/snippets/src/pass/align_element_types.cpp b/src/common/snippets/src/pass/align_element_types.cpp index 34250a7e1a1429..c159167c7496e7 100644 --- a/src/common/snippets/src/pass/align_element_types.cpp +++ b/src/common/snippets/src/pass/align_element_types.cpp @@ -29,7 +29,7 @@ bool pass::AlignElementTypes::run_on_model(const std::shared_ptr& m) for (size_t i = 0; i < m_output_precisions.size(); i++) { const auto needed_out_type = m_output_precisions[i]; if (results[i]->get_input_element_type(0) != needed_out_type) { - std::shared_ptr consumer = results[i]; + std::shared_ptr consumer = op::Subgraph::get_last_shape_infer_op(results[i], false); auto parent_output = consumer->get_input_source_output(0); // Snippets supports Transpose only after Parameter or before Result nodes @@ -76,18 +76,11 @@ bool pass::AlignElementTypes::run_on_model(const std::shared_ptr& m) parameter->set_element_type(needed_in_type); parameter->validate_and_infer_types(); - auto parent_output = parameter->output(0); - auto consumer_inputs = parent_output.get_target_inputs(); - - auto first_child = consumer_inputs.begin()->get_node()->shared_from_this(); // Note: shape infer ops is designed for shape-inference purposes only. // It does not process any data (nor does it emit any code), so it doesn't require Convert operations - while (op::Subgraph::is_shape_infer_op(first_child)) { - OPENVINO_ASSERT(consumer_inputs.size() == 1, "Shape infer ops are supposed to be the only consumer"); - parent_output = first_child->output(0); - consumer_inputs = parent_output.get_target_inputs(); - first_child = consumer_inputs.begin()->get_node()->shared_from_this(); - } + auto first_child = op::Subgraph::get_last_shape_infer_op(parameter, true); + auto parent_output = first_child->output(0); + auto consumer_inputs = parent_output.get_target_inputs(); // Snippets supports Transpose only after Parameter or before Result nodes // So we have to insert Convert after Transpose (if there is) on Subgraph inputs diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.cpp index f32985f9999d5b..e82d7cdd5b36dc 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.cpp @@ -201,8 +201,7 @@ jit_kernel_static_emitter::jit_kernel_static_emitter(dnnl::impl::cpu::x64::jit_g switch (expr->get_type()) { case snippets::lowered::IOExpression::io_type::INPUT: { // input->shape changing ops->load - auto shape_infer_consumers = snippets::lowered::LinearIR::propagate_expr_through_shape_infer_ops(expr, true); - auto mem_desc_expr = shape_infer_consumers.empty() ? expr : shape_infer_consumers.back(); + auto mem_desc_expr = ov::snippets::lowered::LinearIR::get_last_shape_infer_expr(expr, true); auto consumer_inputs = mem_desc_expr->get_output_port_connector(0)->get_consumers(); for (const auto& child_input : consumer_inputs) { const auto ma = ov::as_type_ptr(child_input.get_expr()->get_node()); @@ -217,8 +216,7 @@ jit_kernel_static_emitter::jit_kernel_static_emitter(dnnl::impl::cpu::x64::jit_g } case snippets::lowered::IOExpression::io_type::OUTPUT: { // store->shape changing ops->result - auto shape_infer_sources = snippets::lowered::LinearIR::propagate_expr_through_shape_infer_ops(expr, false); - auto mem_desc_expr = shape_infer_sources.empty() ? expr : shape_infer_sources.back(); + auto mem_desc_expr = ov::snippets::lowered::LinearIR::get_last_shape_infer_expr(expr, false); desc = mem_desc_expr->get_input_port_connector(0)->get_source().get_descriptor_ptr(); etype = mem_desc_expr->get_node()->get_input_element_type(0); break; diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 0eff25aed316ce..f740b397b54f9f 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -471,10 +471,9 @@ void Transformations::PreLpt(const std::vector& defaultPrecis }, ov::pass::NormalizeL2Decomposition); - // todo: only support f32 in first version CPU_SET_CALLBACK_X64(manager, - [](const_node_ptr &node) -> bool { - return !node->is_dynamic() && node->get_element_type() == element::f32; + [this](const_node_ptr &node) -> bool { + return !node->is_dynamic() && node->get_element_type() == element::f32 && inferencePrecision != ov::element::bf16; }, ov::pass::GroupNormalizationDecomposition);