diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 00d771a75c63..1c9b76c7f38e 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -3570,6 +3570,15 @@ def TryOp : CIR_Op<"try", attr-dict }]; + let extraClassDeclaration = [{ + private: + mlir::Region *getCatchLastRegion(); + public: + mlir::Block *getCatchAllEntryBlock(); + mlir::Block *getCatchUnwindEntryBlock(); + bool isCatchAllOnly(); + }]; + // Everything already covered elsewhere. let hasVerifier = 0; let builders = [ diff --git a/clang/lib/CIR/CodeGen/CIRGenException.cpp b/clang/lib/CIR/CodeGen/CIRGenException.cpp index d7cea55dd462..38a94d6f6e19 100644 --- a/clang/lib/CIR/CodeGen/CIRGenException.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenException.cpp @@ -252,16 +252,6 @@ void CIRGenFunction::buildAnyExprToExn(const Expr *e, Address addr) { DeactivateCleanupBlock(cleanup, op); } -static mlir::Block *getResumeBlockFromCatch(mlir::cir::TryOp &tryOp, - mlir::cir::GlobalOp globalParent) { - assert(tryOp && "cir.try expected"); - unsigned numCatchRegions = tryOp.getCatchRegions().size(); - assert(numCatchRegions && "expected at least one region"); - auto &fallbackRegion = tryOp.getCatchRegions()[numCatchRegions - 1]; - return &fallbackRegion.getBlocks().back(); - return nullptr; -} - mlir::Block *CIRGenFunction::getEHResumeBlock(bool isCleanup, mlir::cir::TryOp tryOp) { @@ -270,7 +260,8 @@ mlir::Block *CIRGenFunction::getEHResumeBlock(bool isCleanup, // Just like some other try/catch related logic: return the basic block // pointer but only use it to denote we're tracking things, but there // shouldn't be any changes to that block after work done in this function. - ehResumeBlock = getResumeBlockFromCatch(tryOp, CGM.globalOpContext); + assert(tryOp && "expected available cir.try"); + ehResumeBlock = tryOp.getCatchUnwindEntryBlock(); if (!ehResumeBlock->empty()) return ehResumeBlock; diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index a9c445b08796..282ba04875dc 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -1243,6 +1243,27 @@ void TryOp::build( catchBuilder(builder, result.location, result); } +mlir::Region *TryOp::getCatchLastRegion() { + unsigned numCatchRegions = getCatchRegions().size(); + assert(numCatchRegions && "expected at least one region"); + auto &lastRegion = getCatchRegions()[numCatchRegions - 1]; + return &lastRegion; +} + +mlir::Block *TryOp::getCatchUnwindEntryBlock() { + return &getCatchLastRegion()->getBlocks().front(); +} + +mlir::Block *TryOp::getCatchAllEntryBlock() { + return &getCatchLastRegion()->getBlocks().front(); +} + +bool TryOp::isCatchAllOnly() { + mlir::ArrayAttr catchAttrList = getCatchTypesAttr(); + return catchAttrList.size() == 1 && + isa(catchAttrList[0]); +} + void TryOp::getSuccessorRegions(mlir::RegionBranchPoint point, SmallVectorImpl ®ions) { // If any index all the underlying regions branch back to the parent diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp index 122ee1fe07aa..10d835fc8aa1 100644 --- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp +++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp @@ -418,8 +418,6 @@ class CIRTryOpFlattening : public mlir::OpRewritePattern { // Do not update `nextDispatcher`, no more business in try/catch } else if (auto catchUnwind = dyn_cast(catchAttr)) { - // assert(dispatcher->empty() && "expect empty dispatcher"); - // assert(!dispatcher->args_empty() && "expected block argument"); assert(dispatcher->getArguments().size() == 2 && "expected two block argument"); buildUnwindCase(rewriter, catchRegion, dispatcher); @@ -440,15 +438,10 @@ class CIRTryOpFlattening : public mlir::OpRewritePattern { rewriter.setInsertionPointToEnd(beforeCatch); rewriter.replaceOpWithNewOp(tryBodyYield, afterTry); - // Retrieve catch list and some properties. - mlir::ArrayAttr catchAttrList = tryOp.getCatchTypesAttr(); - bool tryOnlyHasCatchAll = catchAttrList.size() == 1 && - isa(catchAttrList[0]); - // Start the landing pad by getting the inflight exception information. mlir::Block *nextDispatcher = buildLandingPads(tryOp, rewriter, beforeCatch, afterTry, callsToRewrite, - landingPads, tryOnlyHasCatchAll); + landingPads, tryOp.isCatchAllOnly()); // Fill in dispatcher to all catch clauses. rewriter.setInsertionPointToEnd(nextDispatcher); @@ -456,6 +449,7 @@ class CIRTryOpFlattening : public mlir::OpRewritePattern { unsigned catchIdx = 0; // Build control-flow for all catch clauses. + mlir::ArrayAttr catchAttrList = tryOp.getCatchTypesAttr(); for (mlir::Attribute catchAttr : catchAttrList) { mlir::Attribute nextCatchAttr; if (catchIdx + 1 < catchAttrList.size())