Skip to content

Commit

Permalink
Use predicates for type & shape checks that don't depend on other nod…
Browse files Browse the repository at this point in the history
…es in GroupNormalizationFusion pass
  • Loading branch information
jhajducz committed Jan 16, 2025
1 parent 45160a8 commit 9cf1017
Showing 1 changed file with 39 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,35 +25,63 @@
ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
MATCHER_SCOPE(GroupNormalizationFusion);

auto input_m = ov::pass::pattern::any_input();
auto has_real_not_quantized_type = [](const ov::Output<ov::Node>& output) -> bool {
const auto& T = output.get_element_type();
return (T.is_real() && (!T.is_quantized()));
};

auto has_integral_type = [](const ov::Output<ov::Node>& output) -> bool {
const auto& T = output.get_element_type();
return (T.is_integral());
};

auto pre_mvn_shape_const_m = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
auto pre_mvn_reshape_m = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({input_m, pre_mvn_shape_const_m});
auto has_at_least_2d_shape = [](const ov::Output<ov::Node>& output) -> bool {
const auto& output_ps = output.get_partial_shape();
return (output_ps.rank().is_static()) && (output_ps.rank().get_length() >= 2);
};

auto axes_const_m = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
auto input_m = ov::pass::pattern::any_input(ov::pass::pattern::all_of(
{has_real_not_quantized_type, has_at_least_2d_shape, ov::pass::pattern::has_static_dim(1)}));

auto pre_mvn_shape_const_m = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(ov::pass::pattern::all_of(
{has_integral_type, ov::pass::pattern::rank_equals(1), ov::pass::pattern::has_static_dim(0)}));
auto pre_mvn_reshape_m = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>(
{input_m, pre_mvn_shape_const_m},
ov::pass::pattern::all_of(
{has_real_not_quantized_type, ov::pass::pattern::rank_equals(3), ov::pass::pattern::has_static_dim(1)}));

auto axes_const_m = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(ov::pass::pattern::all_of(
{has_integral_type, ov::pass::pattern::rank_equals(1), ov::pass::pattern::has_static_dim(0)}));
auto mvn_m = ov::pass::pattern::wrap_type<ov::op::v6::MVN>({pre_mvn_reshape_m, axes_const_m});

auto instance_norm_gamma_m = ov::pass::pattern::any_input();
auto instance_norm_gamma_m = ov::pass::pattern::any_input(
ov::pass::pattern::all_of({has_real_not_quantized_type, ov::pass::pattern::has_static_shape()}));
auto instance_norm_gamma_multiply_m =
ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({mvn_m, instance_norm_gamma_m});
auto instance_norm_opt_gamma_m =
std::make_shared<ov::pass::pattern::op::Or>(ov::OutputVector{mvn_m, instance_norm_gamma_multiply_m});

auto instance_norm_beta_m = ov::pass::pattern::any_input();
auto instance_norm_beta_m = ov::pass::pattern::any_input(
ov::pass::pattern::all_of({has_real_not_quantized_type, ov::pass::pattern::has_static_shape()}));
auto instance_norm_beta_add_m =
ov::pass::pattern::wrap_type<ov::op::v1::Add>({instance_norm_opt_gamma_m, instance_norm_beta_m});
auto instance_norm_opt_gamma_opt_beta_m = std::make_shared<ov::pass::pattern::op::Or>(
ov::OutputVector{instance_norm_opt_gamma_m, instance_norm_beta_add_m});

auto post_instance_norm_shape_m = ov::pass::pattern::any_input();
auto post_instance_norm_shape_m = ov::pass::pattern::any_input(ov::pass::pattern::all_of(
{has_integral_type, ov::pass::pattern::rank_equals(1), ov::pass::pattern::has_static_dim(0)}));
auto post_instance_norm_reshape_m = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>(
{instance_norm_opt_gamma_opt_beta_m, post_instance_norm_shape_m});
{instance_norm_opt_gamma_opt_beta_m, post_instance_norm_shape_m},
ov::pass::pattern::all_of(
{has_real_not_quantized_type, has_at_least_2d_shape, ov::pass::pattern::has_static_dim(1)}));

