@@ -36,77 +36,76 @@ void AMDAIEFuseConsumerIntoLoopPass::runOnOperation() {
36
36
mlir::FunctionOpInterface funcOp = getOperation ();
37
37
IRRewriter rewriter (context);
38
38
39
- // The depth until which we would keep fusing consumer chain.
39
+ // Step 1. Find the first scf loop in postorder walk.
40
+ Operation *scfLoopOp = nullptr ;
41
+ funcOp->walk <WalkOrder::PostOrder, ReverseIterator>(
42
+ [&](LoopLikeOpInterface op) {
43
+ if (isa<scf::ForOp>(op)) {
44
+ scfLoopOp = op;
45
+ return WalkResult::interrupt ();
46
+ } else if (isa<scf::ForallOp>(op)) {
47
+ scfLoopOp = op;
48
+ return WalkResult::interrupt ();
49
+ }
50
+ return WalkResult::advance ();
51
+ });
52
+ if (!scfLoopOp) {
53
+ LLVM_DEBUG (llvm::dbgs ()
54
+ << " There is no scf.for/forall loop to fuse with\n " );
55
+ return ;
56
+ }
57
+ // Step 2. Search the compute op and its consumer slices.
58
+ linalg::LinalgOp linalgOp;
59
+ scfLoopOp->walk <WalkOrder::PostOrder, ReverseIterator>(
60
+ [&](linalg::LinalgOp op) {
61
+ linalgOp = op;
62
+ return WalkResult::interrupt ();
63
+ });
64
+ if (!linalgOp) {
65
+ LLVM_DEBUG (llvm::dbgs () << " Could not find any compute op\n " );
66
+ return ;
67
+ }
68
+ // Step 3. The depth until which we would keep fusing consumer chain.
40
69
// TODO(avarma): This should also be part of KernelDispatch logic.
41
70
unsigned fuseDepth = 1 ;
42
71
// Check if there is matmul-elementwise fusion opportunity. If so, overwrite
43
- // the `fuseDepth` to be 2.
44
- funcOp->walk <WalkOrder::PostOrder, ReverseIterator>([&](linalg::LinalgOp op) {
45
- if (isMatmulProducerOfElementwise (op)) {
46
- fuseDepth = 2 ;
47
- return WalkResult::interrupt ();
48
- }
49
- return WalkResult::advance ();
50
- });
72
+ // the `fuseDepth` to be 2
73
+ if (isMatmulProducerOfElementwise (linalgOp)) {
74
+ fuseDepth = 2 ;
75
+ }
51
76
52
- // Based on the `fuseDepth`, we would greedily fuse the consumer ops.
77
+ Operation *computeOp = linalgOp;
78
+ // Step 4. Based on the `fuseDepth`, we would greedily fuse the consumer ops.
53
79
for (unsigned depth = 1 ; depth <= fuseDepth; depth++) {
54
- // Walk through the graph in post order and find the loop.
55
- Operation *scfLoopOp = nullptr ;
56
- funcOp->walk <WalkOrder::PostOrder, ReverseIterator>(
57
- [&](LoopLikeOpInterface op) {
58
- if (isa<scf::ForOp>(op) && useSCFFor) {
59
- scfLoopOp = op;
60
- return WalkResult::interrupt ();
61
- } else if (isa<scf::ForallOp>(op) && !useSCFFor) {
62
- scfLoopOp = op;
63
- return WalkResult::interrupt ();
64
- }
65
- return WalkResult::advance ();
66
- });
67
-
68
- if (!scfLoopOp) {
69
- LLVM_DEBUG (llvm::dbgs ()
70
- << " There is no scf.for/forall loop to fuse with\n " );
71
- return ;
72
- }
73
-
74
- // Search the compute op and its consumer slices.
75
- linalg::LinalgOp linalgOp;
76
- scfLoopOp->walk <WalkOrder::PostOrder, ReverseIterator>(
77
- [&](linalg::LinalgOp op) {
78
- linalgOp = op;
79
- return WalkResult::interrupt ();
80
- });
81
-
82
- if (!linalgOp) {
83
- LLVM_DEBUG (llvm::dbgs () << " Could not find any compute op\n " );
84
- return ;
85
- }
86
-
87
- Value::user_range users = linalgOp->getResult (0 ).getUsers ();
88
- if (!llvm::hasSingleElement (users)) {
89
- linalgOp->emitOpError (" Expected only one user of the compute op" );
90
- return signalPassFailure ();
91
- }
92
-
93
- Operation *terminatorStoreOp = *(users.begin ());
94
- if (!(isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
95
- terminatorStoreOp))) {
96
- terminatorStoreOp->emitOpError (
97
- " Expected either tensor.insert_slice OR tensor.parallel_insert_slice "
98
- " to be the only user of the compute op" );
99
- return signalPassFailure ();
100
- }
101
-
102
- std::optional<scf::SCFFuseConsumerOfSliceResult> fusedConsumer =
103
- scf::tileAndFuseConsumerOfSlice (rewriter, terminatorStoreOp);
104
- if (!fusedConsumer) {
105
- terminatorStoreOp->emitOpError (
106
- " Failed to fuse any consumer op into the producer" );
107
- return signalPassFailure ();
108
- }
109
- fusedConsumer->origConsumerOperand ->getOwner ()->erase ();
80
+ LLVM_DEBUG (llvm::dbgs () << " Compute op = " << (*computeOp) << " \n " );
81
+ do {
82
+ Value::user_range users = computeOp->getResult (0 ).getUsers ();
83
+ if (!llvm::hasSingleElement (users)) {
84
+ computeOp->emitOpError (" Expected only one user of the compute op" );
85
+ return signalPassFailure ();
86
+ }
87
+
88
+ Operation *terminatorStoreOp = *(users.begin ());
89
+ if (!(isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
90
+ terminatorStoreOp))) {
91
+ computeOp = computeOp->getParentOfType <LoopLikeOpInterface>();
92
+ LLVM_DEBUG (llvm::dbgs ()
93
+ << " Going to reattempt fusion because didn't find "
94
+ " tensor.insert_slice/tensor.parallel_insert_slice as the "
95
+ " user of the compute op\n " );
96
+ continue ;
97
+ }
98
+ std::optional<scf::SCFFuseConsumerOfSliceResult> fusedConsumer =
99
+ scf::tileAndFuseConsumerOfSlice (rewriter, terminatorStoreOp);
100
+ if (!fusedConsumer) {
101
+ terminatorStoreOp->emitOpError (
102
+ " Failed to fuse any consumer op into the producer" );
103
+ return signalPassFailure ();
104
+ }
105
+ fusedConsumer->origConsumerOperand ->getOwner ()->erase ();
106
+ computeOp = fusedConsumer->tiledAndFusedConsumerOperand ->getOwner ();
107
+ break ;
108
+ } while (computeOp && computeOp->getParentOp () != funcOp);
110
109
}
111
110
}
112
111
0 commit comments