diff --git a/src/common/snippets/include/snippets/pass/group_norm_tokenization.hpp b/src/common/snippets/include/snippets/pass/group_norm_tokenization.hpp new file mode 100644 index 00000000000000..ba7737a69b0db9 --- /dev/null +++ b/src/common/snippets/include/snippets/pass/group_norm_tokenization.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pattern/matcher.hpp" + +namespace ov { +namespace snippets { +namespace pass { + +/** + * @interface TokenizeGroupNormSnippets + * @brief Tokenize GroupNormalization to a subgraph + * @ingroup snippets + */ +class TokenizeGroupNormSnippets: public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("TokenizeGroupNormSnippets", "0"); + TokenizeGroupNormSnippets(); +}; + +} // namespace pass +} // namespace snippets +} // namespace ov \ No newline at end of file diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index b5e5e5c526dd5b..647fcb63f3e822 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -17,6 +17,7 @@ #include "snippets/pass/set_softmax_ports.hpp" #include "snippets/pass/canonicalization.hpp" #include "snippets/pass/align_element_types.hpp" +#include "snippets/pass/group_normalization_decomposition.hpp" #include "snippets/lowered/pass/validate_shapes.hpp" #include "snippets/utils.hpp" @@ -75,7 +76,8 @@ auto Subgraph::is_domain_sensitive_op(const std::shared_ptr& op) -> bo ov::is_type(op) || ov::is_type(op) || ov::is_type(op) || // Broadcast is domain sensetive op because the output shape depends on - ov::is_type(op); // the both input and broadcast shapes (the both - are inputs of op). Note: is used only in MHA pattern + ov::is_type(op) || // the both input and broadcast shapes (the both - are inputs of op). Note: is used only in MHA pattern + ov::is_type(op); } void Subgraph::init_config() { @@ -395,6 +397,7 @@ void Subgraph::data_flow_transformations(const BlockedShapeVector& blocked_input manager.register_pass(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); } manager.register_pass(); manager.register_pass(); diff --git a/src/common/snippets/src/pass/group_norm_tokenization.cpp b/src/common/snippets/src/pass/group_norm_tokenization.cpp new file mode 100644 index 00000000000000..4317400c426ba8 --- /dev/null +++ b/src/common/snippets/src/pass/group_norm_tokenization.cpp @@ -0,0 +1,34 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/pass/group_norm_tokenization.hpp" + +#include "snippets/itt.hpp" +#include "snippets/op/subgraph.hpp" +#include "snippets/utils.hpp" + +#include "openvino/core/rt_info.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" + +ov::snippets::pass::TokenizeGroupNormSnippets::TokenizeGroupNormSnippets() { + MATCHER_SCOPE(TokenizeGroupNormSnippets); + + auto group_norm_pattern = ov::pass::pattern::wrap_type(); + + ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::pass::TokenizeGroupNormSnippets") + auto group_norm_node = ov::as_type_ptr(m.get_match_root()); + + auto subgraph = op::Subgraph::wrap_node_as_subgraph(group_norm_node); + subgraph->get_rt_info()["originalLayersNames"] = group_norm_node->get_friendly_name(); + ov::replace_node(group_norm_node, subgraph); + op::update_out_tensor_name(subgraph); + + return true; + }; + auto m = std::make_shared(group_norm_pattern, matcher_name); + register_matcher(m, callback); +} + + diff --git a/src/common/snippets/src/pass/group_normalization_decomposition.cpp b/src/common/snippets/src/pass/group_normalization_decomposition.cpp index 279b98e29849d3..77e4008425c7a7 100644 --- a/src/common/snippets/src/pass/group_normalization_decomposition.cpp +++ b/src/common/snippets/src/pass/group_normalization_decomposition.cpp @@ -9,7 +9,6 @@ #include "snippets/itt.hpp" #include "snippets/lowered/port_descriptor.hpp" #include "snippets/snippets_isa.hpp" -#include "transformations/utils/utils.hpp" #include "openvino/core/rt_info.hpp" namespace ov { @@ -32,7 +31,7 @@ GroupNormalizationDecomposition::GroupNormalizationDecomposition() { const auto bias = group_norm_node->input_value(2); const auto num_groups = static_cast(group_norm_node->get_num_groups()); - const auto eps = ov::op::util::cast_eps_to_float(group_norm_node->get_epsilon()); + const float eps = static_cast(group_norm_node->get_epsilon()); // reshape [N, C, spatial] to [N, group, C / group, spatial] const auto orig_shape = group_norm_node->get_input_shape(0); diff --git a/src/common/snippets/src/pass/tokenization.cpp b/src/common/snippets/src/pass/tokenization.cpp index 0dc0415356b0b5..c79baea55c4c0e 100644 --- a/src/common/snippets/src/pass/tokenization.cpp +++ b/src/common/snippets/src/pass/tokenization.cpp @@ -9,6 +9,7 @@ #include "snippets/pass/common_optimizations.hpp" #include "snippets/pass/extract_reshapes_from_mha.hpp" #include "snippets/pass/mha_tokenization.hpp" +#include "snippets/pass/group_norm_tokenization.hpp" #include "snippets/pass/collapse_subgraph.hpp" @@ -82,6 +83,7 @@ bool SnippetsTokenization::run_on_model(const std::shared_ptr& m) { manager.register_pass(); manager.register_pass(m_config); manager.register_pass(); + manager.register_pass(); manager.register_pass(m_config); manager.run_passes(m);