From 9cf1017e7120f98c1fb91b12f171788a64c0b618 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Thu, 16 Jan 2025 15:25:45 +0100 Subject: [PATCH] Use predicates for type & shape checks that don't depend on other nodes in GroupNormalizationFusion pass --- .../group_normalization_fusion.cpp | 93 ++++++++----------- 1 file changed, 39 insertions(+), 54 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp index 387b6c450e8d5b..ced652fc5edb8c 100644 --- a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp @@ -25,35 +25,63 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { MATCHER_SCOPE(GroupNormalizationFusion); - auto input_m = ov::pass::pattern::any_input(); + auto has_real_not_quantized_type = [](const ov::Output& output) -> bool { + const auto& T = output.get_element_type(); + return (T.is_real() && (!T.is_quantized())); + }; + + auto has_integral_type = [](const ov::Output& output) -> bool { + const auto& T = output.get_element_type(); + return (T.is_integral()); + }; - auto pre_mvn_shape_const_m = ov::pass::pattern::wrap_type(); - auto pre_mvn_reshape_m = ov::pass::pattern::wrap_type({input_m, pre_mvn_shape_const_m}); + auto has_at_least_2d_shape = [](const ov::Output& output) -> bool { + const auto& output_ps = output.get_partial_shape(); + return (output_ps.rank().is_static()) && (output_ps.rank().get_length() >= 2); + }; - auto axes_const_m = ov::pass::pattern::wrap_type(); + auto input_m = ov::pass::pattern::any_input(ov::pass::pattern::all_of( + {has_real_not_quantized_type, has_at_least_2d_shape, ov::pass::pattern::has_static_dim(1)})); + + auto pre_mvn_shape_const_m = ov::pass::pattern::wrap_type(ov::pass::pattern::all_of( + {has_integral_type, ov::pass::pattern::rank_equals(1), ov::pass::pattern::has_static_dim(0)})); + auto pre_mvn_reshape_m = ov::pass::pattern::wrap_type( + {input_m, pre_mvn_shape_const_m}, + ov::pass::pattern::all_of( + {has_real_not_quantized_type, ov::pass::pattern::rank_equals(3), ov::pass::pattern::has_static_dim(1)})); + + auto axes_const_m = ov::pass::pattern::wrap_type(ov::pass::pattern::all_of( + {has_integral_type, ov::pass::pattern::rank_equals(1), ov::pass::pattern::has_static_dim(0)})); auto mvn_m = ov::pass::pattern::wrap_type({pre_mvn_reshape_m, axes_const_m}); - auto instance_norm_gamma_m = ov::pass::pattern::any_input(); + auto instance_norm_gamma_m = ov::pass::pattern::any_input( + ov::pass::pattern::all_of({has_real_not_quantized_type, ov::pass::pattern::has_static_shape()})); auto instance_norm_gamma_multiply_m = ov::pass::pattern::wrap_type({mvn_m, instance_norm_gamma_m}); auto instance_norm_opt_gamma_m = std::make_shared(ov::OutputVector{mvn_m, instance_norm_gamma_multiply_m}); - auto instance_norm_beta_m = ov::pass::pattern::any_input(); + auto instance_norm_beta_m = ov::pass::pattern::any_input( + ov::pass::pattern::all_of({has_real_not_quantized_type, ov::pass::pattern::has_static_shape()})); auto instance_norm_beta_add_m = ov::pass::pattern::wrap_type({instance_norm_opt_gamma_m, instance_norm_beta_m}); auto instance_norm_opt_gamma_opt_beta_m = std::make_shared( ov::OutputVector{instance_norm_opt_gamma_m, instance_norm_beta_add_m}); - auto post_instance_norm_shape_m = ov::pass::pattern::any_input(); + auto post_instance_norm_shape_m = ov::pass::pattern::any_input(ov::pass::pattern::all_of( + {has_integral_type, ov::pass::pattern::rank_equals(1), ov::pass::pattern::has_static_dim(0)})); auto post_instance_norm_reshape_m = ov::pass::pattern::wrap_type( - {instance_norm_opt_gamma_opt_beta_m, post_instance_norm_shape_m}); + {instance_norm_opt_gamma_opt_beta_m, post_instance_norm_shape_m}, + ov::pass::pattern::all_of( + {has_real_not_quantized_type, has_at_least_2d_shape, ov::pass::pattern::has_static_dim(1)})); - auto group_norm_gamma_m = ov::pass::pattern::any_input(); + auto group_norm_gamma_m = ov::pass::pattern::any_input( + ov::pass::pattern::all_of({has_real_not_quantized_type, ov::pass::pattern::has_static_shape()})); auto group_norm_gamma_multiply_m = ov::pass::pattern::wrap_type({post_instance_norm_reshape_m, group_norm_gamma_m}); - auto group_norm_beta_m = ov::pass::pattern::any_input(); + auto group_norm_beta_m = ov::pass::pattern::any_input( + ov::pass::pattern::all_of({has_real_not_quantized_type, ov::pass::pattern::has_static_shape()})); auto group_norm_beta_add_m = ov::pass::pattern::wrap_type({group_norm_gamma_multiply_m, group_norm_beta_m}); @@ -65,30 +93,9 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { const auto& T = input.get_element_type(); - // this pattern supports only real and not quantized data types - if ((!T.is_real()) || (T.is_quantized())) - return false; - - // expecting at least 2D tensor as pattern input: - // (batch_size, num_channels, ...) - if (input_ps.size() < 2) - return false; - // channel dimension has to be static, all other dimensions in input can be dynamic - if (input_ps[1].is_dynamic()) - return false; - const auto& pre_mvn_reshape_out = pattern_map.at(pre_mvn_reshape_m); const auto& pre_mvn_reshape_out_ps = pre_mvn_reshape_out.get_partial_shape(); - // expecting 3D static tensor as pre-MVN reshape input: - // (batch_size, num_groups, -1) - if (pre_mvn_reshape_out_ps.size() != 3) - return false; - - // channel dimension has to be static, all other dimensions in pre-MVN reshape can be dynamic - if (pre_mvn_reshape_out_ps[1].is_dynamic()) - return false; - const auto& num_channels = input_ps[1].get_max_length(); const auto& num_groups = pre_mvn_reshape_out_ps[1].get_max_length(); @@ -97,11 +104,6 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { return false; auto channels_to_groups_ratio = num_channels / num_groups; - // MVN input has to have at least two dimensions: - // (batch_size, num_groups, ...) - if (pre_mvn_reshape_out_ps.size() < 2) - return false; - // first dimension of MVN input (batch_size) has to be the same // as in pattern input if (input_ps[0].get_max_length() != pre_mvn_reshape_out_ps[0].get_max_length()) @@ -110,8 +112,7 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { const auto& post_instance_norm_reshape_out = pattern_map.at(post_instance_norm_reshape_m); const auto& post_instance_norm_reshape_out_ps = post_instance_norm_reshape_out.get_partial_shape(); - // post instance norm shape has to be same as in pattern input: - // (batch_size, num_channels, height, width) + // post instance norm shape has to be same as in pattern input if (post_instance_norm_reshape_out_ps != input_ps) return false; @@ -123,10 +124,6 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { if (group_norm_gamma.get_element_type() != T) return false; - // group_norm_gamma has to be static - if (group_norm_gamma_ps.is_dynamic()) - return false; - // number of elements in group_norm_gamma must be equal to // number of channels if (ov::shape_size(group_norm_gamma.get_shape()) != num_channels) @@ -140,10 +137,6 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { if (group_norm_beta.get_element_type() != T) return false; - // group_norm_beta has to be static - if (group_norm_beta_ps.is_dynamic()) - return false; - // number of elements in group_norm_beta must be equal to // number of channels if (ov::shape_size(group_norm_beta.get_shape()) != num_channels) @@ -175,10 +168,6 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { if (instance_norm_beta.get_element_type() != T) return false; - // instance_norm_beta has to be static - if (instance_norm_beta_ps.is_dynamic()) - return false; - // number of elements in instance_norm_beta must be equal to // number of groups if (ov::shape_size(instance_norm_beta.get_shape()) != num_groups) @@ -212,10 +201,6 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { if (instance_norm_gamma.get_element_type() != T) return false; - // instance_norm_gamma has to be static - if (instance_norm_gamma_ps.is_dynamic()) - return false; - // number of elements in instance_norm_gamma must be equal to // number of groups if (ov::shape_size(instance_norm_gamma.get_shape()) != num_groups)