Skip to content

Commit ed82f00

Browse files
[FuseConsumer] Update consumer fusion pass to be more generic
-- This commit updates consumer fusion pass to be more generic whilst trying to pick the compute op/loop to fuse with. -- Helps in [increasing L2 tile size](https://github.com/nod-ai/iree-amd-aie/tree/increase_L2_tile_sizes). Signed-off-by: Abhishek Varma <[email protected]>
1 parent af67ee2 commit ed82f00

File tree

3 files changed

+65
-91
lines changed

3 files changed

+65
-91
lines changed

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFuseConsumerIntoLoop.cpp

+65-66
Original file line numberDiff line numberDiff line change
@@ -36,77 +36,76 @@ void AMDAIEFuseConsumerIntoLoopPass::runOnOperation() {
3636
mlir::FunctionOpInterface funcOp = getOperation();
3737
IRRewriter rewriter(context);
3838

39-
// The depth until which we would keep fusing consumer chain.
39+
// Step 1. Find the first scf loop in postorder walk.
40+
Operation *scfLoopOp = nullptr;
41+
funcOp->walk<WalkOrder::PostOrder, ReverseIterator>(
42+
[&](LoopLikeOpInterface op) {
43+
if (isa<scf::ForOp>(op)) {
44+
scfLoopOp = op;
45+
return WalkResult::interrupt();
46+
} else if (isa<scf::ForallOp>(op)) {
47+
scfLoopOp = op;
48+
return WalkResult::interrupt();
49+
}
50+
return WalkResult::advance();
51+
});
52+
if (!scfLoopOp) {
53+
LLVM_DEBUG(llvm::dbgs()
54+
<< "There is no scf.for/forall loop to fuse with\n");
55+
return;
56+
}
57+
// Step 2. Search the compute op and its consumer slices.
58+
linalg::LinalgOp linalgOp;
59+
scfLoopOp->walk<WalkOrder::PostOrder, ReverseIterator>(
60+
[&](linalg::LinalgOp op) {
61+
linalgOp = op;
62+
return WalkResult::interrupt();
63+
});
64+
if (!linalgOp) {
65+
LLVM_DEBUG(llvm::dbgs() << "Could not find any compute op\n");
66+
return;
67+
}
68+
// Step 3. The depth until which we would keep fusing consumer chain.
4069
// TODO(avarma): This should also be part of KernelDispatch logic.
4170
unsigned fuseDepth = 1;
4271
// Check if there is matmul-elementwise fusion opportunity. If so, overwrite
43-
// the `fuseDepth` to be 2.
44-
funcOp->walk<WalkOrder::PostOrder, ReverseIterator>([&](linalg::LinalgOp op) {
45-
if (isMatmulProducerOfElementwise(op)) {
46-
fuseDepth = 2;
47-
return WalkResult::interrupt();
48-
}
49-
return WalkResult::advance();
50-
});
72+
// the `fuseDepth` to be 2
73+
if (isMatmulProducerOfElementwise(linalgOp)) {
74+
fuseDepth = 2;
75+
}
5176

52-
// Based on the `fuseDepth`, we would greedily fuse the consumer ops.
77+
Operation *computeOp = linalgOp;
78+
// Step 4. Based on the `fuseDepth`, we would greedily fuse the consumer ops.
5379
for (unsigned depth = 1; depth <= fuseDepth; depth++) {
54-
// Walk through the graph in post order and find the loop.
55-
Operation *scfLoopOp = nullptr;
56-
funcOp->walk<WalkOrder::PostOrder, ReverseIterator>(
57-
[&](LoopLikeOpInterface op) {
58-
if (isa<scf::ForOp>(op) && useSCFFor) {
59-
scfLoopOp = op;
60-
return WalkResult::interrupt();
61-
} else if (isa<scf::ForallOp>(op) && !useSCFFor) {
62-
scfLoopOp = op;
63-
return WalkResult::interrupt();
64-
}
65-
return WalkResult::advance();
66-
});
67-
68-
if (!scfLoopOp) {
69-
LLVM_DEBUG(llvm::dbgs()
70-
<< "There is no scf.for/forall loop to fuse with\n");
71-
return;
72-
}
73-
74-
// Search the compute op and its consumer slices.
75-
linalg::LinalgOp linalgOp;
76-
scfLoopOp->walk<WalkOrder::PostOrder, ReverseIterator>(
77-
[&](linalg::LinalgOp op) {
78-
linalgOp = op;
79-
return WalkResult::interrupt();
80-
});
81-
82-
if (!linalgOp) {
83-
LLVM_DEBUG(llvm::dbgs() << "Could not find any compute op\n");
84-
return;
85-
}
86-
87-
Value::user_range users = linalgOp->getResult(0).getUsers();
88-
if (!llvm::hasSingleElement(users)) {
89-
linalgOp->emitOpError("Expected only one user of the compute op");
90-
return signalPassFailure();
91-
}
92-
93-
Operation *terminatorStoreOp = *(users.begin());
94-
if (!(isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
95-
terminatorStoreOp))) {
96-
terminatorStoreOp->emitOpError(
97-
"Expected either tensor.insert_slice OR tensor.parallel_insert_slice "
98-
"to be the only user of the compute op");
99-
return signalPassFailure();
100-
}
101-
102-
std::optional<scf::SCFFuseConsumerOfSliceResult> fusedConsumer =
103-
scf::tileAndFuseConsumerOfSlice(rewriter, terminatorStoreOp);
104-
if (!fusedConsumer) {
105-
terminatorStoreOp->emitOpError(
106-
"Failed to fuse any consumer op into the producer");
107-
return signalPassFailure();
108-
}
109-
fusedConsumer->origConsumerOperand->getOwner()->erase();
80+
LLVM_DEBUG(llvm::dbgs() << "Compute op = " << (*computeOp) << "\n");
81+
do {
82+
Value::user_range users = computeOp->getResult(0).getUsers();
83+
if (!llvm::hasSingleElement(users)) {
84+
computeOp->emitOpError("Expected only one user of the compute op");
85+
return signalPassFailure();
86+
}
87+
88+
Operation *terminatorStoreOp = *(users.begin());
89+
if (!(isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
90+
terminatorStoreOp))) {
91+
computeOp = computeOp->getParentOfType<LoopLikeOpInterface>();
92+
LLVM_DEBUG(llvm::dbgs()
93+
<< "Going to reattempt fusion because didn't find "
94+
"tensor.insert_slice/tensor.parallel_insert_slice as the "
95+
"user of the compute op\n");
96+
continue;
97+
}
98+
std::optional<scf::SCFFuseConsumerOfSliceResult> fusedConsumer =
99+
scf::tileAndFuseConsumerOfSlice(rewriter, terminatorStoreOp);
100+
if (!fusedConsumer) {
101+
terminatorStoreOp->emitOpError(
102+
"Failed to fuse any consumer op into the producer");
103+
return signalPassFailure();
104+
}
105+
fusedConsumer->origConsumerOperand->getOwner()->erase();
106+
computeOp = fusedConsumer->tiledAndFusedConsumerOperand->getOwner();
107+
break;
108+
} while (computeOp && computeOp->getParentOp() != funcOp);
110109
}
111110
}
112111

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/fuse_consumer_into_loop_scf_for.mlir

