diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 5c977055e95dc..d3bdc364b0751 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -66,7 +66,8 @@ createCanonicalizerPass(const GreedyRewriteConfig &config, ArrayRef enabledPatterns = std::nullopt); /// Creates a pass to perform control-flow sinking. -std::unique_ptr createControlFlowSinkPass(); +std::unique_ptr createControlFlowSinkPass( + function_ref shouldMoveIntoRegion = nullptr); /// Creates a pass to perform common sub expression elimination. std::unique_ptr createCSEPass(); diff --git a/mlir/lib/Transforms/ControlFlowSink.cpp b/mlir/lib/Transforms/ControlFlowSink.cpp index 4e1dfa1c7c83f..fb9676a2275d6 100644 --- a/mlir/lib/Transforms/ControlFlowSink.cpp +++ b/mlir/lib/Transforms/ControlFlowSink.cpp @@ -30,7 +30,12 @@ using namespace mlir; namespace { /// A control-flow sink pass. struct ControlFlowSink : public impl::ControlFlowSinkBase { + ControlFlowSink( + function_ref shouldMoveIntoRegion) + : shouldMoveIntoRegion(shouldMoveIntoRegion) {} void runOnOperation() override; + + function_ref shouldMoveIntoRegion; }; } // end anonymous namespace @@ -40,19 +45,25 @@ void ControlFlowSink::runOnOperation() { SmallVector regionsToSink; // Get the regions are that known to be executed at most once. getSinglyExecutedRegionsToSink(branch, regionsToSink); - // Sink side-effect free operations. - numSunk = controlFlowSink( - regionsToSink, domInfo, - [](Operation *op, Region *) { return isMemoryEffectFree(op); }, - [](Operation *op, Region *region) { - // Move the operation to the beginning of the region's entry block. - // This guarantees the preservation of SSA dominance of all of the - // operation's uses are in the region. - op->moveBefore(®ion->front(), region->front().begin()); - }); + numSunk = controlFlowSink(regionsToSink, domInfo, shouldMoveIntoRegion, + [](Operation *op, Region *region) { + // Move the operation to the beginning of the + // region's entry block. This guarantees the + // preservation of SSA dominance of all of the + // operation's uses are in the region. + op->moveBefore(®ion->front(), + region->front().begin()); + }); }); } -std::unique_ptr mlir::createControlFlowSinkPass() { - return std::make_unique(); +std::unique_ptr mlir::createControlFlowSinkPass( + function_ref shouldMoveIntoRegion) { + if (!shouldMoveIntoRegion) { + // Sink side-effect free operations. + shouldMoveIntoRegion = [](Operation *op, Region *) { + return isMemoryEffectFree(op); + }; + } + return std::make_unique(shouldMoveIntoRegion); }