diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp index 53bdac693788..7c9eb28b83b7 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp @@ -34,6 +34,10 @@ static llvm::cl::opt clEnableTransposePropagation( llvm::cl::desc( "Enables propagation of transpose ops to improve fusion chances."), llvm::cl::init(true)); +static llvm::cl::opt clEnableAttentionVTranspose( + "iree-global-opt-enable-attention-v-transpose", + llvm::cl::desc("Enables transposition of v operand of attention ops,"), + llvm::cl::init(true)); // TODO(hanchung): Remove the flag. We don't want to do early materialization by // default. Because it won't work for heterogeneous computing. This is not the @@ -157,8 +161,11 @@ void buildGlobalOptimizationPassPipeline( .addPredicatedPass( clEnableTransposePropagation, [&]() { - return createPropagateLinalgTransposePass( - transformOptions.options.aggressiveTransposePropagation); + PropagateLinalgTransposePassOptions options; + options.enableAggressivePropagation = + transformOptions.options.aggressiveTransposePropagation; + options.enableAttentionVTranspose = clEnableAttentionVTranspose; + return createPropagateLinalgTransposePass(options); }) .addPass(IREE::Flow::createCanonicalizerPass) .addPass(mlir::createCSEPass); diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.td b/compiler/src/iree/compiler/GlobalOptimization/Passes.td index 1d12334eeb42..4fcbbfbb5c89 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.td +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.td @@ -115,6 +115,8 @@ def PropagateLinalgTransposePass : "Flag used for lit-testing sinking patterns only. Not for general usage">, Option<"testBubblingOnly", "test-bubbling-only", "bool", /*default=*/"false", "Flag used for lit-testing bubbling patterns only. Not for general usage">, + Option<"enableAttentionVTranspose", "enable-attention-v-transpose", "bool", + /*default=*/"true", "Enable transposition of attention v operand">, ]; } diff --git a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp index 1e58f0e76211..e2104184f7b5 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp @@ -1092,15 +1092,19 @@ void PropagateLinalgTransposePass::runOnOperation() { linalg::populateFoldReshapeOpsByExpansionPatterns(bubblingPatterns, reshapePropagationFn); linalg::FillOp::getCanonicalizationPatterns(bubblingPatterns, context); - linalg::ControlFusionFn bubbleTransposeControlFn = - [](OpOperand *fusedOperand) { - Operation *producer = fusedOperand->get().getDefiningOp(); - Operation *consumer = fusedOperand->getOwner(); - return IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer}); - }; - IREE::LinalgExt::populateBubbleTransposeFromLinalgExtOps( - bubblingPatterns, bubbleTransposeControlFn); + if (enableAttentionVTranspose) { + linalg::ControlFusionFn bubbleTransposeControlFn = + [](OpOperand *fusedOperand) { + Operation *producer = fusedOperand->get().getDefiningOp(); + Operation *consumer = fusedOperand->getOwner(); + + return IREE::Flow::isNonNullAndOutsideDispatch( + {producer, consumer}); + }; + IREE::LinalgExt::populateBubbleTransposeFromLinalgExtOps( + bubblingPatterns, bubbleTransposeControlFn); + } bubblingPatterns.insert( context, enableAggressivePropagation); bubblingPatterns.insert(context);