Skip to content

Commit

Permalink
[CIR][IR] Refactor do-while loops
Browse files Browse the repository at this point in the history
Creates a separate C/C++ operation for do-while loops, while keeping the
LoopOpInterface to generically handle loops. This simplifies the IR
generation and printing/parsing of do-while loops. It also allows us to
define it regions in the order that they are executed, which is useful
for the lifetime analysis.

ghstack-source-id: b4d9517197b8f82ae677dc2684101fe5762b21b7
Pull Request resolved: #407
  • Loading branch information
sitio-couto authored and lanza committed Apr 29, 2024
1 parent 5d02771 commit 5836fa1
Show file tree
Hide file tree
Showing 12 changed files with 167 additions and 113 deletions.
71 changes: 63 additions & 8 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,9 @@ def StoreOp : CIR_Op<"store", [
// ReturnOp
//===----------------------------------------------------------------------===//

def ReturnOp : CIR_Op<"return", [HasParent<"FuncOp, ScopeOp, IfOp, SwitchOp, LoopOp">,
def ReturnOp : CIR_Op<"return", [ParentOneOf<["FuncOp", "ScopeOp", "IfOp",
"SwitchOp", "DoWhileOp",
"LoopOp"]>,
Terminator]> {
let summary = "Return from function";
let description = [{
Expand Down Expand Up @@ -634,7 +636,7 @@ def ConditionOp : CIR_Op<"condition", [

def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
ParentOneOf<["IfOp", "ScopeOp", "SwitchOp", "LoopOp", "AwaitOp",
"TernaryOp", "GlobalOp"]>]> {
"TernaryOp", "GlobalOp", "DoWhileOp"]>]> {
let summary = "Represents the default branching behaviour of a region";
let description = [{
The `cir.yield` operation terminates regions on different CIR operations,
Expand Down Expand Up @@ -1163,12 +1165,11 @@ def BrCondOp : CIR_Op<"brcond",

def LoopOpKind_For : I32EnumAttrCase<"For", 1, "for">;
def LoopOpKind_While : I32EnumAttrCase<"While", 2, "while">;
def LoopOpKind_DoWhile : I32EnumAttrCase<"DoWhile", 3, "dowhile">;

def LoopOpKind : I32EnumAttr<
"LoopOpKind",
"Loop kind",
[LoopOpKind_For, LoopOpKind_While, LoopOpKind_DoWhile]> {
[LoopOpKind_For, LoopOpKind_While]> {
let cppNamespace = "::mlir::cir";
}

Expand Down Expand Up @@ -1252,13 +1253,66 @@ def LoopOp : CIR_Op<"loop",
return llvm::SmallVector<Region *, 3>{&getCond(), &getBody(), &getStep()};
case LoopOpKind::While:
return llvm::SmallVector<Region *, 2>{&getCond(), &getBody()};
case LoopOpKind::DoWhile:
return llvm::SmallVector<Region *, 2>{&getBody(), &getCond()};
// case LoopOpKind::DoWhile:
// return llvm::SmallVector<Region *, 2>{&getBody(), &getCond()};
}
}
}];
}

//===----------------------------------------------------------------------===//
// DoWhileOp
//===----------------------------------------------------------------------===//

class WhileOpBase<string mnemonic> : CIR_Op<mnemonic, [
LoopOpInterface,
NoRegionArguments,
]> {
defvar isWhile = !eq(mnemonic, "while");
let summary = "C/C++ " # !if(isWhile, "while", "do-while") # " loop";
let builders = [
OpBuilder<(ins "function_ref<void(OpBuilder &, Location)>":$condBuilder,
"function_ref<void(OpBuilder &, Location)>":$bodyBuilder), [{
OpBuilder::InsertionGuard guard($_builder);
$_builder.createBlock($_state.addRegion());
}] # !if(isWhile, [{
condBuilder($_builder, $_state.location);
$_builder.createBlock($_state.addRegion());
bodyBuilder($_builder, $_state.location);
}], [{
bodyBuilder($_builder, $_state.location);
$_builder.createBlock($_state.addRegion());
condBuilder($_builder, $_state.location);
}])>
];
}

def DoWhileOp : WhileOpBase<"do"> {
let regions = (region MinSizedRegion<1>:$body, SizedRegion<1>:$cond);
let assemblyFormat = " $body `while` $cond attr-dict";

let extraClassDeclaration = [{
Region &getEntry() { return getBody(); }
}];

let description = [{
Represents a C/C++ do-while loop. Identical to `cir.while` but the
condition is evaluated after the body.

Example:

```mlir
cir.do {
cir.break
^bb2:
cir.yield
} while {
cir.condition %cond : cir.bool
}
```
}];
}

