diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 21774597c406..08557c3173a0 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -440,7 +440,9 @@ def StoreOp : CIR_Op<"store", [ // ReturnOp //===----------------------------------------------------------------------===// -def ReturnOp : CIR_Op<"return", [HasParent<"FuncOp, ScopeOp, IfOp, SwitchOp, LoopOp">, +def ReturnOp : CIR_Op<"return", [ParentOneOf<["FuncOp", "ScopeOp", "IfOp", + "SwitchOp", "DoWhileOp", + "LoopOp"]>, Terminator]> { let summary = "Return from function"; let description = [{ @@ -634,7 +636,7 @@ def ConditionOp : CIR_Op<"condition", [ def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator, ParentOneOf<["IfOp", "ScopeOp", "SwitchOp", "LoopOp", "AwaitOp", - "TernaryOp", "GlobalOp"]>]> { + "TernaryOp", "GlobalOp", "DoWhileOp"]>]> { let summary = "Represents the default branching behaviour of a region"; let description = [{ The `cir.yield` operation terminates regions on different CIR operations, @@ -1163,12 +1165,11 @@ def BrCondOp : CIR_Op<"brcond", def LoopOpKind_For : I32EnumAttrCase<"For", 1, "for">; def LoopOpKind_While : I32EnumAttrCase<"While", 2, "while">; -def LoopOpKind_DoWhile : I32EnumAttrCase<"DoWhile", 3, "dowhile">; def LoopOpKind : I32EnumAttr< "LoopOpKind", "Loop kind", - [LoopOpKind_For, LoopOpKind_While, LoopOpKind_DoWhile]> { + [LoopOpKind_For, LoopOpKind_While]> { let cppNamespace = "::mlir::cir"; } @@ -1259,6 +1260,59 @@ def LoopOp : CIR_Op<"loop", }]; } +//===----------------------------------------------------------------------===// +// DoWhileOp +//===----------------------------------------------------------------------===// + +class WhileOpBase : CIR_Op { + defvar isWhile = !eq(mnemonic, "while"); + let summary = "C/C++ " # !if(isWhile, "while", "do-while") # " loop"; + let builders = [ + OpBuilder<(ins "function_ref":$condBuilder, + "function_ref":$bodyBuilder), [{ + OpBuilder::InsertionGuard guard($_builder); + $_builder.createBlock($_state.addRegion()); + }] # !if(isWhile, [{ + condBuilder($_builder, $_state.location); + $_builder.createBlock($_state.addRegion()); + bodyBuilder($_builder, $_state.location); + }], [{ + bodyBuilder($_builder, $_state.location); + $_builder.createBlock($_state.addRegion()); + condBuilder($_builder, $_state.location); + }])> + ]; +} + +def DoWhileOp : WhileOpBase<"do"> { + let regions = (region MinSizedRegion<1>:$body, SizedRegion<1>:$cond); + let assemblyFormat = " $body `while` $cond attr-dict"; + + let extraClassDeclaration = [{ + Region &getEntry() { return getBody(); } + }]; + + let description = [{ + Represents a C/C++ do-while loop. Identical to `cir.while` but the + condition is evaluated after the body. + + Example: + + ```mlir + cir.do { + cir.break + ^bb2: + cir.yield + } while { + cir.condition %cond : cir.bool + } + ``` + }]; +} + //===----------------------------------------------------------------------===// // GlobalOp //===----------------------------------------------------------------------===// @@ -2604,8 +2658,9 @@ def AllocException : CIR_Op<"alloc_exception", [ // ThrowOp //===----------------------------------------------------------------------===// -def ThrowOp : CIR_Op<"throw", - [HasParent<"FuncOp, ScopeOp, IfOp, SwitchOp, LoopOp">, +def ThrowOp : CIR_Op<"throw", [ + ParentOneOf<["FuncOp", "ScopeOp", "IfOp", "SwitchOp", + "DoWhileOp", "LoopOp"]>, Terminator]> { let summary = "(Re)Throws an exception"; let description = [{ diff --git a/clang/lib/CIR/CodeGen/CIRGenBuilder.h b/clang/lib/CIR/CodeGen/CIRGenBuilder.h index 228b163da488..31fcb50d7d46 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuilder.h +++ b/clang/lib/CIR/CodeGen/CIRGenBuilder.h @@ -599,6 +599,14 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy { return create(loc); } + /// Create a do-while operation. + mlir::cir::DoWhileOp createDoWhile( + mlir::Location loc, + llvm::function_ref condBuilder, + llvm::function_ref bodyBuilder) { + return create(loc, condBuilder, bodyBuilder); + } + mlir::cir::MemCpyOp createMemCpy(mlir::Location loc, mlir::Value dst, mlir::Value src, mlir::Value len) { return create(loc, dst, src, len); diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp index 4e8a36edcfbd..dfa83c28be8f 100644 --- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp @@ -825,7 +825,7 @@ mlir::LogicalResult CIRGenFunction::buildForStmt(const ForStmt &S) { } mlir::LogicalResult CIRGenFunction::buildDoStmt(const DoStmt &S) { - mlir::cir::LoopOp loopOp; + mlir::cir::DoWhileOp doWhileOp; // TODO: pass in array of attributes. auto doStmtBuilder = [&]() -> mlir::LogicalResult { @@ -837,8 +837,8 @@ mlir::LogicalResult CIRGenFunction::buildDoStmt(const DoStmt &S) { // sure we handle all cases. assert(!UnimplementedFeature::requiresCleanups()); - loopOp = builder.create( - getLoc(S.getSourceRange()), mlir::cir::LoopOpKind::DoWhile, + doWhileOp = builder.createDoWhile( + getLoc(S.getSourceRange()), /*condBuilder=*/ [&](mlir::OpBuilder &b, mlir::Location loc) { assert(!UnimplementedFeature::createProfileWeightsForLoop()); @@ -854,10 +854,6 @@ mlir::LogicalResult CIRGenFunction::buildDoStmt(const DoStmt &S) { if (buildStmt(S.getBody(), /*useCurrentScope=*/true).failed()) loopRes = mlir::failure(); buildStopPoint(&S); - }, - /*stepBuilder=*/ - [&](mlir::OpBuilder &b, mlir::Location loc) { - builder.createYield(loc); }); return loopRes; }; @@ -874,7 +870,7 @@ mlir::LogicalResult CIRGenFunction::buildDoStmt(const DoStmt &S) { if (res.failed()) return res; - terminateBody(builder, loopOp.getBody(), getLoc(S.getEndLoc())); + terminateBody(builder, doWhileOp.getBody(), getLoc(S.getEndLoc())); return mlir::success(); } diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index bf04b61b5721..f2c21a74695f 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -1274,6 +1274,20 @@ llvm::SmallVector LoopOp::getLoopRegions() { return {&getBody()}; } LogicalResult LoopOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// LoopOpInterface Methods +//===----------------------------------------------------------------------===// + +void DoWhileOp::getSuccessorRegions( + ::mlir::RegionBranchPoint point, + ::llvm::SmallVectorImpl<::mlir::RegionSuccessor> ®ions) { + LoopOpInterface::getLoopOpSuccessorRegions(*this, point, regions); +} + +::llvm::SmallVector DoWhileOp::getLoopRegions() { + return {&getBody()}; +} + //===----------------------------------------------------------------------===// // GlobalOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/Transforms/LifetimeCheck.cpp b/clang/lib/CIR/Dialect/Transforms/LifetimeCheck.cpp index ea324bd090b2..e77a6bdf14b8 100644 --- a/clang/lib/CIR/Dialect/Transforms/LifetimeCheck.cpp +++ b/clang/lib/CIR/Dialect/Transforms/LifetimeCheck.cpp @@ -47,7 +47,7 @@ struct LifetimeCheckPass : public LifetimeCheckBase { void checkIf(IfOp op); void checkSwitch(SwitchOp op); - void checkLoop(LoopOp op); + void checkLoop(LoopOpInterface op); void checkAlloca(AllocaOp op); void checkStore(StoreOp op); void checkLoad(LoadOp op); @@ -654,7 +654,7 @@ void LifetimeCheckPass::joinPmaps(SmallVectorImpl &pmaps) { } } -void LifetimeCheckPass::checkLoop(LoopOp loopOp) { +void LifetimeCheckPass::checkLoop(LoopOpInterface loopOp) { // 2.4.9. Loops // // A loop is treated as if it were the first two loop iterations unrolled @@ -1850,7 +1850,7 @@ void LifetimeCheckPass::checkOperation(Operation *op) { return checkIf(ifOp); if (auto switchOp = dyn_cast(op)) return checkSwitch(switchOp); - if (auto loopOp = dyn_cast(op)) + if (auto loopOp = dyn_cast(op)) return checkLoop(loopOp); if (auto allocaOp = dyn_cast(op)) return checkAlloca(allocaOp); diff --git a/clang/test/CIR/CodeGen/loop.cpp b/clang/test/CIR/CodeGen/loop.cpp index 0c5d4f0990ab..530ed7606e80 100644 --- a/clang/test/CIR/CodeGen/loop.cpp +++ b/clang/test/CIR/CodeGen/loop.cpp @@ -111,46 +111,40 @@ void l3(bool cond) { // CHECK: cir.func @_Z2l3b // CHECK: cir.scope { -// CHECK-NEXT: cir.loop dowhile(cond : { -// CHECK-NEXT: %[[#TRUE:]] = cir.load %0 : cir.ptr , !cir.bool -// CHECK-NEXT: cir.condition(%[[#TRUE]]) -// CHECK-NEXT: }, step : { -// CHECK-NEXT: cir.yield -// CHECK-NEXT: }) { -// CHECK-NEXT: %3 = cir.load %1 : cir.ptr , !s32i -// CHECK-NEXT: %4 = cir.const(#cir.int<1> : !s32i) : !s32i -// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i -// CHECK-NEXT: cir.store %5, %1 : !s32i, cir.ptr -// CHECK-NEXT: cir.yield +// CHECK-NEXT: cir.do { +// CHECK-NEXT: %3 = cir.load %1 : cir.ptr , !s32i +// CHECK-NEXT: %4 = cir.const(#cir.int<1> : !s32i) : !s32i +// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i +// CHECK-NEXT: cir.store %5, %1 : !s32i, cir.ptr +// CHECK-NEXT: cir.yield +// CHECK-NEXT: } while { +// CHECK-NEXT: %[[#TRUE:]] = cir.load %0 : cir.ptr , !cir.bool +// CHECK-NEXT: cir.condition(%[[#TRUE]]) // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: cir.scope { -// CHECK-NEXT: cir.loop dowhile(cond : { +// CHECK-NEXT: cir.do { +// CHECK-NEXT: %3 = cir.load %1 : cir.ptr , !s32i +// CHECK-NEXT: %4 = cir.const(#cir.int<1> : !s32i) : !s32i +// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i +// CHECK-NEXT: cir.store %5, %1 : !s32i, cir.ptr +// CHECK-NEXT: cir.yield +// CHECK-NEXT: } while { // CHECK-NEXT: %[[#TRUE:]] = cir.const(#true) : !cir.bool // CHECK-NEXT: cir.condition(%[[#TRUE]]) -// CHECK-NEXT: }, step : { -// CHECK-NEXT: cir.yield -// CHECK-NEXT: }) { -// CHECK-NEXT: %3 = cir.load %1 : cir.ptr , !s32i -// CHECK-NEXT: %4 = cir.const(#cir.int<1> : !s32i) : !s32i -// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i -// CHECK-NEXT: cir.store %5, %1 : !s32i, cir.ptr -// CHECK-NEXT: cir.yield // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: cir.scope { -// CHECK-NEXT: cir.loop dowhile(cond : { -// CHECK-NEXT: %3 = cir.const(#cir.int<1> : !s32i) : !s32i -// CHECK-NEXT: %4 = cir.cast(int_to_bool, %3 : !s32i), !cir.bool -// CHECK-NEXT: cir.condition(%4) -// CHECK-NEXT: }, step : { -// CHECK-NEXT: cir.yield -// CHECK-NEXT: }) { -// CHECK-NEXT: %3 = cir.load %1 : cir.ptr , !s32i -// CHECK-NEXT: %4 = cir.const(#cir.int<1> : !s32i) : !s32i -// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i -// CHECK-NEXT: cir.store %5, %1 : !s32i, cir.ptr -// CHECK-NEXT: cir.yield +// CHECK-NEXT: cir.do { +// CHECK-NEXT: %3 = cir.load %1 : cir.ptr , !s32i +// CHECK-NEXT: %4 = cir.const(#cir.int<1> : !s32i) : !s32i +// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i +// CHECK-NEXT: cir.store %5, %1 : !s32i, cir.ptr +// CHECK-NEXT: cir.yield +// CHECK-NEXT: } while { +// CHECK-NEXT: %3 = cir.const(#cir.int<1> : !s32i) : !s32i +// CHECK-NEXT: %4 = cir.cast(int_to_bool, %3 : !s32i), !cir.bool +// CHECK-NEXT: cir.condition(%4) // CHECK-NEXT: } // CHECK-NEXT: } @@ -191,14 +185,12 @@ void l5() { // CHECK: cir.func @_Z2l5v() // CHECK-NEXT: cir.scope { -// CHECK-NEXT: cir.loop dowhile(cond : { +// CHECK-NEXT: cir.do { +// CHECK-NEXT: cir.yield +// CHECK-NEXT: } while { // CHECK-NEXT: %0 = cir.const(#cir.int<0> : !s32i) : !s32i // CHECK-NEXT: %1 = cir.cast(int_to_bool, %0 : !s32i), !cir.bool // CHECK-NEXT: cir.condition(%1) -// CHECK-NEXT: }, step : { -// CHECK-NEXT: cir.yield -// CHECK-NEXT: }) { -// CHECK-NEXT: cir.yield // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: cir.return diff --git a/clang/test/CIR/IR/do-while.cir b/clang/test/CIR/IR/do-while.cir new file mode 100644 index 000000000000..6664b4cfe4bf --- /dev/null +++ b/clang/test/CIR/IR/do-while.cir @@ -0,0 +1,18 @@ +// RUN: cir-opt %s -o %t.cir +// RUN: FileCheck --input-file=%t.cir %s + +cir.func @testPrintingAndParsing (%arg0 : !cir.bool) -> !cir.void { + cir.do { + cir.yield + } while { + cir.condition(%arg0) + } + cir.return +} + +// CHECK: testPrintingAndParsing +// CHECK: cir.do { +// CHECK: cir.yield +// CHECK: } while { +// CHECK: cir.condition(%arg0) +// CHECK: } diff --git a/clang/test/CIR/IR/invalid.cir b/clang/test/CIR/IR/invalid.cir index a8601342919a..92132f2ba8fb 100644 --- a/clang/test/CIR/IR/invalid.cir +++ b/clang/test/CIR/IR/invalid.cir @@ -808,3 +808,14 @@ cir.func @const_type_mismatch() -> () { %2 = cir.const(#cir.int<0> : !s8i) : !u8i cir.return } + +// ----- + +cir.func @invalid_cond_region_terminator(%arg0 : !cir.bool) -> !cir.void { + cir.do { // expected-error {{op expected condition region to terminate with 'cir.condition'}} + cir.yield + } while { + cir.yield + } + cir.return +} diff --git a/clang/test/CIR/IR/loop.cir b/clang/test/CIR/IR/loop.cir index 132e68119239..af9271ac9f46 100644 --- a/clang/test/CIR/IR/loop.cir +++ b/clang/test/CIR/IR/loop.cir @@ -58,25 +58,6 @@ cir.func @l0() { } } - cir.scope { - %2 = cir.alloca !u32i, cir.ptr , ["i", init] {alignment = 4 : i64} - %3 = cir.const(#cir.int<0> : !u32i) : !u32i - cir.store %3, %2 : !u32i, cir.ptr - cir.loop dowhile(cond : { - %4 = cir.load %2 : cir.ptr , !u32i - %5 = cir.const(#cir.int<10> : !u32i) : !u32i - %6 = cir.cmp(lt, %4, %5) : !u32i, !cir.bool - cir.condition(%6) - }, step : { - cir.yield - }) { - %4 = cir.load %0 : cir.ptr , !u32i - %5 = cir.const(#cir.int<1> : !u32i) : !u32i - %6 = cir.binop(add, %4, %5) : !u32i - cir.store %6, %0 : !u32i, cir.ptr - cir.yield - } - } cir.return } @@ -123,21 +104,6 @@ cir.func @l0() { // CHECK-NEXT: cir.yield // CHECK-NEXT: } -// CHECK: cir.loop dowhile(cond : { -// CHECK-NEXT: %4 = cir.load %2 : cir.ptr , !u32i -// CHECK-NEXT: %5 = cir.const(#cir.int<10> : !u32i) : !u32i -// CHECK-NEXT: %6 = cir.cmp(lt, %4, %5) : !u32i, !cir.bool -// CHECK-NEXT: cir.condition(%6) -// CHECK-NEXT: }, step : { -// CHECK-NEXT: cir.yield -// CHECK-NEXT: }) { -// CHECK-NEXT: %4 = cir.load %0 : cir.ptr , !u32i -// CHECK-NEXT: %5 = cir.const(#cir.int<1> : !u32i) : !u32i -// CHECK-NEXT: %6 = cir.binop(add, %4, %5) : !u32i -// CHECK-NEXT: cir.store %6, %0 : !u32i, cir.ptr -// CHECK-NEXT: cir.yield -// CHECK-NEXT: } - cir.func @l1(%arg0 : !cir.bool) { cir.scope { cir.loop while(cond : { diff --git a/clang/test/CIR/Lowering/loop.cir b/clang/test/CIR/Lowering/loop.cir index 04c4a5debae0..1ff39e7f84b3 100644 --- a/clang/test/CIR/Lowering/loop.cir +++ b/clang/test/CIR/Lowering/loop.cir @@ -53,12 +53,10 @@ module { // Test do-while cir.loop operation lowering. cir.func @testDoWhile(%arg0 : !cir.bool) { - cir.loop dowhile(cond : { - cir.condition(%arg0) - }, step : { // Droped when lowering while statements. - cir.yield - }) { + cir.do { cir.yield + } while { + cir.condition(%arg0) } cir.return } diff --git a/clang/test/CIR/Lowering/loops-with-break.cir b/clang/test/CIR/Lowering/loops-with-break.cir index 147163ab307f..4452d8b25b32 100644 --- a/clang/test/CIR/Lowering/loops-with-break.cir +++ b/clang/test/CIR/Lowering/loops-with-break.cir @@ -228,15 +228,7 @@ cir.func @testDoWhile() { %1 = cir.const(#cir.int<0> : !s32i) : !s32i cir.store %1, %0 : !s32i, cir.ptr cir.scope { - cir.loop dowhile(cond : { - %2 = cir.load %0 : cir.ptr , !s32i - %3 = cir.const(#cir.int<10> : !s32i) : !s32i - %4 = cir.cmp(lt, %2, %3) : !s32i, !s32i - %5 = cir.cast(int_to_bool, %4 : !s32i), !cir.bool - cir.condition(%5) - }, step : { - cir.yield - }) { + cir.do { %2 = cir.load %0 : cir.ptr , !s32i %3 = cir.unary(inc, %2) : !s32i, !s32i cir.store %3, %0 : !s32i, cir.ptr @@ -250,6 +242,12 @@ cir.func @testDoWhile() { } } cir.yield + } while { + %2 = cir.load %0 : cir.ptr , !s32i + %3 = cir.const(#cir.int<10> : !s32i) : !s32i + %4 = cir.cmp(lt, %2, %3) : !s32i, !s32i + %5 = cir.cast(int_to_bool, %4 : !s32i), !cir.bool + cir.condition(%5) } } cir.return diff --git a/clang/test/CIR/Lowering/loops-with-continue.cir b/clang/test/CIR/Lowering/loops-with-continue.cir index 07cd6179f7ae..0f20d4b01f18 100644 --- a/clang/test/CIR/Lowering/loops-with-continue.cir +++ b/clang/test/CIR/Lowering/loops-with-continue.cir @@ -225,15 +225,7 @@ cir.func @testWhile() { %1 = cir.const(#cir.int<0> : !s32i) : !s32i cir.store %1, %0 : !s32i, cir.ptr cir.scope { - cir.loop dowhile(cond : { - %2 = cir.load %0 : cir.ptr , !s32i - %3 = cir.const(#cir.int<10> : !s32i) : !s32i - %4 = cir.cmp(lt, %2, %3) : !s32i, !s32i - %5 = cir.cast(int_to_bool, %4 : !s32i), !cir.bool - cir.condition(%5) - }, step : { - cir.yield - }) { + cir.do { %2 = cir.load %0 : cir.ptr , !s32i %3 = cir.unary(inc, %2) : !s32i, !s32i cir.store %3, %0 : !s32i, cir.ptr @@ -247,6 +239,12 @@ cir.func @testWhile() { } } cir.yield + } while { + %2 = cir.load %0 : cir.ptr , !s32i + %3 = cir.const(#cir.int<10> : !s32i) : !s32i + %4 = cir.cmp(lt, %2, %3) : !s32i, !s32i + %5 = cir.cast(int_to_bool, %4 : !s32i), !cir.bool + cir.condition(%5) } } cir.return