Skip to content

Commit 0367528

Browse files
fschlimbsilee2
authored andcommitted
add parallel loops even when in GPU env region
1 parent bc9bd53 commit 0367528

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

lib/Transforms/AddOuterParallelLoop.cpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,31 @@
1414
//===----------------------------------------------------------------------===//
1515

1616
#include "PassDetail.h"
17+
#include "imex/Dialect/Region/IR/RegionOps.h"
1718
#include "mlir/Dialect/Arith/IR/Arith.h"
1819
#include "mlir/Dialect/Func/IR/FuncOps.h"
1920
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
2021
#include "mlir/Dialect/SCF/IR/SCF.h"
2122
#include "llvm/ADT/SetVector.h"
22-
2323
using namespace mlir;
2424
using namespace imex;
2525

2626
namespace {
2727
struct AddOuterParallelLoopPass
2828
: public AddOuterParallelLoopBase<AddOuterParallelLoopPass> {
29-
public:
30-
void runOnOperation() override {
31-
auto func = getOperation();
32-
if (func.getBody().empty())
33-
return;
29+
private:
30+
void runOnBlock(::mlir::Block &block, ::mlir::Operation *parent,
31+
mlir::OpBuilder &builder) {
3432
llvm::SmallVector<llvm::SmallVector<Operation *, 4>, 4> groupedOps;
3533
// populate the top level for-loop
36-
for (auto topIt = func.getBody().front().begin();
37-
topIt != func.getBody().front().end();) {
34+
for (auto topIt = block.begin(); topIt != block.end();) {
35+
auto regOp = dyn_cast<::imex::region::EnvironmentRegionOp>(*topIt);
36+
if (regOp) {
37+
runOnBlock(regOp.getRegion().front(), regOp, builder);
38+
++topIt;
39+
continue;
40+
}
41+
3842
scf::ForOp forOp = dyn_cast<scf::ForOp>(*topIt++);
3943
if (!forOp) {
4044
continue;
@@ -58,7 +62,7 @@ struct AddOuterParallelLoopPass
5862
hasReturnOp = true;
5963
break;
6064
}
61-
while (user->getParentOp() != func) {
65+
while (user->getParentOp() != parent) {
6266
user = user->getParentOp();
6367
}
6468
topUsers.insert(user->getResults().begin(), user->getResults().end());
@@ -80,7 +84,6 @@ struct AddOuterParallelLoopPass
8084
}
8185
}
8286
// move the for-loop and its users into the newly created parallel-loop
83-
mlir::OpBuilder builder(func.getContext());
8487
for (const auto &ops : groupedOps) {
8588
auto op = ops.front();
8689
builder.setInsertionPoint(op);
@@ -97,6 +100,17 @@ struct AddOuterParallelLoopPass
97100
}
98101
}
99102
}
103+
104+
public:
105+
void runOnOperation() override {
106+
auto func = getOperation();
107+
if (func.getBody().empty())
108+
return;
109+
llvm::SmallVector<llvm::SmallVector<Operation *, 4>, 4> groupedOps;
110+
mlir::OpBuilder builder(func.getContext());
111+
// populate the top level for-loop
112+
runOnBlock(func.getBody().front(), func, builder);
113+
}
100114
};
101115
} // namespace
102116

0 commit comments

Comments
 (0)