//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2604,8 +2658,9 @@ def AllocException : CIR_Op<"alloc_exception", [
// ThrowOp
//===----------------------------------------------------------------------===//

def ThrowOp : CIR_Op<"throw",
[HasParent<"FuncOp, ScopeOp, IfOp, SwitchOp, LoopOp">,
def ThrowOp : CIR_Op<"throw", [
ParentOneOf<["FuncOp", "ScopeOp", "IfOp", "SwitchOp",
"DoWhileOp", "LoopOp"]>,
Terminator]> {
let summary = "(Re)Throws an exception";
let description = [{
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,14 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
return create<mlir::cir::ContinueOp>(loc);
}

/// Create a do-while operation.
mlir::cir::DoWhileOp createDoWhile(
mlir::Location loc,
llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> condBuilder,
llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> bodyBuilder) {
return create<mlir::cir::DoWhileOp>(loc, condBuilder, bodyBuilder);
}

mlir::cir::MemCpyOp createMemCpy(mlir::Location loc, mlir::Value dst,
mlir::Value src, mlir::Value len) {
return create<mlir::cir::MemCpyOp>(loc, dst, src, len);
Expand Down
12 changes: 4 additions & 8 deletions clang/lib/CIR/CodeGen/CIRGenStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ mlir::LogicalResult CIRGenFunction::buildForStmt(const ForStmt &S) {
}

mlir::LogicalResult CIRGenFunction::buildDoStmt(const DoStmt &S) {
mlir::cir::LoopOp loopOp;
mlir::cir::DoWhileOp doWhileOp;

// TODO: pass in array of attributes.
auto doStmtBuilder = [&]() -> mlir::LogicalResult {
Expand All @@ -839,8 +839,8 @@ mlir::LogicalResult CIRGenFunction::buildDoStmt(const DoStmt &S) {
// sure we handle all cases.
assert(!UnimplementedFeature::requiresCleanups());

loopOp = builder.create<LoopOp>(
getLoc(S.getSourceRange()), mlir::cir::LoopOpKind::DoWhile,
doWhileOp = builder.createDoWhile(
getLoc(S.getSourceRange()),
/*condBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
assert(!UnimplementedFeature::createProfileWeightsForLoop());
Expand All @@ -856,10 +856,6 @@ mlir::LogicalResult CIRGenFunction::buildDoStmt(const DoStmt &S) {
if (buildStmt(S.getBody(), /*useCurrentScope=*/true).failed())
loopRes = mlir::failure();
buildStopPoint(&S);
},
/*stepBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
builder.createYield(loc);
});
return loopRes;
};
Expand All @@ -876,7 +872,7 @@ mlir::LogicalResult CIRGenFunction::buildDoStmt(const DoStmt &S) {
if (res.failed())
return res;

terminateBody(builder, loopOp.getBody(), getLoc(S.getEndLoc()));
terminateBody(builder, doWhileOp.getBody(), getLoc(S.getEndLoc()));
return mlir::success();
}

Expand Down
14 changes: 14 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,20 @@ llvm::SmallVector<Region *> LoopOp::getLoopRegions() { return {&getBody()}; }

LogicalResult LoopOp::verify() { return success(); }

//===----------------------------------------------------------------------===//
// LoopOpInterface Methods
//===----------------------------------------------------------------------===//

void DoWhileOp::getSuccessorRegions(
::mlir::RegionBranchPoint point,
::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &regions) {
LoopOpInterface::getLoopOpSuccessorRegions(*this, point, regions);
}

::llvm::SmallVector<Region *> DoWhileOp::getLoopRegions() {
return {&getBody()};
}

//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
Expand Down
6 changes: 3 additions & 3 deletions clang/lib/CIR/Dialect/Transforms/LifetimeCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct LifetimeCheckPass : public LifetimeCheckBase<LifetimeCheckPass> {

void checkIf(IfOp op);
void checkSwitch(SwitchOp op);
void checkLoop(LoopOp op);
void checkLoop(LoopOpInterface op);
void checkAlloca(AllocaOp op);
void checkStore(StoreOp op);
void checkLoad(LoadOp op);
Expand Down Expand Up @@ -654,7 +654,7 @@ void LifetimeCheckPass::joinPmaps(SmallVectorImpl<PMapType> &pmaps) {
}
}

void LifetimeCheckPass::checkLoop(LoopOp loopOp) {
void LifetimeCheckPass::checkLoop(LoopOpInterface loopOp) {
// 2.4.9. Loops
//
// A loop is treated as if it were the first two loop iterations unrolled
Expand Down Expand Up @@ -1850,7 +1850,7 @@ void LifetimeCheckPass::checkOperation(Operation *op) {
return checkIf(ifOp);
if (auto switchOp = dyn_cast<SwitchOp>(op))
return checkSwitch(switchOp);
if (auto loopOp = dyn_cast<LoopOp>(op))
if (auto loopOp = dyn_cast<LoopOpInterface>(op))
return checkLoop(loopOp);
if (auto allocaOp = dyn_cast<AllocaOp>(op))
return checkAlloca(allocaOp);
Expand Down
66 changes: 29 additions & 37 deletions clang/test/CIR/CodeGen/loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,46 +111,40 @@ void l3(bool cond) {

// CHECK: cir.func @_Z2l3b
// CHECK: cir.scope {
// CHECK-NEXT: cir.loop dowhile(cond : {
// CHECK-NEXT: %[[#TRUE:]] = cir.load %0 : cir.ptr <!cir.bool>, !cir.bool
// CHECK-NEXT: cir.condition(%[[#TRUE]])
// CHECK-NEXT: }, step : {
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }) {
// CHECK-NEXT: %3 = cir.load %1 : cir.ptr <!s32i>, !s32i
// CHECK-NEXT: %4 = cir.const(#cir.int<1> : !s32i) : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i
// CHECK-NEXT: cir.store %5, %1 : !s32i, cir.ptr <!s32i>
// CHECK-NEXT: cir.yield
// CHECK-NEXT: cir.do {
// CHECK-NEXT: %3 = cir.load %1 : cir.ptr <!s32i>, !s32i
// CHECK-NEXT: %4 = cir.const(#cir.int<1> : !s32i) : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i
// CHECK-NEXT: cir.store %5, %1 : !s32i, cir.ptr <!s32i>
// CHECK-NEXT: cir.yield
// CHECK-NEXT: } while {
// CHECK-NEXT: %[[#TRUE:]] = cir.load %0 : cir.ptr <!cir.bool>, !cir.bool
// CHECK-NEXT: cir.condition(%[[#TRUE]])
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: cir.scope {
// CHECK-NEXT: cir.loop dowhile(cond : {
// CHECK-NEXT: cir.do {
// CHECK-NEXT: %3 = cir.load %1 : cir.ptr <!s32i>, !s32i
// CHECK-NEXT: %4 = cir.const(#cir.int<1> : !s32i) : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i
// CHECK-NEXT: cir.store %5, %1 : !s32i, cir.ptr <!s32i>
// CHECK-NEXT: cir.yield
// CHECK-NEXT: } while {
// CHECK-NEXT: %[[#TRUE:]] = cir.const(#true) : !cir.bool
// CHECK-NEXT: cir.condition(%[[#TRUE]])
// CHECK-NEXT: }, step : {
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }) {
// CHECK-NEXT: %3 = cir.load %1 : cir.ptr <!s32i>, !s32i
// CHECK-NEXT: %4 = cir.const(#cir.int<1> : !s32i) : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i
// CHECK-NEXT: cir.store %5, %1 : !s32i, cir.ptr <!s32i>
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: cir.scope {
// CHECK-NEXT: cir.loop dowhile(cond : {
// CHECK-NEXT: %3 = cir.const(#cir.int<1> : !s32i) : !s32i
// CHECK-NEXT: %4 = cir.cast(int_to_bool, %3 : !s32i), !cir.bool
// CHECK-NEXT: cir.condition(%4)
// CHECK-NEXT: }, step : {
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }) {
// CHECK-NEXT: %3 = cir.load %1 : cir.ptr <!s32i>, !s32i
// CHECK-NEXT: %4 = cir.const(#cir.int<1> : !s32i) : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i
// CHECK-NEXT: cir.store %5, %1 : !s32i, cir.ptr <!s32i>
// CHECK-NEXT: cir.yield
// CHECK-NEXT: cir.do {
// CHECK-NEXT: %3 = cir.load %1 : cir.ptr <!s32i>, !s32i
// CHECK-NEXT: %4 = cir.const(#cir.int<1> : !s32i) : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i
// CHECK-NEXT: cir.store %5, %1 : !s32i, cir.ptr <!s32i>
// CHECK-NEXT: cir.yield
// CHECK-NEXT: } while {
// CHECK-NEXT: %3 = cir.const(#cir.int<1> : !s32i) : !s32i
// CHECK-NEXT: %4 = cir.cast(int_to_bool, %3 : !s32i), !cir.bool
// CHECK-NEXT: cir.condition(%4)
// CHECK-NEXT: }
// CHECK-NEXT: }

Expand Down Expand Up @@ -191,14 +185,12 @@ void l5() {

// CHECK: cir.func @_Z2l5v()
// CHECK-NEXT: cir.scope {
// CHECK-NEXT: cir.loop dowhile(cond : {
// CHECK-NEXT: cir.do {
// CHECK-NEXT: cir.yield
// CHECK-NEXT: } while {
// CHECK-NEXT: %0 = cir.const(#cir.int<0> : !s32i) : !s32i
// CHECK-NEXT: %1 = cir.cast(int_to_bool, %0 : !s32i), !cir.bool
// CHECK-NEXT: cir.condition(%1)
// CHECK-NEXT: }, step : {
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }) {
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: cir.return
Expand Down
18 changes: 18 additions & 0 deletions clang/test/CIR/IR/do-while.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: cir-opt %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s

cir.func @testPrintingAndParsing (%arg0 : !cir.bool) -> !cir.void {
cir.do {
cir.yield
} while {
cir.condition(%arg0)
}
cir.return
}

// CHECK: testPrintingAndParsing
// CHECK: cir.do {
// CHECK: cir.yield
// CHECK: } while {
// CHECK: cir.condition(%arg0)
// CHECK: }
11 changes: 11 additions & 0 deletions clang/test/CIR/IR/invalid.cir
Original file line number Diff line number Diff line change
Expand Up @@ -808,3 +808,14 @@ cir.func @const_type_mismatch() -> () {
%2 = cir.const(#cir.int<0> : !s8i) : !u8i
cir.return
}

// -----

cir.func @invalid_cond_region_terminator(%arg0 : !cir.bool) -> !cir.void {
cir.do { // expected-error {{op expected condition region to terminate with 'cir.condition'}}
cir.yield
} while {
cir.yield
}
cir.return
}
34 changes: 0 additions & 34 deletions clang/test/CIR/IR/loop.cir
Original file line number Diff line number Diff line change
Expand Up @@ -58,25 +58,6 @@ cir.func @l0() {
}
}

cir.scope {
%2 = cir.alloca !u32i, cir.ptr <!u32i>, ["i", init] {alignment = 4 : i64}
%3 = cir.const(#cir.int<0> : !u32i) : !u32i
cir.store %3, %2 : !u32i, cir.ptr <!u32i>
cir.loop dowhile(cond : {
%4 = cir.load %2 : cir.ptr <!u32i>, !u32i
%5 = cir.const(#cir.int<10> : !u32i) : !u32i
%6 = cir.cmp(lt, %4, %5) : !u32i, !cir.bool
cir.condition(%6)
}, step : {
cir.yield
}) {
%4 = cir.load %0 : cir.ptr <!u32i>, !u32i
%5 = cir.const(#cir.int<1> : !u32i) : !u32i
%6 = cir.binop(add, %4, %5) : !u32i
cir.store %6, %0 : !u32i, cir.ptr <!u32i>
cir.yield
}
}
cir.return
}

Expand Down Expand Up @@ -123,21 +104,6 @@ cir.func @l0() {
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }

// CHECK: cir.loop dowhile(cond : {
// CHECK-NEXT: %4 = cir.load %2 : cir.ptr <!u32i>, !u32i
// CHECK-NEXT: %5 = cir.const(#cir.int<10> : !u32i) : !u32i
// CHECK-NEXT: %6 = cir.cmp(lt, %4, %5) : !u32i, !cir.bool
// CHECK-NEXT: cir.condition(%6)
// CHECK-NEXT: }, step : {
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }) {
// CHECK-NEXT: %4 = cir.load %0 : cir.ptr <!u32i>, !u32i
// CHECK-NEXT: %5 = cir.const(#cir.int<1> : !u32i) : !u32i
// CHECK-NEXT: %6 = cir.binop(add, %4, %5) : !u32i
// CHECK-NEXT: cir.store %6, %0 : !u32i, cir.ptr <!u32i>
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }

cir.func @l1(%arg0 : !cir.bool) {
cir.scope {
cir.loop while(cond : {
Expand Down
Loading

0 comments on commit 5836fa1

Please sign in to comment.