Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][IR] Refactor do-while loops #407

Merged
merged 5 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 61 additions & 6 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,9 @@ def StoreOp : CIR_Op<"store", [
// ReturnOp
//===----------------------------------------------------------------------===//

def ReturnOp : CIR_Op<"return", [HasParent<"FuncOp, ScopeOp, IfOp, SwitchOp, LoopOp">,
def ReturnOp : CIR_Op<"return", [ParentOneOf<["FuncOp", "ScopeOp", "IfOp",
"SwitchOp", "DoWhileOp",
"LoopOp"]>,
Terminator]> {
let summary = "Return from function";
let description = [{
Expand Down Expand Up @@ -634,7 +636,7 @@ def ConditionOp : CIR_Op<"condition", [

def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
ParentOneOf<["IfOp", "ScopeOp", "SwitchOp", "LoopOp", "AwaitOp",
"TernaryOp", "GlobalOp"]>]> {
"TernaryOp", "GlobalOp", "DoWhileOp"]>]> {
let summary = "Represents the default branching behaviour of a region";
let description = [{
The `cir.yield` operation terminates regions on different CIR operations,
Expand Down Expand Up @@ -1138,12 +1140,11 @@ def BrCondOp : CIR_Op<"brcond",

def LoopOpKind_For : I32EnumAttrCase<"For", 1, "for">;
def LoopOpKind_While : I32EnumAttrCase<"While", 2, "while">;
def LoopOpKind_DoWhile : I32EnumAttrCase<"DoWhile", 3, "dowhile">;

def LoopOpKind : I32EnumAttr<
"LoopOpKind",
"Loop kind",
[LoopOpKind_For, LoopOpKind_While, LoopOpKind_DoWhile]> {
[LoopOpKind_For, LoopOpKind_While]> {
let cppNamespace = "::mlir::cir";
}

Expand Down Expand Up @@ -1229,6 +1230,59 @@ def LoopOp : CIR_Op<"loop",
}];
}

//===----------------------------------------------------------------------===//
// DoWhileOp
//===----------------------------------------------------------------------===//

class WhileOpBase<string mnemonic> : CIR_Op<mnemonic, [
LoopOpInterface,
NoRegionArguments,
]> {
defvar isWhile = !eq(mnemonic, "while");
let summary = "C/C++ " # !if(isWhile, "while", "do-while") # " loop";
let builders = [
OpBuilder<(ins "function_ref<void(OpBuilder &, Location)>":$condBuilder,
"function_ref<void(OpBuilder &, Location)>":$bodyBuilder), [{
OpBuilder::InsertionGuard guard($_builder);
$_builder.createBlock($_state.addRegion());
}] # !if(isWhile, [{
condBuilder($_builder, $_state.location);
$_builder.createBlock($_state.addRegion());
bodyBuilder($_builder, $_state.location);
}], [{
bodyBuilder($_builder, $_state.location);
$_builder.createBlock($_state.addRegion());
condBuilder($_builder, $_state.location);
}])>
];
}

def DoWhileOp : WhileOpBase<"do"> {
let regions = (region MinSizedRegion<1>:$body, SizedRegion<1>:$cond);
let assemblyFormat = " $body `while` $cond attr-dict";

let extraClassDeclaration = [{
Region &getEntry() { return getBody(); }
}];

let description = [{
Represents a C/C++ do-while loop. Identical to `cir.while` but the
condition is evaluated after the body.

Example:

```mlir
cir.do {
cir.break
^bb2:
cir.yield
} while {
cir.condition %cond : cir.bool
}
```
}];
}

//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2574,8 +2628,9 @@ def AllocException : CIR_Op<"alloc_exception", [
// ThrowOp
//===----------------------------------------------------------------------===//

def ThrowOp : CIR_Op<"throw",
[HasParent<"FuncOp, ScopeOp, IfOp, SwitchOp, LoopOp">,
def ThrowOp : CIR_Op<"throw", [
ParentOneOf<["FuncOp", "ScopeOp", "IfOp", "SwitchOp",
"DoWhileOp", "LoopOp"]>,
Terminator]> {
let summary = "(Re)Throws an exception";
let description = [{
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,14 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
return create<mlir::cir::ContinueOp>(loc);
}

/// Create a do-while operation.
mlir::cir::DoWhileOp createDoWhile(
mlir::Location loc,
llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> condBuilder,
llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> bodyBuilder) {
return create<mlir::cir::DoWhileOp>(loc, condBuilder, bodyBuilder);
}

mlir::cir::MemCpyOp createMemCpy(mlir::Location loc, mlir::Value dst,
mlir::Value src, mlir::Value len) {
return create<mlir::cir::MemCpyOp>(loc, dst, src, len);
Expand Down
12 changes: 4 additions & 8 deletions clang/lib/CIR/CodeGen/CIRGenStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ mlir::LogicalResult CIRGenFunction::buildForStmt(const ForStmt &S) {
}

mlir::LogicalResult CIRGenFunction::buildDoStmt(const DoStmt &S) {
mlir::cir::LoopOp loopOp;
mlir::cir::DoWhileOp doWhileOp;

// TODO: pass in array of attributes.
auto doStmtBuilder = [&]() -> mlir::LogicalResult {
Expand All @@ -834,8 +834,8 @@ mlir::LogicalResult CIRGenFunction::buildDoStmt(const DoStmt &S) {
// sure we handle all cases.
assert(!UnimplementedFeature::requiresCleanups());

loopOp = builder.create<LoopOp>(
getLoc(S.getSourceRange()), mlir::cir::LoopOpKind::DoWhile,
doWhileOp = builder.createDoWhile(
getLoc(S.getSourceRange()),
/*condBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
assert(!UnimplementedFeature::createProfileWeightsForLoop());
Expand All @@ -851,10 +851,6 @@ mlir::LogicalResult CIRGenFunction::buildDoStmt(const DoStmt &S) {
if (buildStmt(S.getBody(), /*useCurrentScope=*/true).failed())
loopRes = mlir::failure();
buildStopPoint(&S);
},
/*stepBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
builder.createYield(loc);
});
return loopRes;
};
Expand All @@ -871,7 +867,7 @@ mlir::LogicalResult CIRGenFunction::buildDoStmt(const DoStmt &S) {
if (res.failed())
return res;

terminateBody(builder, loopOp.getBody(), getLoc(S.getEndLoc()));
terminateBody(builder, doWhileOp.getBody(), getLoc(S.getEndLoc()));
return mlir::success();
}

Expand Down
14 changes: 14 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,20 @@ llvm::SmallVector<Region *> LoopOp::getLoopRegions() { return {&getBody()}; }

LogicalResult LoopOp::verify() { return success(); }

//===----------------------------------------------------------------------===//
// LoopOpInterface Methods
//===----------------------------------------------------------------------===//

void DoWhileOp::getSuccessorRegions(
::mlir::RegionBranchPoint point,
::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &regions) {
LoopOpInterface::getLoopOpSuccessorRegions(*this, point, regions);
}

::llvm::SmallVector<Region *> DoWhileOp::getLoopRegions() {
return {&getBody()};
}

//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
Expand Down
7 changes: 4 additions & 3 deletions clang/lib/CIR/Dialect/Transforms/LifetimeCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/Passes.h"

#include "clang/CIR/Interfaces/LoopOpInterface.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallSet.h"

Expand Down Expand Up @@ -46,7 +47,7 @@ struct LifetimeCheckPass : public LifetimeCheckBase<LifetimeCheckPass> {

void checkIf(IfOp op);
void checkSwitch(SwitchOp op);
void checkLoop(LoopOp op);
void checkLoop(LoopOpInterface op);
void checkAlloca(AllocaOp op);
void checkStore(StoreOp op);
void checkLoad(LoadOp op);
Expand Down Expand Up @@ -653,7 +654,7 @@ void LifetimeCheckPass::joinPmaps(SmallVectorImpl<PMapType> &pmaps) {
}
}

void LifetimeCheckPass::checkLoop(LoopOp loopOp) {
void LifetimeCheckPass::checkLoop(LoopOpInterface loopOp) {
// 2.4.9. Loops
//
// A loop is treated as if it were the first two loop iterations unrolled
Expand Down Expand Up @@ -1849,7 +1850,7 @@ void LifetimeCheckPass::checkOperation(Operation *op) {
return checkIf(ifOp);
if (auto switchOp = dyn_cast<SwitchOp>(op))
return checkSwitch(switchOp);
if (auto loopOp = dyn_cast<LoopOp>(op))
if (auto loopOp = dyn_cast<LoopOpInterface>(op))
return checkLoop(loopOp);
if (auto allocaOp = dyn_cast<AllocaOp>(op))
return checkAlloca(allocaOp);
Expand Down
66 changes: 29 additions & 37 deletions clang/test/CIR/CodeGen/loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,46 +111,40 @@ void l3(bool cond) {

// CHECK: cir.func @_Z2l3b
// CHECK: cir.scope {
// CHECK-NEXT: cir.loop dowhile(cond : {
// CHECK-NEXT: %[[#TRUE:]] = cir.load %0 : cir.ptr <!cir.bool>, !cir.bool
// CHECK-NEXT: cir.condition(%[[#TRUE]])
// CHECK-NEXT: }, step : {
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }) {
// CHECK-NEXT: %3 = cir.load %1 : cir.ptr <!s32i>, !s32i
// CHECK-NEXT: %4 = cir.const(#cir.int<1> : !s32i) : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i
// CHECK-NEXT: cir.store %5, %1 : !s32i, cir.ptr <!s32i>
// CHECK-NEXT: cir.yield
// CHECK-NEXT: cir.do {
// CHECK-NEXT: %3 = cir.load %1 : cir.ptr <!s32i>, !s32i
// CHECK-NEXT: %4 = cir.const(#cir.int<1> : !s32i) : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i
// CHECK-NEXT: cir.store %5, %1 : !s32i, cir.ptr <!s32i>
// CHECK-NEXT: cir.yield
// CHECK-NEXT: } while {
// CHECK-NEXT: %[[#TRUE:]] = cir.load %0 : cir.ptr <!cir.bool>, !cir.bool
// CHECK-NEXT: cir.condition(%[[#TRUE]])
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: cir.scope {
// CHECK-NEXT: cir.loop dowhile(cond : {
// CHECK-NEXT: cir.do {
// CHECK-NEXT: %3 = cir.load %1 : cir.ptr <!s32i>, !s32i
// CHECK-NEXT: %4 = cir.const(#cir.int<1> : !s32i) : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i
// CHECK-NEXT: cir.store %5, %1 : !s32i, cir.ptr <!s32i>
// CHECK-NEXT: cir.yield
// CHECK-NEXT: } while {
// CHECK-NEXT: %[[#TRUE:]] = cir.const(#true) : !cir.bool
// CHECK-NEXT: cir.condition(%[[#TRUE]])
// CHECK-NEXT: }, step : {
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }) {
// CHECK-NEXT: %3 = cir.load %1 : cir.ptr <!s32i>, !s32i
// CHECK-NEXT: %4 = cir.const(#cir.int<1> : !s32i) : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i
// CHECK-NEXT: cir.store %5, %1 : !s32i, cir.ptr <!s32i>
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: cir.scope {
// CHECK-NEXT: cir.loop dowhile(cond : {
// CHECK-NEXT: %3 = cir.const(#cir.int<1> : !s32i) : !s32i
// CHECK-NEXT: %4 = cir.cast(int_to_bool, %3 : !s32i), !cir.bool
// CHECK-NEXT: cir.condition(%4)
// CHECK-NEXT: }, step : {
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }) {
// CHECK-NEXT: %3 = cir.load %1 : cir.ptr <!s32i>, !s32i
// CHECK-NEXT: %4 = cir.const(#cir.int<1> : !s32i) : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i
// CHECK-NEXT: cir.store %5, %1 : !s32i, cir.ptr <!s32i>
// CHECK-NEXT: cir.yield
// CHECK-NEXT: cir.do {
// CHECK-NEXT: %3 = cir.load %1 : cir.ptr <!s32i>, !s32i
// CHECK-NEXT: %4 = cir.const(#cir.int<1> : !s32i) : !s32i
// CHECK-NEXT: %5 = cir.binop(add, %3, %4) : !s32i
// CHECK-NEXT: cir.store %5, %1 : !s32i, cir.ptr <!s32i>
// CHECK-NEXT: cir.yield
// CHECK-NEXT: } while {
// CHECK-NEXT: %3 = cir.const(#cir.int<1> : !s32i) : !s32i
// CHECK-NEXT: %4 = cir.cast(int_to_bool, %3 : !s32i), !cir.bool
// CHECK-NEXT: cir.condition(%4)
// CHECK-NEXT: }
// CHECK-NEXT: }

Expand Down Expand Up @@ -191,14 +185,12 @@ void l5() {

// CHECK: cir.func @_Z2l5v()
// CHECK-NEXT: cir.scope {
// CHECK-NEXT: cir.loop dowhile(cond : {
// CHECK-NEXT: cir.do {
// CHECK-NEXT: cir.yield
// CHECK-NEXT: } while {
// CHECK-NEXT: %0 = cir.const(#cir.int<0> : !s32i) : !s32i
// CHECK-NEXT: %1 = cir.cast(int_to_bool, %0 : !s32i), !cir.bool
// CHECK-NEXT: cir.condition(%1)
// CHECK-NEXT: }, step : {
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }) {
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: cir.return
Expand Down
18 changes: 18 additions & 0 deletions clang/test/CIR/IR/do-while.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: cir-opt %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s

cir.func @testPrintingAndParsing (%arg0 : !cir.bool) -> !cir.void {
cir.do {
cir.yield
} while {
cir.condition(%arg0)
}
cir.return
}

// CHECK: testPrintingAndParsing
// CHECK: cir.do {
// CHECK: cir.yield
// CHECK: } while {
// CHECK: cir.condition(%arg0)
// CHECK: }
11 changes: 11 additions & 0 deletions clang/test/CIR/IR/invalid.cir
Original file line number Diff line number Diff line change
Expand Up @@ -808,3 +808,14 @@ cir.func @const_type_mismatch() -> () {
%2 = cir.const(#cir.int<0> : !s8i) : !u8i
cir.return
}

// -----

cir.func @invalid_cond_region_terminator(%arg0 : !cir.bool) -> !cir.void {
cir.do { // expected-error {{op expected condition region to terminate with 'cir.condition'}}
cir.yield
} while {
cir.yield
}
cir.return
}
34 changes: 0 additions & 34 deletions clang/test/CIR/IR/loop.cir
Original file line number Diff line number Diff line change
Expand Up @@ -58,25 +58,6 @@ cir.func @l0() {
}
}

cir.scope {
%2 = cir.alloca !u32i, cir.ptr <!u32i>, ["i", init] {alignment = 4 : i64}
%3 = cir.const(#cir.int<0> : !u32i) : !u32i
cir.store %3, %2 : !u32i, cir.ptr <!u32i>
cir.loop dowhile(cond : {
%4 = cir.load %2 : cir.ptr <!u32i>, !u32i
%5 = cir.const(#cir.int<10> : !u32i) : !u32i
%6 = cir.cmp(lt, %4, %5) : !u32i, !cir.bool
cir.condition(%6)
}, step : {
cir.yield
}) {
%4 = cir.load %0 : cir.ptr <!u32i>, !u32i
%5 = cir.const(#cir.int<1> : !u32i) : !u32i
%6 = cir.binop(add, %4, %5) : !u32i
cir.store %6, %0 : !u32i, cir.ptr <!u32i>
cir.yield
}
}
cir.return
}

Expand Down Expand Up @@ -123,21 +104,6 @@ cir.func @l0() {
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }

// CHECK: cir.loop dowhile(cond : {
// CHECK-NEXT: %4 = cir.load %2 : cir.ptr <!u32i>, !u32i
// CHECK-NEXT: %5 = cir.const(#cir.int<10> : !u32i) : !u32i
// CHECK-NEXT: %6 = cir.cmp(lt, %4, %5) : !u32i, !cir.bool
// CHECK-NEXT: cir.condition(%6)
// CHECK-NEXT: }, step : {
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }) {
// CHECK-NEXT: %4 = cir.load %0 : cir.ptr <!u32i>, !u32i
// CHECK-NEXT: %5 = cir.const(#cir.int<1> : !u32i) : !u32i
// CHECK-NEXT: %6 = cir.binop(add, %4, %5) : !u32i
// CHECK-NEXT: cir.store %6, %0 : !u32i, cir.ptr <!u32i>
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }

cir.func @l1(%arg0 : !cir.bool) {
cir.scope {
cir.loop while(cond : {
Expand Down
Loading