Skip to content

Commit e82231e

Browse files
sitio-coutolanza
authored andcommitted
[CIR][IR] Refactor for loops
This patch completes the deprecation of the generic `cir.loop` operation by adding a new `cir.for` operation and removing the `cir.loop` op. The new representation removes some bloat and places the regions in order of execution. ghstack-source-id: 886e0da Pull Request resolved: #409
1 parent f6bbb03 commit e82231e

File tree

16 files changed

+204
-283
lines changed

16 files changed

+204
-283
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

+70-103
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def StoreOp : CIR_Op<"store", [
442442

443443
def ReturnOp : CIR_Op<"return", [ParentOneOf<["FuncOp", "ScopeOp", "IfOp",
444444
"SwitchOp", "DoWhileOp",
445-
"WhileOp", "LoopOp"]>,
445+
"WhileOp", "ForOp"]>,
446446
Terminator]> {
447447
let summary = "Return from function";
448448
let description = [{
@@ -635,7 +635,7 @@ def ConditionOp : CIR_Op<"condition", [
635635
//===----------------------------------------------------------------------===//
636636

637637
def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
638-
ParentOneOf<["IfOp", "ScopeOp", "SwitchOp", "WhileOp", "LoopOp", "AwaitOp",
638+
ParentOneOf<["IfOp", "ScopeOp", "SwitchOp", "WhileOp", "ForOp", "AwaitOp",
639639
"TernaryOp", "GlobalOp", "DoWhileOp"]>]> {
640640
let summary = "Represents the default branching behaviour of a region";
641641
let description = [{
@@ -1159,106 +1159,6 @@ def BrCondOp : CIR_Op<"brcond",
11591159
}];
11601160
}
11611161

1162-
//===----------------------------------------------------------------------===//
1163-
// LoopOp
1164-
//===----------------------------------------------------------------------===//
1165-
1166-
def LoopOpKind_For : I32EnumAttrCase<"For", 1, "for">;
1167-
1168-
def LoopOpKind : I32EnumAttr<
1169-
"LoopOpKind",
1170-
"Loop kind",
1171-
[LoopOpKind_For]> {
1172-
let cppNamespace = "::mlir::cir";
1173-
}
1174-
1175-
def LoopOp : CIR_Op<"loop",
1176-
[LoopOpInterface,
1177-
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
1178-
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
1179-
RecursivelySpeculatable, NoRegionArguments]> {
1180-
let summary = "Loop";
1181-
let description = [{
1182-
`cir.loop` represents C/C++ loop forms. It defines 3 blocks:
1183-
- `cond`: region can contain multiple blocks, terminated by regular
1184-
`cir.yield` when control should yield back to the parent, and
1185-
`cir.continue` when execution continues to the next region.
1186-
The region destination depends on the loop form specified.
1187-
- `step`: region with one block, containing code to compute the
1188-
loop step, must be terminated with `cir.yield`.
1189-
- `body`: region for the loop's body, can contain an arbitrary
1190-
number of blocks.
1191-
1192-
The loop form: `for`, `while` and `dowhile` must also be specified and
1193-
each implies the loop regions execution order.
1194-
1195-
```mlir
1196-
// while (true) {
1197-
// i = i + 1;
1198-
// }
1199-
cir.loop while(cond : {
1200-
cir.continue
1201-
}, step : {
1202-
cir.yield
1203-
}) {
1204-
%3 = cir.load %1 : cir.ptr <i32>, i32
1205-
%4 = cir.const(1 : i32) : i32
1206-
%5 = cir.binop(add, %3, %4) : i32
1207-
cir.store %5, %1 : i32, cir.ptr <i32>
1208-
cir.yield
1209-
}
1210-
```
1211-
}];
1212-
1213-
let arguments = (ins Arg<LoopOpKind, "loop kind">:$kind);
1214-
let regions = (region AnyRegion:$cond, AnyRegion:$body,
1215-
SizedRegion<1>:$step);
1216-
1217-
let assemblyFormat = [{
1218-
$kind
1219-
`(`
1220-
`cond` `:` $cond `,`
1221-
`step` `:` $step
1222-
`)`
1223-
$body
1224-
attr-dict
1225-
}];
1226-
1227-
let skipDefaultBuilders = 1;
1228-
let builders = [
1229-
OpBuilder<(ins
1230-
"cir::LoopOpKind":$kind,
1231-
CArg<"function_ref<void(OpBuilder &, Location)>",
1232-
"nullptr">:$condBuilder,
1233-
CArg<"function_ref<void(OpBuilder &, Location)>",
1234-
"nullptr">:$bodyBuilder,
1235-
CArg<"function_ref<void(OpBuilder &, Location)>",
1236-
"nullptr">:$stepBuilder
1237-
)>
1238-
];
1239-
1240-
let hasVerifier = 1;
1241-
1242-
let extraClassDeclaration = [{
1243-
Region *maybeGetStep() {
1244-
if (getKind() == LoopOpKind::For)
1245-
return &getStep();
1246-
return nullptr;
1247-
}
1248-
1249-
llvm::SmallVector<Region *> getRegionsInExecutionOrder() {
1250-
switch(getKind()) {
1251-
case LoopOpKind::For:
1252-
return llvm::SmallVector<Region *, 3>{&getCond(), &getBody(), &getStep()};
1253-
// case LoopOpKind::While:
1254-
// return llvm::SmallVector<Region *, 2>{&getCond(), &getBody()};
1255-
// case LoopOpKind::DoWhile:
1256-
// return llvm::SmallVector<Region *, 2>{&getBody(), &getCond()};
1257-
}
1258-
}
1259-
}];
1260-
}
1261-
12621162
//===----------------------------------------------------------------------===//
12631163
// While & DoWhileOp
12641164
//===----------------------------------------------------------------------===//
@@ -1337,6 +1237,73 @@ def DoWhileOp : WhileOpBase<"do"> {
13371237
}];
13381238
}
13391239

1240+
//===----------------------------------------------------------------------===//
1241+
// ForOp
1242+
//===----------------------------------------------------------------------===//
1243+
1244+
def ForOp : CIR_Op<"for", [LoopOpInterface, NoRegionArguments]> {
1245+
let summary = "C/C++ for loop counterpart";
1246+
let description = [{
1247+
Represents a C/C++ for loop. It consists of three regions:
1248+
1249+
- `cond`: single block region with the loop's condition. Should be
1250+
terminated with a `cir.condition` operation.
1251+
- `body`: contains the loop body and an arbitrary number of blocks.
1252+
- `step`: single block region with the loop's step.
1253+
1254+
Example:
1255+
1256+
```mlir
1257+
cir.for cond {
1258+
cir.condition(%val)
1259+
} body {
1260+
cir.break
1261+
^bb2:
1262+
cir.yield
1263+
} step {
1264+
cir.yield
1265+
}
1266+
```
1267+
}];
1268+
1269+
let regions = (region SizedRegion<1>:$cond,
1270+
MinSizedRegion<1>:$body,
1271+
SizedRegion<1>:$step);
1272+
let assemblyFormat = [{
1273+
`:` `cond` $cond
1274+
`body` $body
1275+
`step` $step
1276+
attr-dict
1277+
}];
1278+
1279+
let builders = [
1280+
OpBuilder<(ins "function_ref<void(OpBuilder &, Location)>":$condBuilder,
1281+
"function_ref<void(OpBuilder &, Location)>":$bodyBuilder,
1282+
"function_ref<void(OpBuilder &, Location)>":$stepBuilder), [{
1283+
OpBuilder::InsertionGuard guard($_builder);
1284+
1285+
// Build condition region.
1286+
$_builder.createBlock($_state.addRegion());
1287+
condBuilder($_builder, $_state.location);
1288+
1289+
// Build body region.
1290+
$_builder.createBlock($_state.addRegion());
1291+
bodyBuilder($_builder, $_state.location);
1292+
1293+
// Build step region.
1294+
$_builder.createBlock($_state.addRegion());
1295+
stepBuilder($_builder, $_state.location);
1296+
}]>
1297+
];
1298+
1299+
let extraClassDeclaration = [{
1300+
Region *maybeGetStep() { return &getStep(); }
1301+
llvm::SmallVector<Region *> getRegionsInExecutionOrder() {
1302+
return llvm::SmallVector<Region *, 3>{&getCond(), &getBody(), &getStep()};
1303+
}
1304+
}];
1305+
}
1306+
13401307
//===----------------------------------------------------------------------===//
13411308
// GlobalOp
13421309
//===----------------------------------------------------------------------===//
@@ -2684,7 +2651,7 @@ def AllocException : CIR_Op<"alloc_exception", [
26842651

26852652
def ThrowOp : CIR_Op<"throw", [
26862653
ParentOneOf<["FuncOp", "ScopeOp", "IfOp", "SwitchOp",
2687-
"DoWhileOp", "WhileOp", "LoopOp"]>,
2654+
"DoWhileOp", "WhileOp", "ForOp"]>,
26882655
Terminator]> {
26892656
let summary = "(Re)Throws an exception";
26902657
let description = [{

clang/lib/CIR/CodeGen/CIRGenBuilder.h

+9
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,15 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
615615
return create<mlir::cir::WhileOp>(loc, condBuilder, bodyBuilder);
616616
}
617617

618+
/// Create a for operation.
619+
mlir::cir::ForOp createFor(
620+
mlir::Location loc,
621+
llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> condBuilder,
622+
llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> bodyBuilder,
623+
llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> stepBuilder) {
624+
return create<mlir::cir::ForOp>(loc, condBuilder, bodyBuilder, stepBuilder);
625+
}
626+
618627
mlir::cir::MemCpyOp createMemCpy(mlir::Location loc, mlir::Value dst,
619628
mlir::Value src, mlir::Value len) {
620629
return create<mlir::cir::MemCpyOp>(loc, dst, src, len);

clang/lib/CIR/CodeGen/CIRGenStmt.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ CIRGenFunction::buildDefaultStmt(const DefaultStmt &S, mlir::Type condType,
674674
mlir::LogicalResult
675675
CIRGenFunction::buildCXXForRangeStmt(const CXXForRangeStmt &S,
676676
ArrayRef<const Attr *> ForAttrs) {
677-
mlir::cir::LoopOp loopOp;
677+
mlir::cir::ForOp forOp;
678678

679679
// TODO(cir): pass in array of attributes.
680680
auto forStmtBuilder = [&]() -> mlir::LogicalResult {
@@ -697,8 +697,8 @@ CIRGenFunction::buildCXXForRangeStmt(const CXXForRangeStmt &S,
697697
// sure we handle all cases.
698698
assert(!UnimplementedFeature::requiresCleanups());
699699

700-
loopOp = builder.create<LoopOp>(
701-
getLoc(S.getSourceRange()), mlir::cir::LoopOpKind::For,
700+
forOp = builder.createFor(
701+
getLoc(S.getSourceRange()),
702702
/*condBuilder=*/
703703
[&](mlir::OpBuilder &b, mlir::Location loc) {
704704
assert(!UnimplementedFeature::createProfileWeightsForLoop());
@@ -743,12 +743,12 @@ CIRGenFunction::buildCXXForRangeStmt(const CXXForRangeStmt &S,
743743
if (res.failed())
744744
return res;
745745

746-
terminateBody(builder, loopOp.getBody(), getLoc(S.getEndLoc()));
746+
terminateBody(builder, forOp.getBody(), getLoc(S.getEndLoc()));
747747
return mlir::success();
748748
}
749749

750750
mlir::LogicalResult CIRGenFunction::buildForStmt(const ForStmt &S) {
751-
mlir::cir::LoopOp loopOp;
751+
mlir::cir::ForOp forOp;
752752

753753
// TODO: pass in array of attributes.
754754
auto forStmtBuilder = [&]() -> mlir::LogicalResult {
@@ -764,8 +764,8 @@ mlir::LogicalResult CIRGenFunction::buildForStmt(const ForStmt &S) {
764764
// sure we handle all cases.
765765
assert(!UnimplementedFeature::requiresCleanups());
766766

767-
loopOp = builder.create<LoopOp>(
768-
getLoc(S.getSourceRange()), mlir::cir::LoopOpKind::For,
767+
forOp = builder.createFor(
768+
getLoc(S.getSourceRange()),
769769
/*condBuilder=*/
770770
[&](mlir::OpBuilder &b, mlir::Location loc) {
771771
assert(!UnimplementedFeature::createProfileWeightsForLoop());
@@ -822,7 +822,7 @@ mlir::LogicalResult CIRGenFunction::buildForStmt(const ForStmt &S) {
822822
if (res.failed())
823823
return res;
824824

825-
terminateBody(builder, loopOp.getBody(), getLoc(S.getEndLoc()));
825+
terminateBody(builder, forOp.getBody(), getLoc(S.getEndLoc()));
826826
return mlir::success();
827827
}
828828

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

+8-36
Original file line numberDiff line numberDiff line change
@@ -1238,42 +1238,6 @@ void CatchOp::build(
12381238
catchBuilder(builder, result.location, result);
12391239
}
12401240

1241-
//===----------------------------------------------------------------------===//
1242-
// LoopOp
1243-
//===----------------------------------------------------------------------===//
1244-
1245-
void LoopOp::build(OpBuilder &builder, OperationState &result,
1246-
cir::LoopOpKind kind,
1247-
function_ref<void(OpBuilder &, Location)> condBuilder,
1248-
function_ref<void(OpBuilder &, Location)> bodyBuilder,
1249-
function_ref<void(OpBuilder &, Location)> stepBuilder) {
1250-
OpBuilder::InsertionGuard guard(builder);
1251-
::mlir::cir::LoopOpKindAttr kindAttr =
1252-
cir::LoopOpKindAttr::get(builder.getContext(), kind);
1253-
result.addAttribute(getKindAttrName(result.name), kindAttr);
1254-
1255-
Region *condRegion = result.addRegion();
1256-
builder.createBlock(condRegion);
1257-
condBuilder(builder, result.location);
1258-
1259-
Region *bodyRegion = result.addRegion();
1260-
builder.createBlock(bodyRegion);
1261-
bodyBuilder(builder, result.location);
1262-
1263-
Region *stepRegion = result.addRegion();
1264-
builder.createBlock(stepRegion);
1265-
stepBuilder(builder, result.location);
1266-
}
1267-
1268-
void LoopOp::getSuccessorRegions(mlir::RegionBranchPoint point,
1269-
SmallVectorImpl<RegionSuccessor> &regions) {
1270-
LoopOpInterface::getLoopOpSuccessorRegions(*this, point, regions);
1271-
}
1272-
1273-
llvm::SmallVector<Region *> LoopOp::getLoopRegions() { return {&getBody()}; }
1274-
1275-
LogicalResult LoopOp::verify() { return success(); }
1276-
12771241
//===----------------------------------------------------------------------===//
12781242
// LoopOpInterface Methods
12791243
//===----------------------------------------------------------------------===//
@@ -1296,6 +1260,14 @@ void WhileOp::getSuccessorRegions(
12961260

12971261
::llvm::SmallVector<Region *> WhileOp::getLoopRegions() { return {&getBody()}; }
12981262

1263+
void ForOp::getSuccessorRegions(
1264+
::mlir::RegionBranchPoint point,
1265+
::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &regions) {
1266+
LoopOpInterface::getLoopOpSuccessorRegions(*this, point, regions);
1267+
}
1268+
1269+
::llvm::SmallVector<Region *> ForOp::getLoopRegions() { return {&getBody()}; }
1270+
12991271
//===----------------------------------------------------------------------===//
13001272
// GlobalOp
13011273
//===----------------------------------------------------------------------===//

clang/lib/CIR/Interfaces/CIRLoopOpInterface.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ void LoopOpInterface::getLoopOpSuccessorRegions(
3838
// Branching from step: go to condition.
3939
else if (op.maybeGetStep() == point.getRegionOrNull()) {
4040
regions.emplace_back(&op.getCond(), op.getCond().getArguments());
41+
} else {
42+
llvm_unreachable("unexpected branch origin");
4143
}
4244
}
4345

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -2261,7 +2261,6 @@ void ConvertCIRToLLVMPass::runOnOperation() {
22612261
// ,ConstantOp
22622262
// ,FuncOp
22632263
// ,LoadOp
2264-
// ,LoopOp
22652264
// ,ReturnOp
22662265
// ,StoreOp
22672266
// ,YieldOp

clang/test/CIR/CodeGen/loop-scope.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ void l0(void) {
1515
// CPPSCOPE-NEXT: %1 = cir.alloca !s32i, cir.ptr <!s32i>, ["j", init] {alignment = 4 : i64}
1616
// CPPSCOPE-NEXT: %2 = cir.const(#cir.int<0> : !s32i) : !s32i
1717
// CPPSCOPE-NEXT: cir.store %2, %0 : !s32i, cir.ptr <!s32i>
18-
// CPPSCOPE-NEXT: cir.loop for(cond : {
18+
// CPPSCOPE-NEXT: cir.for : cond {
1919

2020
// CSCOPE: cir.func @l0()
2121
// CSCOPE-NEXT: cir.scope {
2222
// CSCOPE-NEXT: %0 = cir.alloca !s32i, cir.ptr <!s32i>, ["i", init] {alignment = 4 : i64}
2323
// CSCOPE-NEXT: %1 = cir.const(#cir.int<0> : !s32i) : !s32i
2424
// CSCOPE-NEXT: cir.store %1, %0 : !s32i, cir.ptr <!s32i>
25-
// CSCOPE-NEXT: cir.loop for(cond : {
25+
// CSCOPE-NEXT: cir.for : cond {
2626

27-
// CSCOPE: }) {
27+
// CSCOPE: } body {
2828
// CSCOPE-NEXT: cir.scope {
2929
// CSCOPE-NEXT: %2 = cir.alloca !s32i, cir.ptr <!s32i>, ["j", init] {alignment = 4 : i64}

0 commit comments

Comments
 (0)