From 5d0277164017dec400ca6cc82488d796b9abfd7f Mon Sep 17 00:00:00 2001 From: Vinicius Couto Espindola Date: Mon, 22 Jan 2024 17:21:01 -0300 Subject: [PATCH] [CIR][Lowering][NFC] Refactor LoopOp lowering Leverages the new LoopOpInterface for lowering instead of the LoopOp operation. This is a step towards removing the LoopOp operation. ghstack-source-id: 28c1294833a12669d222a293de76609d2cf19148 Pull Request resolved: https://github.com/llvm/clangir/pull/406 --- .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 106 ++++++++---------- 1 file changed, 45 insertions(+), 61 deletions(-) diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index f6ddc2553a9f..3317e2654bc6 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -405,10 +405,11 @@ class CIRPtrStrideOpLowering } }; -class CIRLoopOpLowering : public mlir::OpConversionPattern { +class CIRLoopOpInterfaceLowering + : public mlir::OpInterfaceConversionPattern { public: - using mlir::OpConversionPattern::OpConversionPattern; - using LoopKind = mlir::cir::LoopOpKind; + using mlir::OpInterfaceConversionPattern< + mlir::cir::LoopOpInterface>::OpInterfaceConversionPattern; inline void lowerConditionOp(mlir::cir::ConditionOp op, mlir::Block *body, @@ -421,76 +422,59 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern { } mlir::LogicalResult - matchAndRewrite(mlir::cir::LoopOp loopOp, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto kind = loopOp.getKind(); - auto *currentBlock = rewriter.getInsertionBlock(); - auto *continueBlock = - rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + matchAndRewrite(mlir::cir::LoopOpInterface op, + mlir::ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const final { + // Setup CFG blocks. + auto *entry = rewriter.getInsertionBlock(); + auto *exit = rewriter.splitBlock(entry, rewriter.getInsertionPoint()); + auto *cond = &op.getCond().front(); + auto *body = &op.getBody().front(); + auto *step = (op.maybeGetStep() ? &op.maybeGetStep()->front() : nullptr); + + // Setup loop entry branch. + rewriter.setInsertionPointToEnd(entry); + rewriter.create(op.getLoc(), &op.getEntry().front()); - // Fetch required info from the condition region. - auto &condRegion = loopOp.getCond(); - auto &condFrontBlock = condRegion.front(); - - // Fetch required info from the body region. - auto &bodyRegion = loopOp.getBody(); - auto &bodyFrontBlock = bodyRegion.front(); - auto bodyYield = - dyn_cast(bodyRegion.back().getTerminator()); + // Branch from condition region to body or exit. + auto conditionOp = cast(cond->getTerminator()); + lowerConditionOp(conditionOp, body, exit, rewriter); - // Fetch required info from the step region. - auto &stepRegion = loopOp.getStep(); - auto &stepFrontBlock = stepRegion.front(); - auto stepYield = - dyn_cast(stepRegion.back().getTerminator()); - auto &stepBlock = (kind == LoopKind::For ? stepFrontBlock : condFrontBlock); + // TODO(cir): Remove the walks below. It visits operations unnecessarily, + // however, to solve this we would likely need a custom DialecConversion + // driver to customize the order that operations are visited. // Lower continue statements. - mlir::Block &dest = - (kind != LoopKind::For ? condFrontBlock : stepFrontBlock); - walkRegionSkipping( - loopOp.getBody(), [&](mlir::Operation *op) { - if (isa(op)) - lowerTerminator(op, &dest, rewriter); - }); + mlir::Block *dest = (step ? step : cond); + op.walkBodySkippingNestedLoops([&](mlir::Operation *op) { + if (isa(op)) + lowerTerminator(op, dest, rewriter); + }); // Lower break statements. walkRegionSkipping( - loopOp.getBody(), [&](mlir::Operation *op) { + op.getBody(), [&](mlir::Operation *op) { if (isa(op)) - lowerTerminator(op, continueBlock, rewriter); + lowerTerminator(op, exit, rewriter); }); - // Move loop op region contents to current CFG. - rewriter.inlineRegionBefore(condRegion, continueBlock); - rewriter.inlineRegionBefore(bodyRegion, continueBlock); - if (kind == LoopKind::For) // Ignore step if not a for-loop. - rewriter.inlineRegionBefore(stepRegion, continueBlock); + // Lower optional body region yield. + auto bodyYield = dyn_cast(body->getTerminator()); + if (bodyYield) + lowerTerminator(bodyYield, (step ? step : cond), rewriter); - // Set loop entry point to condition or to body in do-while cases. - rewriter.setInsertionPointToEnd(currentBlock); - auto &entry = (kind != LoopKind::DoWhile ? condFrontBlock : bodyFrontBlock); - rewriter.create(loopOp.getLoc(), &entry); - - // Branch from condition region to body or exit. - auto conditionOp = - cast(condFrontBlock.getTerminator()); - lowerConditionOp(conditionOp, &bodyFrontBlock, continueBlock, rewriter); - - // Branch from body to condition or to step on for-loop cases. - if (bodyYield) { - rewriter.setInsertionPoint(bodyYield); - rewriter.replaceOpWithNewOp(bodyYield, &stepBlock); - } + // Lower mandatory step region yield. + if (step) + lowerTerminator(cast(step->getTerminator()), cond, + rewriter); - // Is a for loop: branch from step to condition. - if (kind == LoopKind::For) { - rewriter.setInsertionPoint(stepYield); - rewriter.replaceOpWithNewOp(stepYield, &condFrontBlock); - } + // Move region contents out of the loop op. + rewriter.inlineRegionBefore(op.getCond(), exit); + rewriter.inlineRegionBefore(op.getBody(), exit); + if (step) + rewriter.inlineRegionBefore(*op.maybeGetStep(), exit); - // Remove the loop op. - rewriter.eraseOp(loopOp); + rewriter.eraseOp(op); return mlir::success(); } }; @@ -2094,7 +2078,7 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns, mlir::TypeConverter &converter) { patterns.add(patterns.getContext()); patterns.add< - CIRCmpOpLowering, CIRLoopOpLowering, CIRBrCondOpLowering, + CIRCmpOpLowering, CIRLoopOpInterfaceLowering, CIRBrCondOpLowering, CIRPtrStrideOpLowering, CIRCallLowering, CIRUnaryOpLowering, CIRBinOpLowering, CIRShiftOpLowering, CIRLoadLowering, CIRConstantLowering, CIRStoreLowering, CIRAllocaLowering, CIRFuncLowering,