auto group_norm_gamma_m = ov::pass::pattern::any_input();
auto group_norm_gamma_m = ov::pass::pattern::any_input(
ov::pass::pattern::all_of({has_real_not_quantized_type, ov::pass::pattern::has_static_shape()}));
auto group_norm_gamma_multiply_m =
ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({post_instance_norm_reshape_m, group_norm_gamma_m});

auto group_norm_beta_m = ov::pass::pattern::any_input();
auto group_norm_beta_m = ov::pass::pattern::any_input(
ov::pass::pattern::all_of({has_real_not_quantized_type, ov::pass::pattern::has_static_shape()}));
auto group_norm_beta_add_m =
ov::pass::pattern::wrap_type<ov::op::v1::Add>({group_norm_gamma_multiply_m, group_norm_beta_m});

Expand All @@ -65,30 +93,9 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {

const auto& T = input.get_element_type();

// this pattern supports only real and not quantized data types
if ((!T.is_real()) || (T.is_quantized()))
return false;

// expecting at least 2D tensor as pattern input:
// (batch_size, num_channels, ...)
if (input_ps.size() < 2)
return false;
// channel dimension has to be static, all other dimensions in input can be dynamic
if (input_ps[1].is_dynamic())
return false;

const auto& pre_mvn_reshape_out = pattern_map.at(pre_mvn_reshape_m);
const auto& pre_mvn_reshape_out_ps = pre_mvn_reshape_out.get_partial_shape();

// expecting 3D static tensor as pre-MVN reshape input:
// (batch_size, num_groups, -1)
if (pre_mvn_reshape_out_ps.size() != 3)
return false;

// channel dimension has to be static, all other dimensions in pre-MVN reshape can be dynamic
if (pre_mvn_reshape_out_ps[1].is_dynamic())
return false;

const auto& num_channels = input_ps[1].get_max_length();
const auto& num_groups = pre_mvn_reshape_out_ps[1].get_max_length();

Expand All @@ -97,11 +104,6 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
return false;
auto channels_to_groups_ratio = num_channels / num_groups;

// MVN input has to have at least two dimensions:
// (batch_size, num_groups, ...)
if (pre_mvn_reshape_out_ps.size() < 2)
return false;

// first dimension of MVN input (batch_size) has to be the same
// as in pattern input
if (input_ps[0].get_max_length() != pre_mvn_reshape_out_ps[0].get_max_length())
Expand All @@ -110,8 +112,7 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
const auto& post_instance_norm_reshape_out = pattern_map.at(post_instance_norm_reshape_m);
const auto& post_instance_norm_reshape_out_ps = post_instance_norm_reshape_out.get_partial_shape();

// post instance norm shape has to be same as in pattern input:
// (batch_size, num_channels, height, width)
// post instance norm shape has to be same as in pattern input
if (post_instance_norm_reshape_out_ps != input_ps)
return false;

Expand All @@ -123,10 +124,6 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
if (group_norm_gamma.get_element_type() != T)
return false;

// group_norm_gamma has to be static
if (group_norm_gamma_ps.is_dynamic())
return false;

// number of elements in group_norm_gamma must be equal to
// number of channels
if (ov::shape_size(group_norm_gamma.get_shape()) != num_channels)
Expand All @@ -140,10 +137,6 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
if (group_norm_beta.get_element_type() != T)
return false;

// group_norm_beta has to be static
if (group_norm_beta_ps.is_dynamic())
return false;

// number of elements in group_norm_beta must be equal to
// number of channels
if (ov::shape_size(group_norm_beta.get_shape()) != num_channels)
Expand Down Expand Up @@ -175,10 +168,6 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
if (instance_norm_beta.get_element_type() != T)
return false;

// instance_norm_beta has to be static
if (instance_norm_beta_ps.is_dynamic())
return false;

// number of elements in instance_norm_beta must be equal to
// number of groups
if (ov::shape_size(instance_norm_beta.get_shape()) != num_groups)
Expand Down Expand Up @@ -212,10 +201,6 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
if (instance_norm_gamma.get_element_type() != T)
return false;

// instance_norm_gamma has to be static
if (instance_norm_gamma_ps.is_dynamic())
return false;

// number of elements in instance_norm_gamma must be equal to
// number of groups
if (ov::shape_size(instance_norm_gamma.get_shape()) != num_groups)
Expand Down

0 comments on commit 9cf1017

Please sign in to comment.