diff --git a/src/common/snippets/include/snippets/lowered/linear_ir.hpp b/src/common/snippets/include/snippets/lowered/linear_ir.hpp index 8ee91cbd56c31a..9c8ac3f1f25b4d 100644 --- a/src/common/snippets/include/snippets/lowered/linear_ir.hpp +++ b/src/common/snippets/include/snippets/lowered/linear_ir.hpp @@ -223,34 +223,6 @@ class LinearIR { */ exprIt replace_with_expr(const std::vector& old_exprs, const ExpressionPtr& new_expr); - /** - * @brief Get zero to several consecutive child shape infer exprs(such as reshape, rankNormalization) from start_expr. - * @param start_expr Collect from start_expr. - * @return shape infer expression consumers as a sequence. - */ - static std::vector get_child_shape_infer_expr_seq(const ExpressionPtr& start_expr); - - /** - * @brief Get zero to several consecutive parent shape infer exprs(such as reshape, rankNormalization) from start_expr. - * @param start_expr Collect from start_expr. - * @return shape infer expression sources as a sequence. - */ - static std::vector get_parent_shape_infer_expr_seq(const ExpressionPtr& start_expr); - - /** - * @brief Get last child shape infer op from start_expr in a sequence. If no shape infer op is connect to start_expr, return start_expr. - * @param start_expr Search from start_expr. - * @return last child shape infer expr - */ - static ExpressionPtr get_last_child_shape_infer_expr(const ExpressionPtr& start_expr); - - /** - * @brief Get last parent shape infer op from start_expr in a sequence. If no shape infer op is connect to start_expr, return start_expr. - * @param start_expr Search from start_expr. - * @return last parent shape infer expr - */ - static ExpressionPtr get_last_parent_shape_infer_expr(const ExpressionPtr& start_expr); - private: std::shared_ptr m_shape_infer = nullptr; diff --git a/src/common/snippets/include/snippets/shape_inference/shape_infer_instances.hpp b/src/common/snippets/include/snippets/shape_inference/shape_infer_instances.hpp index a3dffd973c93dd..c06daa8eb1143f 100644 --- a/src/common/snippets/include/snippets/shape_inference/shape_infer_instances.hpp +++ b/src/common/snippets/include/snippets/shape_inference/shape_infer_instances.hpp @@ -76,7 +76,8 @@ class ReduceShapeInfer : public IShapeInferSnippets { }; class ReshapeShapeInfer : public IShapeInferSnippets { - ov::PartialShape target_shape; + VectorDims target_shape; + size_t target_shape_volume = 0; public: explicit ReshapeShapeInfer(const std::shared_ptr& n); Result infer(const std::vector& input_shapes) override; diff --git a/src/common/snippets/include/snippets/utils.hpp b/src/common/snippets/include/snippets/utils.hpp index 9669796628ad44..acb5d2f20bcc2c 100644 --- a/src/common/snippets/include/snippets/utils.hpp +++ b/src/common/snippets/include/snippets/utils.hpp @@ -154,6 +154,8 @@ VectorDims get_planar_vdims(const snippets::lowered::ExpressionPort& expr_port); * @return preordered shape: `shape[i]` = `planar_shape[order[i]]` where `shape` is shape before applying the order. */ VectorDims get_preordered_vdims(const snippets::lowered::ExpressionPort& expr_port); +/* --------------------------- */ + /** * @brief Returns element count of a shape * @param shape input shape @@ -162,7 +164,22 @@ VectorDims get_preordered_vdims(const snippets::lowered::ExpressionPort& expr_po inline auto get_shape_size(const VectorDims& shape) -> size_t { return std::accumulate(shape.begin(), shape.end(), static_cast(1), std::multiplies()); } -/* --------------------------- */ + +/** + * @brief Get zero to several consecutive child shape infer exprs(such as reshape, rankNormalization) from start_expr. + * As Node maybe have multiple output. This functions return the first(left) legal sequence. + * @param start_expr Collect from start_expr. + * @return shape infer expression consumers as a sequence. + */ +std::vector get_first_child_shape_infer_expr_seq(const lowered::ExpressionPtr& start_expr); + +/** + * @brief Get zero to several consecutive parent shape infer exprs(such as reshape, rankNormalization) from start_expr. + * As Node maybe have multiple input. This functions return the first(left) legal sequence. + * @param start_expr Collect from start_expr. + * @return shape infer expression sources as a sequence. + */ +std::vector get_first_parent_shape_infer_expr_seq(const lowered::ExpressionPtr& start_expr); } // namespace utils } // namespace snippets diff --git a/src/common/snippets/src/lowered/linear_ir.cpp b/src/common/snippets/src/lowered/linear_ir.cpp index 38187346d878f3..a21e0c4f3b6088 100644 --- a/src/common/snippets/src/lowered/linear_ir.cpp +++ b/src/common/snippets/src/lowered/linear_ir.cpp @@ -371,8 +371,9 @@ VectorDims LinearIR::get_master_shape() const { if (!m_config.m_enable_domain_optimization && ov::is_type(source.get_expr()->get_node())) { master_shape = utils::get_preordered_vdims(source); } else { - auto last_shape_infer_expr = LinearIR::get_last_parent_shape_infer_expr(out_exprs[0]); - master_shape = utils::get_preordered_vdims(last_shape_infer_expr->get_input_port_connector(0)->get_source()); + const auto& shape_infer_seq = utils::get_first_parent_shape_infer_expr_seq(out_exprs[0]); + const auto& expr = shape_infer_seq.empty() ? out_exprs[0] : shape_infer_seq.back(); + master_shape = utils::get_preordered_vdims(expr->get_input_port_connector(0)->get_source()); } } else { for (const auto& oe : out_exprs) { @@ -498,92 +499,6 @@ LinearIR::exprIt LinearIR::replace_with_expr(const std::vector& o return replace_with_expr(old_exprs, new_expr, insertion_place); } -std::vector LinearIR::get_child_shape_infer_expr_seq(const ExpressionPtr& start_expr) { - std::vector shape_infer_exprs; - auto current_exp = start_expr; - if (op::Subgraph::is_shape_infer_op(current_exp->get_node())) { - OPENVINO_ASSERT(current_exp->get_input_port_connector(0)->get_consumers().size() == 1, "Shape infer ops are supposed to be the only consumer."); - shape_infer_exprs.push_back(current_exp); - } - if (current_exp->get_output_count() == 0) - return shape_infer_exprs; - auto output_consumers = current_exp->get_output_port_connector(0)->get_consumers(); - auto first_child = output_consumers.begin()->get_expr(); - while (op::Subgraph::is_shape_infer_op(first_child->get_node())) { - OPENVINO_ASSERT(output_consumers.size() == 1, "Shape infer ops are supposed to be the only consumer."); - shape_infer_exprs.push_back(first_child); - current_exp = first_child; - if (current_exp->get_output_count() == 0) - break; - output_consumers = current_exp->get_output_port_connector(0)->get_consumers(); - first_child = output_consumers.begin()->get_expr(); - } - return shape_infer_exprs; -} - -std::vector LinearIR::get_parent_shape_infer_expr_seq(const ExpressionPtr& start_expr) { - std::vector shape_infer_exprs; - auto current_exp = start_expr; - if (op::Subgraph::is_shape_infer_op(current_exp->get_node())) { - OPENVINO_ASSERT(current_exp->get_input_port_connector(0)->get_consumers().size() == 1, "Shape infer ops are supposed to be the only consumer."); - shape_infer_exprs.push_back(current_exp); - } - if (current_exp->get_input_count() == 0) - return shape_infer_exprs; - auto input = current_exp->get_input_port_connector(0); - auto first_parent = input->get_source().get_expr(); - while (op::Subgraph::is_shape_infer_op(first_parent->get_node())) { - shape_infer_exprs.push_back(first_parent); - current_exp = first_parent; - if (current_exp->get_input_count() == 0) - break; - input = current_exp->get_input_port_connector(0); - first_parent = input->get_source().get_expr(); - if (!ov::is_type(first_parent->get_node())) { - // there are maybe some loopEnd consumers of store as well for loop code gen purpose - OPENVINO_ASSERT(input->get_consumers().size() == 1, "Shape infer ops are supposed to be the only consumer if it doesn't consume a store ops."); - } - } - return shape_infer_exprs; -} - -ExpressionPtr LinearIR::get_last_child_shape_infer_expr(const ExpressionPtr& start_expr) { - auto last_exp = start_expr; - if (last_exp->get_output_count() == 0) - return last_exp; - auto consumers = last_exp->get_output_port_connector(0)->get_consumers(); - auto first_child = consumers.begin()->get_expr(); - while (op::Subgraph::is_shape_infer_op(first_child->get_node())) { - OPENVINO_ASSERT(consumers.size() == 1, "Shape infer ops are supposed to be the only consumer."); - last_exp = first_child; - if (last_exp->get_output_count() == 0) - break; - consumers = last_exp->get_output_port_connector(0)->get_consumers(); - first_child = consumers.begin()->get_expr(); - } - return last_exp; -} - -ExpressionPtr LinearIR::get_last_parent_shape_infer_expr(const ExpressionPtr& start_expr) { - auto last_exp = start_expr; - if (last_exp->get_input_count() == 0) - return last_exp; - auto input = last_exp->get_input_port_connector(0); - auto first_parent = input->get_source().get_expr(); - while (op::Subgraph::is_shape_infer_op(first_parent->get_node())) { - last_exp = first_parent; - if (last_exp->get_input_count() == 0) - break; - input = last_exp->get_input_port_connector(0); - first_parent = input->get_source().get_expr(); - if (!ov::is_type(first_parent->get_node())) { - // there are maybe some loopEnd consumers of store as well for loop code gen purpose - OPENVINO_ASSERT(input->get_consumers().size() == 1, "Shape infer ops are supposed to be the only consumer if it doesn't consume a store ops."); - } - } - return last_exp; -} - LinearIR::LIRShapeInfer::LIRShapeInfer(container& body_exprs, io_container& io_exprs) : ShapeInferSnippetsNode(), m_exprs{std::make_shared(body_exprs)} { diff --git a/src/common/snippets/src/lowered/pass/allocate_buffers.cpp b/src/common/snippets/src/lowered/pass/allocate_buffers.cpp index 72769f210015a4..027a5b9d5a423b 100644 --- a/src/common/snippets/src/lowered/pass/allocate_buffers.cpp +++ b/src/common/snippets/src/lowered/pass/allocate_buffers.cpp @@ -13,6 +13,7 @@ #include "snippets/lowered/pass/normalize_buffer_ids.hpp" #include "snippets/pass/tokenization.hpp" #include "snippets/itt.hpp" +#include "snippets/utils.hpp" namespace ov { namespace snippets { @@ -46,8 +47,9 @@ void AllocateBuffers::set_buffer_offset(const ExpressionPtr& buffer_expr, const } } // Propagate to down: in Load. Buffer can have several Load - auto last_shape_infer = ov::snippets::lowered::LinearIR::get_last_child_shape_infer_expr(buffer_expr); - const auto& buffer_out = last_shape_infer->get_output_port_connector(0); + const auto& shape_infer_seq = utils::get_first_child_shape_infer_expr_seq(buffer_expr); + const auto& target_expr = shape_infer_seq.empty() ? buffer_expr : shape_infer_seq.back(); + const auto& buffer_out = target_expr->get_output_port_connector(0); for (const auto& child_expr_input : buffer_out->get_consumers()) { const auto& child_expr = child_expr_input.get_expr(); const auto port = child_expr_input.get_index(); diff --git a/src/common/snippets/src/lowered/pass/assign_registers.cpp b/src/common/snippets/src/lowered/pass/assign_registers.cpp index 8792e3a7fe989f..ae5e48736bb27f 100644 --- a/src/common/snippets/src/lowered/pass/assign_registers.cpp +++ b/src/common/snippets/src/lowered/pass/assign_registers.cpp @@ -5,9 +5,9 @@ #include "snippets/lowered/pass/assign_registers.hpp" #include "snippets/lowered/linear_ir.hpp" -#include "snippets/op/subgraph.hpp" #include "snippets/snippets_isa.hpp" #include "snippets/itt.hpp" +#include "snippets/utils.hpp" // This header is needed to avoid MSVC warning "C2039: 'inserter': is not a member of 'std'" #include @@ -82,20 +82,16 @@ bool AssignRegisters::run(LinearIR& linear_ir) { manually_assigned_gprs[out_connector] = io_expr->get_index(); // TODO [96434]: Support shape infer ops in arbitrary place in pipeline, not just after inputs // shape infer ops sequence after input - auto shape_infer_consumers = LinearIR::get_child_shape_infer_expr_seq(io_expr); - if (!shape_infer_consumers.empty()) { - for (const auto& child_shape_infer_expr : shape_infer_consumers) { - manually_assigned_gprs[child_shape_infer_expr->get_output_port_connector(0)] = io_expr->get_index(); - } + const auto& shape_infer_consumers = utils::get_first_child_shape_infer_expr_seq(io_expr); + for (const auto& child_shape_infer_expr : shape_infer_consumers) { + manually_assigned_gprs[child_shape_infer_expr->get_output_port_connector(0)] = io_expr->get_index(); } } else if (io_expr->get_type() == IOExpression::io_type::OUTPUT) { manually_assigned_gprs[expr->get_input_port_connector(0)] = num_parameters + io_expr->get_index(); // shape infer ops sequence before result - auto shape_infer_sources = LinearIR::get_parent_shape_infer_expr_seq(io_expr); - if (!shape_infer_sources.empty()) { - for (const auto& parent_shape_infer_expr : shape_infer_sources) { - manually_assigned_gprs[parent_shape_infer_expr->get_input_port_connector(0)] = num_parameters + io_expr->get_index(); - } + const auto& shape_infer_sources = utils::get_first_parent_shape_infer_expr_seq(io_expr); + for (const auto& parent_shape_infer_expr : shape_infer_sources) { + manually_assigned_gprs[parent_shape_infer_expr->get_input_port_connector(0)] = num_parameters + io_expr->get_index(); } } else { OPENVINO_THROW("Unsupported io_type detected"); @@ -108,13 +104,11 @@ bool AssignRegisters::run(LinearIR& linear_ir) { static_cast(num_results + num_parameters + buffer_id); // shape infer ops in the middle of subgraph. IntermediateMemoryBuffer is inserted before reshape as new loop should start. // child shape info ops share the same memory as IntermediateMemoryBuffer. - auto shape_infer_consumers = LinearIR::get_child_shape_infer_expr_seq(expr); - if (!shape_infer_consumers.empty()) { - for (const auto& child_shape_infer_expr : shape_infer_consumers) { - manually_assigned_gprs[child_shape_infer_expr->get_input_port_connector(0)] = - manually_assigned_gprs[child_shape_infer_expr->get_output_port_connector(0)] = - static_cast(num_results + num_parameters + buffer_id); - } + const auto& shape_infer_consumers = utils::get_first_child_shape_infer_expr_seq(expr); + for (const auto& child_shape_infer_expr : shape_infer_consumers) { + manually_assigned_gprs[child_shape_infer_expr->get_input_port_connector(0)] = + manually_assigned_gprs[child_shape_infer_expr->get_output_port_connector(0)] = + static_cast(num_results + num_parameters + buffer_id); } } manually_assigned_gprs[expr->get_output_port_connector(0)] = diff --git a/src/common/snippets/src/lowered/pass/insert_buffers.cpp b/src/common/snippets/src/lowered/pass/insert_buffers.cpp index ea3139c84bbd04..8e3bafd5b49fb3 100644 --- a/src/common/snippets/src/lowered/pass/insert_buffers.cpp +++ b/src/common/snippets/src/lowered/pass/insert_buffers.cpp @@ -148,19 +148,17 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const auto port_idx = entry_port->get_index(); const auto node = expr->get_node(); auto parent_expr_output = expr->get_input_port_connector(port_idx)->get_source(); - - const auto& first_parent_expr = parent_expr_output.get_expr(); + auto parent_expr = parent_expr_output.get_expr(); bool has_shape_infer_parent = false; auto top_shape_infer_expr = expr; // parent before shape infer ops is used to determine if buffer needed according loopInfo - auto shape_infer_parents = LinearIR::get_parent_shape_infer_expr_seq(first_parent_expr); + const auto& shape_infer_parents = utils::get_first_parent_shape_infer_expr_seq(parent_expr); if (!shape_infer_parents.empty()) { parent_expr_output = shape_infer_parents.back()->get_input_port_connector(0)->get_source(); has_shape_infer_parent = true; top_shape_infer_expr = shape_infer_parents.back(); + parent_expr = parent_expr_output.get_expr(); } - - const auto& parent_expr = parent_expr_output.get_expr(); const auto& parent_port = parent_expr_output.get_index(); const auto& parent = parent_expr->get_node(); if (ov::is_type(parent) || diff --git a/src/common/snippets/src/lowered/pass/insert_load_store.cpp b/src/common/snippets/src/lowered/pass/insert_load_store.cpp index b36faa550d40d5..c8466d85747ce6 100644 --- a/src/common/snippets/src/lowered/pass/insert_load_store.cpp +++ b/src/common/snippets/src/lowered/pass/insert_load_store.cpp @@ -35,8 +35,8 @@ size_t InsertLoadStore::get_count(const ExpressionPort& port) const { } bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExprIt& data_expr_it) { - std::shared_ptr data_expr = *data_expr_it; - data_expr = LinearIR::get_last_child_shape_infer_expr(data_expr); + const auto& shape_infer_seq = utils::get_first_child_shape_infer_expr_seq(*data_expr_it); + const std::shared_ptr& data_expr = shape_infer_seq.empty() ? *data_expr_it : shape_infer_seq.back(); const auto& data_ngraph_output = data_expr->get_node()->output(0); bool was_inserted = false; const auto& data_out = data_expr->get_output_port_connector(0); @@ -56,8 +56,8 @@ bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExpr } bool InsertLoadStore::insert_store(LinearIR& linear_ir, const LinearIR::constExprIt& data_expr_it) { - auto data_expr = *data_expr_it; - data_expr = LinearIR::get_last_parent_shape_infer_expr(data_expr); + const auto& shape_infer_seq = utils::get_first_parent_shape_infer_expr_seq(*data_expr_it); + const auto& data_expr = shape_infer_seq.empty() ? *data_expr_it : shape_infer_seq.back(); const auto& parent_output = data_expr->get_input_port_connector(0)->get_source(); const auto& parent_expr = parent_output.get_expr(); diff --git a/src/common/snippets/src/lowered/pass/validate.cpp b/src/common/snippets/src/lowered/pass/validate.cpp index 1da16eb2feff60..c890811efb4063 100644 --- a/src/common/snippets/src/lowered/pass/validate.cpp +++ b/src/common/snippets/src/lowered/pass/validate.cpp @@ -32,7 +32,8 @@ void validate_ports(const ExpressionPtr& expr) { void validate_parameter(const ExpressionPtr& expr, const LinearIR& linear_ir) { OPENVINO_ASSERT(ov::is_type(expr->get_node()), "Parameter validation expects Parameter op"); - auto expr_val = LinearIR::get_last_child_shape_infer_expr(expr); + const auto& shape_infer_seq = utils::get_first_child_shape_infer_expr_seq(expr); + const auto& expr_val = shape_infer_seq.empty() ? expr : shape_infer_seq.back(); auto consumer_inputs = expr_val->get_output_port_connector(0)->get_consumers(); std::set> layouts; for (const auto& consumer_input : consumer_inputs) { @@ -51,7 +52,8 @@ void validate_parameter(const ExpressionPtr& expr, const LinearIR& linear_ir) { void validate_result(const ExpressionPtr& expr, const LinearIR& linear_ir) { OPENVINO_ASSERT(ov::is_type(expr->get_node()), "Result validation expects Result op"); - auto expr_val = LinearIR::get_last_parent_shape_infer_expr(expr); + const auto& shape_infer_seq = utils::get_first_parent_shape_infer_expr_seq(expr); + const auto& expr_val = shape_infer_seq.empty() ? expr : shape_infer_seq.back(); const auto source = expr_val->get_input_port_connector(0)->get_source(); const auto ma = ov::as_type_ptr(source.get_expr()->get_node()); OPENVINO_ASSERT(ma && ma->is_memory_access_output_port(source.get_index()), @@ -66,7 +68,8 @@ void validate_buffer(const ExpressionPtr& expr, const LinearIR& linear_ir) { const auto ma = ov::as_type_ptr(source.get_expr()->get_node()); OPENVINO_ASSERT(ma && ma->is_memory_access_input_port(source.get_index()), "Buffer expects MemoryAccess parent"); - auto expr_val = LinearIR::get_last_child_shape_infer_expr(expr); + const auto& shape_infer_seq = utils::get_first_child_shape_infer_expr_seq(expr); + const auto& expr_val = shape_infer_seq.empty() ? expr : shape_infer_seq.back(); const auto& out = expr_val->get_output_port_connector(0); const auto consumers = out->get_consumers(); for (const auto& consumer_input : consumers) { diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index b0eac09e6a6ffe..6ed24fa70f11ee 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -93,7 +93,7 @@ auto Subgraph::get_last_child_shape_infer_op(const std::shared_ptr& op return last_op; auto consumers = last_op->get_output_target_inputs(0); auto first_child = consumers.begin()->get_node()->shared_from_this(); - while (op::Subgraph::is_shape_infer_op(first_child)) { + while (is_shape_infer_op(first_child)) { OPENVINO_ASSERT(consumers.size() == 1, "Shape infer ops are supposed to be the only consumer."); last_op = first_child; if (last_op->get_output_size() == 0) @@ -109,7 +109,7 @@ auto Subgraph::get_last_parent_shape_infer_op(const std::shared_ptr& o if (last_op->get_input_size() == 0) return last_op; auto first_parent = last_op->get_input_node_shared_ptr(0); - while (op::Subgraph::is_shape_infer_op(first_parent)) { + while (is_shape_infer_op(first_parent)) { last_op = first_parent; if (last_op->get_input_size() == 0) break; diff --git a/src/common/snippets/src/pass/gn_decomposition.cpp b/src/common/snippets/src/pass/gn_decomposition.cpp index 4b72ef0e5dea6f..04d3fdb0ac5971 100644 --- a/src/common/snippets/src/pass/gn_decomposition.cpp +++ b/src/common/snippets/src/pass/gn_decomposition.cpp @@ -5,10 +5,8 @@ #include "snippets/pass/gn_decomposition.hpp" #include "openvino/op/group_normalization.hpp" -#include "snippets/op/reduce.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "snippets/itt.hpp" -#include "snippets/lowered/port_descriptor.hpp" #include "snippets/snippets_isa.hpp" #include "openvino/core/rt_info.hpp" diff --git a/src/common/snippets/src/shape_inference/shape_infer_instances.cpp b/src/common/snippets/src/shape_inference/shape_infer_instances.cpp index 371de613305f37..49cc1a379c8b18 100644 --- a/src/common/snippets/src/shape_inference/shape_infer_instances.cpp +++ b/src/common/snippets/src/shape_inference/shape_infer_instances.cpp @@ -248,18 +248,18 @@ Result ReduceShapeInfer::infer(const std::vector& input_shapes) { ReshapeShapeInfer::ReshapeShapeInfer(const std::shared_ptr& n) { const auto& reshape = as_type_ptr(n); OPENVINO_ASSERT(reshape, "Invalid node passed to ReshapeShapeInfer."); - target_shape = reshape->get_target_shape(); + const auto& partial_shape = reshape->get_target_shape(); + OPENVINO_ASSERT(partial_shape.is_static(), "target_shape of reshape op should be static in ReshapeShapeInfer"); + target_shape = partial_shape.get_shape(); + target_shape_volume = utils::get_shape_size(target_shape); } Result ReshapeShapeInfer::infer(const std::vector& input_shapes) { OPENVINO_ASSERT(input_shapes.size() == 1, "Invalid number of shapes is passed in ReshapeShapeInfer"); - OPENVINO_ASSERT(target_shape.is_static(), "target_shape should be static in ReshapeShapeInfer"); - VectorDims result_shape = target_shape.get_shape(); - const auto input_elems = utils::get_shape_size(input_shapes[0].get()); - const auto output_elems = utils::get_shape_size(result_shape); - OPENVINO_ASSERT(input_elems == output_elems, "Tensor volume should be the same after reshape in ReshapeShapeInfer"); + const auto input_shape_volume = utils::get_shape_size(input_shapes[0].get()); + OPENVINO_ASSERT(input_shape_volume == target_shape_volume, "Tensor volume should be the same after reshape in ReshapeShapeInfer"); - return {{result_shape}, ShapeInferStatus::success}; + return {{target_shape}, ShapeInferStatus::success}; } } // namespace snippets diff --git a/src/common/snippets/src/utils.cpp b/src/common/snippets/src/utils.cpp index d3106179c7f9ab..45a85d04fba762 100644 --- a/src/common/snippets/src/utils.cpp +++ b/src/common/snippets/src/utils.cpp @@ -6,6 +6,7 @@ #include "snippets/pass/fq_decomposition.hpp" #include "openvino/core/rt_info.hpp" +#include "snippets/op/subgraph.hpp" namespace ov { @@ -165,6 +166,60 @@ VectorDims get_preordered_vdims(const snippets::lowered::ExpressionPort& expr_po return get_preordered_vdims(expr_port.get_descriptor_ptr()->get_shape(), expr_port.get_descriptor_ptr()->get_layout()); } +std::vector get_first_child_shape_infer_expr_seq(const lowered::ExpressionPtr& start_expr) { + auto get_first_shape_infer_expr = [](const std::set& consumers) -> lowered::ExpressionPtr { + for (auto it = consumers.begin(); it != consumers.end(); ++it) { + auto expr = it->get_expr(); + if (op::Subgraph::is_shape_infer_op(expr->get_node())) { + return expr; + } + } + return nullptr; + }; + std::vector shape_infer_exprs; + if (op::Subgraph::is_shape_infer_op(start_expr->get_node())) { + OPENVINO_ASSERT(start_expr->get_input_port_connector(0)->get_consumers().size() == 1, "Shape infer ops are supposed to be the only consumer."); + shape_infer_exprs.push_back(start_expr); + } + if (start_expr->get_output_count() == 0) + return shape_infer_exprs; + auto output_consumers = start_expr->get_output_port_connector(0)->get_consumers(); + while (auto shape_infer_child = get_first_shape_infer_expr(output_consumers)) { + OPENVINO_ASSERT(output_consumers.size() == 1, "Shape infer ops are supposed to be the only consumer."); + shape_infer_exprs.push_back(shape_infer_child); + if (shape_infer_child->get_output_count() == 0) + break; + output_consumers = shape_infer_child->get_output_port_connector(0)->get_consumers(); + } + return shape_infer_exprs; +} + +std::vector get_first_parent_shape_infer_expr_seq(const lowered::ExpressionPtr& start_expr) { + std::vector shape_infer_exprs; + auto current_exp = start_expr; + if (op::Subgraph::is_shape_infer_op(current_exp->get_node())) { + OPENVINO_ASSERT(current_exp->get_input_port_connector(0)->get_consumers().size() == 1, "Shape infer ops are supposed to be the only consumer."); + shape_infer_exprs.push_back(current_exp); + } + if (current_exp->get_input_count() == 0) + return shape_infer_exprs; + auto input = current_exp->get_input_port_connector(0); + auto first_parent = input->get_source().get_expr(); + while (op::Subgraph::is_shape_infer_op(first_parent->get_node())) { + shape_infer_exprs.push_back(first_parent); + current_exp = first_parent; + if (current_exp->get_input_count() == 0) + break; + input = current_exp->get_input_port_connector(0); + first_parent = input->get_source().get_expr(); + if (!ov::is_type(first_parent->get_node())) { + // there are maybe some loopEnd consumers of store as well for loop code gen purpose + OPENVINO_ASSERT(input->get_consumers().size() == 1, "Shape infer ops are supposed to be the only consumer if it doesn't consume a store ops."); + } + } + return shape_infer_exprs; +} + } // namespace utils } // namespace snippets } // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.cpp index 9c7f5cc40c21b1..355c29b54513a3 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.cpp @@ -3,6 +3,7 @@ // #include "jit_kernel_emitter.hpp" +#include "snippets/utils.hpp" using namespace Xbyak; @@ -201,7 +202,8 @@ jit_kernel_static_emitter::jit_kernel_static_emitter(dnnl::impl::cpu::x64::jit_g switch (expr->get_type()) { case snippets::lowered::IOExpression::io_type::INPUT: { // input->shape changing ops->load - auto mem_desc_expr = ov::snippets::lowered::LinearIR::get_last_child_shape_infer_expr(expr); + const auto& shape_infer_seq = ov::snippets::utils::get_first_child_shape_infer_expr_seq(expr); + const auto& mem_desc_expr = shape_infer_seq.empty() ? expr : shape_infer_seq.back(); auto consumer_inputs = mem_desc_expr->get_output_port_connector(0)->get_consumers(); for (const auto& child_input : consumer_inputs) { const auto ma = ov::as_type_ptr(child_input.get_expr()->get_node()); @@ -211,12 +213,12 @@ jit_kernel_static_emitter::jit_kernel_static_emitter(dnnl::impl::cpu::x64::jit_g } } etype = mem_desc_expr->get_node()->get_output_element_type(0); - break; break; } case snippets::lowered::IOExpression::io_type::OUTPUT: { // store->shape changing ops->result - auto mem_desc_expr = ov::snippets::lowered::LinearIR::get_last_parent_shape_infer_expr(expr); + const auto& shape_infer_seq = ov::snippets::utils::get_first_parent_shape_infer_expr_seq(expr); + const auto& mem_desc_expr = shape_infer_seq.empty() ? expr : shape_infer_seq.back(); desc = mem_desc_expr->get_input_port_connector(0)->get_source().get_descriptor_ptr(); etype = mem_desc_expr->get_node()->get_input_element_type(0); break;