Skip to content

Commit

Permalink
Disable Attention V operand transposition.
Browse files Browse the repository at this point in the history
This impacts the ability to horizontally fuse the matmuls that feed
into `Q-K-V` transpose. The improvements seen with the change might
have been due to reduction in copy overheads, which are no more an
issue.

Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar committed Feb 4, 2025
1 parent eb19497 commit 618dac5
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 10 deletions.
11 changes: 9 additions & 2 deletions compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ static llvm::cl::opt<bool> clEnableTransposePropagation(
llvm::cl::desc(
"Enables propagation of transpose ops to improve fusion chances."),
llvm::cl::init(true));
static llvm::cl::opt<bool> clEnableAttentionVTranspose(
"iree-global-opt-enable-attention-v-transpose",
llvm::cl::desc("Enables transposition of v operand of attention ops,"),
llvm::cl::init(true));

// TODO(hanchung): Remove the flag. We don't want to do early materialization by
// default. Because it won't work for heterogeneous computing. This is not the
Expand Down Expand Up @@ -157,8 +161,11 @@ void buildGlobalOptimizationPassPipeline(
.addPredicatedPass(
clEnableTransposePropagation,
[&]() {
return createPropagateLinalgTransposePass(
transformOptions.options.aggressiveTransposePropagation);
PropagateLinalgTransposePassOptions options;
options.enableAggressivePropagation =
transformOptions.options.aggressiveTransposePropagation;
options.enableAttentionVTranspose = clEnableAttentionVTranspose;
return createPropagateLinalgTransposePass(options);
})
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass);
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/GlobalOptimization/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def PropagateLinalgTransposePass :
"Flag used for lit-testing sinking patterns only. Not for general usage">,
Option<"testBubblingOnly", "test-bubbling-only", "bool", /*default=*/"false",
"Flag used for lit-testing bubbling patterns only. Not for general usage">,
Option<"enableAttentionVTranspose", "enable-attention-v-transpose", "bool",
/*default=*/"true", "Enable transposition of attention v operand">,
];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1092,15 +1092,19 @@ void PropagateLinalgTransposePass::runOnOperation() {
linalg::populateFoldReshapeOpsByExpansionPatterns(bubblingPatterns,
reshapePropagationFn);
linalg::FillOp::getCanonicalizationPatterns(bubblingPatterns, context);
linalg::ControlFusionFn bubbleTransposeControlFn =
[](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
Operation *consumer = fusedOperand->getOwner();

return IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer});
};
IREE::LinalgExt::populateBubbleTransposeFromLinalgExtOps(
bubblingPatterns, bubbleTransposeControlFn);
if (enableAttentionVTranspose) {
linalg::ControlFusionFn bubbleTransposeControlFn =
[](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
Operation *consumer = fusedOperand->getOwner();

return IREE::Flow::isNonNullAndOutsideDispatch(
{producer, consumer});
};
IREE::LinalgExt::populateBubbleTransposeFromLinalgExtOps(
bubblingPatterns, bubbleTransposeControlFn);
}
bubblingPatterns.insert<FuseTransposeWithProducerLinalgOp>(
context, enableAggressivePropagation);
bubblingPatterns.insert<BubbleTransposeThroughCollapseShape>(context);
Expand Down

0 comments on commit 618dac5

Please sign in to comment.