14
14
// ===----------------------------------------------------------------------===//
15
15
16
16
#include " PassDetail.h"
17
+ #include " imex/Dialect/Region/IR/RegionOps.h"
17
18
#include " mlir/Dialect/Arith/IR/Arith.h"
18
19
#include " mlir/Dialect/Func/IR/FuncOps.h"
19
20
#include " mlir/Dialect/GPU/IR/GPUDialect.h"
20
21
#include " mlir/Dialect/SCF/IR/SCF.h"
21
22
#include " llvm/ADT/SetVector.h"
22
-
23
23
using namespace mlir ;
24
24
using namespace imex ;
25
25
26
26
namespace {
27
27
struct AddOuterParallelLoopPass
28
28
: public AddOuterParallelLoopBase<AddOuterParallelLoopPass> {
29
- public:
30
- void runOnOperation () override {
31
- auto func = getOperation ();
32
- if (func.getBody ().empty ())
33
- return ;
29
+ private:
30
+ void runOnBlock (::mlir::Block &block, ::mlir::Operation *parent,
31
+ mlir::OpBuilder &builder) {
34
32
llvm::SmallVector<llvm::SmallVector<Operation *, 4 >, 4 > groupedOps;
35
33
// populate the top level for-loop
36
- for (auto topIt = func.getBody ().front ().begin ();
37
- topIt != func.getBody ().front ().end ();) {
34
+ for (auto topIt = block.begin (); topIt != block.end ();) {
35
+ auto regOp = dyn_cast<::imex::region::EnvironmentRegionOp>(*topIt);
36
+ if (regOp) {
37
+ runOnBlock (regOp.getRegion ().front (), regOp, builder);
38
+ ++topIt;
39
+ continue ;
40
+ }
41
+
38
42
scf::ForOp forOp = dyn_cast<scf::ForOp>(*topIt++);
39
43
if (!forOp) {
40
44
continue ;
@@ -58,7 +62,7 @@ struct AddOuterParallelLoopPass
58
62
hasReturnOp = true ;
59
63
break ;
60
64
}
61
- while (user->getParentOp () != func ) {
65
+ while (user->getParentOp () != parent ) {
62
66
user = user->getParentOp ();
63
67
}
64
68
topUsers.insert (user->getResults ().begin (), user->getResults ().end ());
@@ -80,7 +84,6 @@ struct AddOuterParallelLoopPass
80
84
}
81
85
}
82
86
// move the for-loop and its users into the newly created parallel-loop
83
- mlir::OpBuilder builder (func.getContext ());
84
87
for (const auto &ops : groupedOps) {
85
88
auto op = ops.front ();
86
89
builder.setInsertionPoint (op);
@@ -97,6 +100,17 @@ struct AddOuterParallelLoopPass
97
100
}
98
101
}
99
102
}
103
+
104
+ public:
105
+ void runOnOperation () override {
106
+ auto func = getOperation ();
107
+ if (func.getBody ().empty ())
108
+ return ;
109
+ llvm::SmallVector<llvm::SmallVector<Operation *, 4 >, 4 > groupedOps;
110
+ mlir::OpBuilder builder (func.getContext ());
111
+ // populate the top level for-loop
112
+ runOnBlock (func.getBody ().front (), func, builder);
113
+ }
100
114
};
101
115
} // namespace
102
116
0 commit comments