Skip to content

Commit

Permalink
[CIR][Lowering][NFC] Refactor LoopOp lowering
Browse files Browse the repository at this point in the history
Leverages the new LoopOpInterface for lowering instead of the LoopOp
operation. This is a step towards removing the LoopOp operation.

ghstack-source-id: 28c1294833a12669d222a293de76609d2cf19148
Pull Request resolved: #406
  • Loading branch information
sitio-couto authored and lanza committed Apr 29, 2024
1 parent e9fdf3a commit f5d53c3
Showing 1 changed file with 45 additions and 61 deletions.
106 changes: 45 additions & 61 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,10 +405,11 @@ class CIRPtrStrideOpLowering
}
};

class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
class CIRLoopOpInterfaceLowering
: public mlir::OpInterfaceConversionPattern<mlir::cir::LoopOpInterface> {
public:
using mlir::OpConversionPattern<mlir::cir::LoopOp>::OpConversionPattern;
using LoopKind = mlir::cir::LoopOpKind;
using mlir::OpInterfaceConversionPattern<
mlir::cir::LoopOpInterface>::OpInterfaceConversionPattern;

inline void
lowerConditionOp(mlir::cir::ConditionOp op, mlir::Block *body,
Expand All @@ -421,76 +422,59 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
}

mlir::LogicalResult
matchAndRewrite(mlir::cir::LoopOp loopOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto kind = loopOp.getKind();
auto *currentBlock = rewriter.getInsertionBlock();
auto *continueBlock =
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
matchAndRewrite(mlir::cir::LoopOpInterface op,
mlir::ArrayRef<mlir::Value> operands,
mlir::ConversionPatternRewriter &rewriter) const final {
// Setup CFG blocks.
auto *entry = rewriter.getInsertionBlock();
auto *exit = rewriter.splitBlock(entry, rewriter.getInsertionPoint());
auto *cond = &op.getCond().front();
auto *body = &op.getBody().front();
auto *step = (op.maybeGetStep() ? &op.maybeGetStep()->front() : nullptr);

// Setup loop entry branch.
rewriter.setInsertionPointToEnd(entry);
rewriter.create<mlir::LLVM::BrOp>(op.getLoc(), &op.getEntry().front());

// Fetch required info from the condition region.
auto &condRegion = loopOp.getCond();
auto &condFrontBlock = condRegion.front();

// Fetch required info from the body region.
auto &bodyRegion = loopOp.getBody();
auto &bodyFrontBlock = bodyRegion.front();
auto bodyYield =
dyn_cast<mlir::cir::YieldOp>(bodyRegion.back().getTerminator());
// Branch from condition region to body or exit.
auto conditionOp = cast<mlir::cir::ConditionOp>(cond->getTerminator());
lowerConditionOp(conditionOp, body, exit, rewriter);

// Fetch required info from the step region.
auto &stepRegion = loopOp.getStep();
auto &stepFrontBlock = stepRegion.front();
auto stepYield =
dyn_cast<mlir::cir::YieldOp>(stepRegion.back().getTerminator());
auto &stepBlock = (kind == LoopKind::For ? stepFrontBlock : condFrontBlock);
// TODO(cir): Remove the walks below. It visits operations unnecessarily,
// however, to solve this we would likely need a custom DialecConversion
// driver to customize the order that operations are visited.

// Lower continue statements.
mlir::Block &dest =
(kind != LoopKind::For ? condFrontBlock : stepFrontBlock);
walkRegionSkipping<mlir::cir::LoopOpInterface>(
loopOp.getBody(), [&](mlir::Operation *op) {
if (isa<mlir::cir::ContinueOp>(op))
lowerTerminator(op, &dest, rewriter);
});
mlir::Block *dest = (step ? step : cond);
op.walkBodySkippingNestedLoops([&](mlir::Operation *op) {
if (isa<mlir::cir::ContinueOp>(op))
lowerTerminator(op, dest, rewriter);
});

// Lower break statements.
walkRegionSkipping<mlir::cir::LoopOpInterface, mlir::cir::SwitchOp>(
loopOp.getBody(), [&](mlir::Operation *op) {
op.getBody(), [&](mlir::Operation *op) {
if (isa<mlir::cir::BreakOp>(op))
lowerTerminator(op, continueBlock, rewriter);
lowerTerminator(op, exit, rewriter);
});

// Move loop op region contents to current CFG.
rewriter.inlineRegionBefore(condRegion, continueBlock);
rewriter.inlineRegionBefore(bodyRegion, continueBlock);
if (kind == LoopKind::For) // Ignore step if not a for-loop.
rewriter.inlineRegionBefore(stepRegion, continueBlock);
// Lower optional body region yield.
auto bodyYield = dyn_cast<mlir::cir::YieldOp>(body->getTerminator());
if (bodyYield)
lowerTerminator(bodyYield, (step ? step : cond), rewriter);

// Set loop entry point to condition or to body in do-while cases.
rewriter.setInsertionPointToEnd(currentBlock);
auto &entry = (kind != LoopKind::DoWhile ? condFrontBlock : bodyFrontBlock);
rewriter.create<mlir::cir::BrOp>(loopOp.getLoc(), &entry);

// Branch from condition region to body or exit.
auto conditionOp =
cast<mlir::cir::ConditionOp>(condFrontBlock.getTerminator());
lowerConditionOp(conditionOp, &bodyFrontBlock, continueBlock, rewriter);

// Branch from body to condition or to step on for-loop cases.
if (bodyYield) {
rewriter.setInsertionPoint(bodyYield);
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(bodyYield, &stepBlock);
}
// Lower mandatory step region yield.
if (step)
lowerTerminator(cast<mlir::cir::YieldOp>(step->getTerminator()), cond,
rewriter);

// Is a for loop: branch from step to condition.
if (kind == LoopKind::For) {
rewriter.setInsertionPoint(stepYield);
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(stepYield, &condFrontBlock);
}
// Move region contents out of the loop op.
rewriter.inlineRegionBefore(op.getCond(), exit);
rewriter.inlineRegionBefore(op.getBody(), exit);
if (step)
rewriter.inlineRegionBefore(*op.maybeGetStep(), exit);

// Remove the loop op.
rewriter.eraseOp(loopOp);
rewriter.eraseOp(op);
return mlir::success();
}
};
Expand Down Expand Up @@ -2094,7 +2078,7 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &converter) {
patterns.add<CIRReturnLowering>(patterns.getContext());
patterns.add<
CIRCmpOpLowering, CIRLoopOpLowering, CIRBrCondOpLowering,
CIRCmpOpLowering, CIRLoopOpInterfaceLowering, CIRBrCondOpLowering,
CIRPtrStrideOpLowering, CIRCallLowering, CIRUnaryOpLowering,
CIRBinOpLowering, CIRShiftOpLowering, CIRLoadLowering,
CIRConstantLowering, CIRStoreLowering, CIRAllocaLowering, CIRFuncLowering,
Expand Down

0 comments on commit f5d53c3

Please sign in to comment.