-14
Original file line numberDiff line numberDiff line change
@@ -195,20 +195,6 @@ func.func @no_user_of_producer(%arg0: tensor<64xf32>) {
195195

196196
// -----
197197

198-
func.func @insert_slice_not_found(%arg0: tensor<64xf32>) -> tensor<64xf32> {
199-
%c0 = arith.constant 0 : index
200-
%c1 = arith.constant 1 : index
201-
%c8 = arith.constant 8 : index
202-
%0 = scf.for %arg1 = %c0 to %c8 step %c1 iter_args(%out = %arg0) -> tensor<64xf32> {
203-
%1 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%arg0, %arg0 : tensor<64xf32>, tensor<64xf32>) outs(%arg0 : tensor<64xf32>) -> tensor<64xf32>
204-
// expected-error @+1 {{Expected either tensor.insert_slice OR tensor.parallel_insert_slice to be the only user of the compute op}}
205-
scf.yield %1 : tensor<64xf32>
206-
}
207-
return %0 : tensor<64xf32>
208-
}
209-
210-
// -----
211-
212198
func.func @no_consumer_fusion(%arg0: tensor<64xf32>) -> tensor<64xf32> {
213199
%c0 = arith.constant 0 : index
214200
%c1 = arith.constant 1 : index

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/fuse_consumer_into_loop_scf_forall.mlir

-11
Original file line numberDiff line numberDiff line change
@@ -188,17 +188,6 @@ func.func @no_user_of_producer(%arg0: tensor<64xf32>) {
188188

189189
// -----
190190

191-
func.func @parallel_insert_slice_not_found(%arg0: tensor<64xf32>) -> tensor<64xf32> {
192-
%0 = scf.forall (%arg1, %arg2) in (1,2) shared_outs(%out = %arg0) -> tensor<64xf32> {
193-
%1 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%arg0, %arg0 : tensor<64xf32>, tensor<64xf32>) outs(%arg0 : tensor<64xf32>) -> tensor<64xf32>
194-
// expected-error @+1 {{Expected either tensor.insert_slice OR tensor.parallel_insert_slice to be the only user of the compute op}}
195-
%2 = arith.mulf %1, %out : tensor<64xf32>
196-
}
197-
return %0 : tensor<64xf32>
198-
}
199-
200-
// -----
201-
202191
func.func @no_consumer_fusion(%arg0: tensor<64xf32>) -> tensor<64xf32> {
203192
%0 = scf.forall (%arg1, %arg2) in (1,2) shared_outs(%out = %arg0) -> tensor<64xf32> {
204193
%1 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%arg0, %arg0 : tensor<64xf32>, tensor<64xf32>) outs(%arg0 : tensor<64xf32>) -> tensor<64xf32>

0 commit comments

Comments
 (0)