Skip to content

Commit

Permalink
Enable GroupNormalization fusion pass in GPU plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
jhajducz committed Jan 12, 2025
1 parent 3ec3026 commit ba87e35
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
#include "transformations/common_optimizations/broadcast_transition.hpp"
#include "transformations/common_optimizations/common_optimizations.hpp"
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
#include "transformations/common_optimizations/group_normalization_fusion.hpp"
#include "transformations/common_optimizations/lin_op_sequence_fusion.hpp"
#include "transformations/common_optimizations/lstm_cell_fusion.hpp"
#include "transformations/common_optimizations/move_eltwise_up_data_movement.hpp"
Expand Down Expand Up @@ -290,6 +291,13 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
auto pass_config = manager.get_pass_config();
manager.set_per_pass_validation(false);

// fuse following ops into GroupNormalization:
// group_norm_gamma * (instance_norm_gamma * MVN(x) + instance_norm_beta) + group_norm_beta
// note that instance norm related parameters are optional:
// - instance_norm_gamma is assumed to be filled with ones if not present in the graph
// - instance_norm_beta is assumed to be filled with zeros if not present in the graph
manager.register_pass<ov::pass::GroupNormalizationFusion>();

// Temporary solution, global rt info cleanup is needed
for (auto& node : func->get_ops()) {
ov::enable_constant_folding(node);
Expand Down

0 comments on commit ba87e35

Please sign in to comment.