Skip to content

Commit

Permalink
tokenize only single groupNorm to subgraph
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Jan 2, 2024
1 parent d30d70b commit d24e71d
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/matcher.hpp"

namespace ov {
namespace snippets {
namespace pass {

/**
* @interface TokenizeGroupNormSnippets
* @brief Tokenize GroupNormalization to a subgraph
* @ingroup snippets
*/
class TokenizeGroupNormSnippets: public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TokenizeGroupNormSnippets", "0");
TokenizeGroupNormSnippets();
};

} // namespace pass
} // namespace snippets
} // namespace ov
5 changes: 4 additions & 1 deletion src/common/snippets/src/op/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "snippets/pass/set_softmax_ports.hpp"
#include "snippets/pass/canonicalization.hpp"
#include "snippets/pass/align_element_types.hpp"
#include "snippets/pass/group_normalization_decomposition.hpp"
#include "snippets/lowered/pass/validate_shapes.hpp"

#include "snippets/utils.hpp"
Expand Down Expand Up @@ -75,7 +76,8 @@ auto Subgraph::is_domain_sensitive_op(const std::shared_ptr<ov::Node>& op) -> bo
ov::is_type<ov::op::v8::Softmax>(op) ||
ov::is_type<ov::op::v0::MatMul>(op) ||
ov::is_type<ov::op::v1::Broadcast>(op) || // Broadcast is domain sensetive op because the output shape depends on
ov::is_type<ov::op::v3::Broadcast>(op); // the both input and broadcast shapes (the both - are inputs of op). Note: is used only in MHA pattern
ov::is_type<ov::op::v3::Broadcast>(op) || // the both input and broadcast shapes (the both - are inputs of op). Note: is used only in MHA pattern
ov::is_type<ov::op::v12::GroupNormalization>(op);
}

void Subgraph::init_config() {
Expand Down Expand Up @@ -395,6 +397,7 @@ void Subgraph::data_flow_transformations(const BlockedShapeVector& blocked_input
manager.register_pass<snippets::pass::FuseTransposeBrgemm>();
manager.register_pass<snippets::pass::TransposeDecomposition>();
manager.register_pass<snippets::pass::SetSoftmaxPorts>();
manager.register_pass<snippets::pass::GroupNormalizationDecomposition>();
}
manager.register_pass<snippets::pass::BroadcastToMoveBroadcast>();
manager.register_pass<snippets::pass::ConvertConstantsToScalars>();
Expand Down
34 changes: 34 additions & 0 deletions src/common/snippets/src/pass/group_norm_tokenization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "snippets/pass/group_norm_tokenization.hpp"

#include "snippets/itt.hpp"
#include "snippets/op/subgraph.hpp"
#include "snippets/utils.hpp"

#include "openvino/core/rt_info.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"

ov::snippets::pass::TokenizeGroupNormSnippets::TokenizeGroupNormSnippets() {
MATCHER_SCOPE(TokenizeGroupNormSnippets);

auto group_norm_pattern = ov::pass::pattern::wrap_type<ov::op::v12::GroupNormalization>();

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::pass::TokenizeGroupNormSnippets")
auto group_norm_node = ov::as_type_ptr<ov::op::v12::GroupNormalization>(m.get_match_root());

auto subgraph = op::Subgraph::wrap_node_as_subgraph(group_norm_node);
subgraph->get_rt_info()["originalLayersNames"] = group_norm_node->get_friendly_name();
ov::replace_node(group_norm_node, subgraph);
op::update_out_tensor_name(subgraph);

return true;
};
auto m = std::make_shared<ov::pass::pattern::Matcher>(group_norm_pattern, matcher_name);
register_matcher(m, callback);
}


Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "snippets/itt.hpp"
#include "snippets/lowered/port_descriptor.hpp"
#include "snippets/snippets_isa.hpp"
#include "transformations/utils/utils.hpp"
#include "openvino/core/rt_info.hpp"

namespace ov {
Expand All @@ -32,7 +31,7 @@ GroupNormalizationDecomposition::GroupNormalizationDecomposition() {
const auto bias = group_norm_node->input_value(2);

const auto num_groups = static_cast<size_t>(group_norm_node->get_num_groups());
const auto eps = ov::op::util::cast_eps_to_float(group_norm_node->get_epsilon());
const float eps = static_cast<float>(group_norm_node->get_epsilon());

// reshape [N, C, spatial] to [N, group, C / group, spatial]
const auto orig_shape = group_norm_node->get_input_shape(0);
Expand Down
2 changes: 2 additions & 0 deletions src/common/snippets/src/pass/tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "snippets/pass/common_optimizations.hpp"
#include "snippets/pass/extract_reshapes_from_mha.hpp"
#include "snippets/pass/mha_tokenization.hpp"
#include "snippets/pass/group_norm_tokenization.hpp"
#include "snippets/pass/collapse_subgraph.hpp"


Expand Down Expand Up @@ -82,6 +83,7 @@ bool SnippetsTokenization::run_on_model(const std::shared_ptr<ov::Model>& m) {
manager.register_pass<ExtractReshapesFromMHA>();
manager.register_pass<TokenizeMHASnippets>(m_config);
manager.register_pass<TokenizeSnippets>();
manager.register_pass<TokenizeGroupNormSnippets>();
manager.register_pass<CommonOptimizations>(m_config);
manager.run_passes(m);

Expand Down

0 comments on commit d24e71d

Please sign in to comment.