Skip to content

Commit 8396ed6

Browse files
committed
[CIR][IR] Refactor for loops
This patch completes the deprecation of the generic `cir.loop` operation by adding a new `cir.for` operation and removing the `cir.loop` op. The new representation removes some bloat and places the regions in order of execution. ghstack-source-id: c8184ed Pull Request resolved: #409
1 parent 96355af commit 8396ed6

File tree

16 files changed

+204
-283
lines changed

16 files changed

+204
-283
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

+70-103
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def StoreOp : CIR_Op<"store", [
442442

443443
def ReturnOp : CIR_Op<"return", [ParentOneOf<["FuncOp", "ScopeOp", "IfOp",
444444
"SwitchOp", "DoWhileOp",
445-
"WhileOp", "LoopOp"]>,
445+
"WhileOp", "ForOp"]>,
446446
Terminator]> {
447447
let summary = "Return from function";
448448
let description = [{
@@ -635,7 +635,7 @@ def ConditionOp : CIR_Op<"condition", [
635635
//===----------------------------------------------------------------------===//
636636

637637
def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
638-
ParentOneOf<["IfOp", "ScopeOp", "SwitchOp", "WhileOp", "LoopOp", "AwaitOp",
638+
ParentOneOf<["IfOp", "ScopeOp", "SwitchOp", "WhileOp", "ForOp", "AwaitOp",
639639
"TernaryOp", "GlobalOp", "DoWhileOp"]>]> {
640640
let summary = "Represents the default branching behaviour of a region";
641641
let description = [{
@@ -1134,106 +1134,6 @@ def BrCondOp : CIR_Op<"brcond",
11341134
}];
11351135
}
11361136

1137-
//===----------------------------------------------------------------------===//
1138-
// LoopOp
1139-
//===----------------------------------------------------------------------===//
1140-
1141-
def LoopOpKind_For : I32EnumAttrCase<"For", 1, "for">;
1142-
1143-
def LoopOpKind : I32EnumAttr<
1144-
"LoopOpKind",
1145-
"Loop kind",
1146-
[LoopOpKind_For]> {
1147-
let cppNamespace = "::mlir::cir";
1148-
}
1149-
1150-
def LoopOp : CIR_Op<"loop",
1151-
[LoopOpInterface,
1152-
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
1153-
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
1154-
RecursivelySpeculatable, NoRegionArguments]> {
1155-
let summary = "Loop";
1156-
let description = [{
1157-
`cir.loop` represents C/C++ loop forms. It defines 3 blocks:
1158-
- `cond`: region can contain multiple blocks, terminated by regular
1159-
`cir.yield` when control should yield back to the parent, and
1160-
`cir.continue` when execution continues to the next region.
1161-
The region destination depends on the loop form specified.
1162-
- `step`: region with one block, containing code to compute the
1163-
loop step, must be terminated with `cir.yield`.
1164-
- `body`: region for the loop's body, can contain an arbitrary
1165-
number of blocks.
1166-
1167-
The loop form: `for`, `while` and `dowhile` must also be specified and
1168-
each implies the loop regions execution order.
1169-
1170-
```mlir
1171-
// while (true) {
1172-
// i = i + 1;
1173-
// }
1174-
cir.loop while(cond : {
1175-
cir.continue
1176-
}, step : {
1177-
cir.yield
1178-
}) {
1179-
%3 = cir.load %1 : cir.ptr <i32>, i32
1180-
%4 = cir.const(1 : i32) : i32
1181-
%5 = cir.binop(add, %3, %4) : i32
1182-
cir.store %5, %1 : i32, cir.ptr <i32>
1183-
cir.yield
1184-
}
1185-
```
1186-
}];
1187-
1188-
let arguments = (ins Arg<LoopOpKind, "loop kind">:$kind);
1189-
let regions = (region AnyRegion:$cond, AnyRegion:$body,
1190-
SizedRegion<1>:$step);
1191-
1192-
let assemblyFormat = [{
1193-
$kind
1194-
`(`
1195-
`cond` `:` $cond `,`
1196-
`step` `:` $step
1197-
`)`
1198-
$body
1199-
attr-dict
1200-
}];
1201-
1202-
let skipDefaultBuilders = 1;
1203-
let builders = [
1204-
OpBuilder<(ins
1205-
"cir::LoopOpKind":$kind,
1206-
CArg<"function_ref<void(OpBuilder &, Location)>",
1207-
"nullptr">:$condBuilder,
1208-
CArg<"function_ref<void(OpBuilder &, Location)>",
1209-
"nullptr">:$bodyBuilder,
1210-
CArg<"function_ref<void(OpBuilder &, Location)>",
1211-
"nullptr">:$stepBuilder
1212-
)>
1213-
];
1214-
1215-
let hasVerifier = 1;
1216-
1217-
let extraClassDeclaration = [{
1218-
Region *maybeGetStep() {
1219-
if (getKind() == LoopOpKind::For)
1220-
return &getStep();
1221-
return nullptr;
1222-
}
1223-
1224-
llvm::SmallVector<Region *> getRegionsInExecutionOrder() {
1225-
switch(getKind()) {
1226-
case LoopOpKind::For:
1227-
return llvm::SmallVector<Region *, 3>{&getCond(), &getBody(), &getStep()};
1228-
case LoopOpKind::While:
1229-
return llvm::SmallVector<Region *, 2>{&getCond(), &getBody()};
1230-
case LoopOpKind::DoWhile:
1231-
return llvm::SmallVector<Region *, 2>{&getBody(), &getCond()};
1232-
}
1233-
}
1234-
}];
1235-
}
1236-
12371137
//===----------------------------------------------------------------------===//
12381138
// While & DoWhileOp
12391139
//===----------------------------------------------------------------------===//
@@ -1312,6 +1212,73 @@ def DoWhileOp : WhileOpBase<"do"> {
13121212
}];
13131213
}
13141214

1215+
//===----------------------------------------------------------------------===//
1216+
// ForOp
1217+
//===----------------------------------------------------------------------===//
1218+
1219+
def ForOp : CIR_Op<"for", [LoopOpInterface, NoRegionArguments]> {
1220+
let summary = "C/C++ for loop counterpart";
1221+
let description = [{
1222+
Represents a C/C++ for loop. It consists of three regions:
1223+
1224+
- `cond`: single block region with the loop's condition. Should be
1225+
terminated with a `cir.condition` operation.
1226+
- `body`: contains the loop body and an arbitrary number of blocks.
1227+
- `step`: single block region with the loop's step.
1228+
1229+
Example:
1230+
1231+
```mlir
1232+
cir.for cond {
1233+
cir.condition(%val)
1234+
} body {
1235+
cir.break
1236+
^bb2:
1237+
cir.yield
1238+
} step {
1239+
cir.yield
1240+
}
1241+
```
1242+
}];
1243+
1244+
let regions = (region SizedRegion<1>:$cond,
1245+
MinSizedRegion<1>:$body,
1246+
SizedRegion<1>:$step);
1247+
let assemblyFormat = [{
1248+
`:` `cond` $cond
1249+
`body` $body
1250+
`step` $step
1251+
attr-dict
1252+
}];
1253+
1254+
let builders = [
1255+
OpBuilder<(ins "function_ref<void(OpBuilder &, Location)>":$condBuilder,
1256+
"function_ref<void(OpBuilder &, Location)>":$bodyBuilder,
1257+
"function_ref<void(OpBuilder &, Location)>":$stepBuilder), [{
1258+
OpBuilder::InsertionGuard guard($_builder);
1259+
1260+
// Build condition region.
1261+
$_builder.createBlock($_state.addRegion());
1262+
condBuilder($_builder, $_state.location);
1263+
1264+
// Build body region.
1265+
$_builder.createBlock($_state.addRegion());
1266+
bodyBuilder($_builder, $_state.location);
1267+
1268+
// Build step region.
1269+
$_builder.createBlock($_state.addRegion());
1270+
stepBuilder($_builder, $_state.location);
1271+
}]>
1272+
];
1273+
1274+
let extraClassDeclaration = [{
1275+
Region *maybeGetStep() { return &getStep(); }
1276+
llvm::SmallVector<Region *> getRegionsInExecutionOrder() {
1277+
return llvm::SmallVector<Region *, 3>{&getCond(), &getBody(), &getStep()};
1278+
}
1279+
}];
1280+
}
1281+
13151282
//===----------------------------------------------------------------------===//
13161283
// GlobalOp
13171284
//===----------------------------------------------------------------------===//
@@ -2659,7 +2626,7 @@ def AllocException : CIR_Op<"alloc_exception", [
26592626

26602627
def ThrowOp : CIR_Op<"throw", [
26612628
ParentOneOf<["FuncOp", "ScopeOp", "IfOp", "SwitchOp",
2662-
"DoWhileOp", "WhileOp", "LoopOp"]>,
2629+
"DoWhileOp", "WhileOp", "ForOp"]>,
26632630
Terminator]> {
26642631
let summary = "(Re)Throws an exception";
26652632
let description = [{

clang/lib/CIR/CodeGen/CIRGenBuilder.h

+9
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,15 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
615615
return create<mlir::cir::WhileOp>(loc, condBuilder, bodyBuilder);
616616
}
617617

618+
/// Create a for operation.
619+
mlir::cir::ForOp createFor(
620+
mlir::Location loc,
621+
llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> condBuilder,
622+
llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> bodyBuilder,
623+
llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> stepBuilder) {
624+
return create<mlir::cir::ForOp>(loc, condBuilder, bodyBuilder, stepBuilder);
625+
}
626+
618627
mlir::cir::MemCpyOp createMemCpy(mlir::Location loc, mlir::Value dst,
619628
mlir::Value src, mlir::Value len) {
620629
return create<mlir::cir::MemCpyOp>(loc, dst, src, len);

clang/lib/CIR/CodeGen/CIRGenStmt.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ CIRGenFunction::buildDefaultStmt(const DefaultStmt &S, mlir::Type condType,
670670
mlir::LogicalResult
671671
CIRGenFunction::buildCXXForRangeStmt(const CXXForRangeStmt &S,
672672
ArrayRef<const Attr *> ForAttrs) {
673-
mlir::cir::LoopOp loopOp;
673+
mlir::cir::ForOp forOp;
674674

675675
// TODO(cir): pass in array of attributes.
676676
auto forStmtBuilder = [&]() -> mlir::LogicalResult {
@@ -693,8 +693,8 @@ CIRGenFunction::buildCXXForRangeStmt(const CXXForRangeStmt &S,
693693
// sure we handle all cases.
694694
assert(!UnimplementedFeature::requiresCleanups());
695695

696-
loopOp = builder.create<LoopOp>(
697-
getLoc(S.getSourceRange()), mlir::cir::LoopOpKind::For,
696+
forOp = builder.createFor(
697+
getLoc(S.getSourceRange()),
698698
/*condBuilder=*/
699699
[&](mlir::OpBuilder &b, mlir::Location loc) {
700700
assert(!UnimplementedFeature::createProfileWeightsForLoop());
@@ -739,12 +739,12 @@ CIRGenFunction::buildCXXForRangeStmt(const CXXForRangeStmt &S,
739739
if (res.failed())
740740
return res;
741741

742-
terminateBody(builder, loopOp.getBody(), getLoc(S.getEndLoc()));
742+
terminateBody(builder, forOp.getBody(), getLoc(S.getEndLoc()));
743743
return mlir::success();
744744
}
745745

746746
mlir::LogicalResult CIRGenFunction::buildForStmt(const ForStmt &S) {
747-
mlir::cir::LoopOp loopOp;
747+
mlir::cir::ForOp forOp;
748748

749749
// TODO: pass in array of attributes.
750750
auto forStmtBuilder = [&]() -> mlir::LogicalResult {
@@ -760,8 +760,8 @@ mlir::LogicalResult CIRGenFunction::buildForStmt(const ForStmt &S) {
760760
// sure we handle all cases.
761761
assert(!UnimplementedFeature::requiresCleanups());
762762

763-
loopOp = builder.create<LoopOp>(
764-
getLoc(S.getSourceRange()), mlir::cir::LoopOpKind::For,
763+
forOp = builder.createFor(
764+
getLoc(S.getSourceRange()),
765765
/*condBuilder=*/
766766
[&](mlir::OpBuilder &b, mlir::Location loc) {
767767
assert(!UnimplementedFeature::createProfileWeightsForLoop());
@@ -818,7 +818,7 @@ mlir::LogicalResult CIRGenFunction::buildForStmt(const ForStmt &S) {
818818
if (res.failed())
819819
return res;
820820

821-
terminateBody(builder, loopOp.getBody(), getLoc(S.getEndLoc()));
821+
terminateBody(builder, forOp.getBody(), getLoc(S.getEndLoc()));
822822
return mlir::success();
823823
}
824824

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

+8-36
Original file line numberDiff line numberDiff line change
@@ -1238,42 +1238,6 @@ void CatchOp::build(
12381238
catchBuilder(builder, result.location, result);
12391239
}
12401240

1241-
//===----------------------------------------------------------------------===//
1242-
// LoopOp
1243-
//===----------------------------------------------------------------------===//
1244-
1245-
void LoopOp::build(OpBuilder &builder, OperationState &result,
1246-
cir::LoopOpKind kind,
1247-
function_ref<void(OpBuilder &, Location)> condBuilder,
1248-
function_ref<void(OpBuilder &, Location)> bodyBuilder,
1249-
function_ref<void(OpBuilder &, Location)> stepBuilder) {
1250-
OpBuilder::InsertionGuard guard(builder);
1251-
::mlir::cir::LoopOpKindAttr kindAttr =
1252-
cir::LoopOpKindAttr::get(builder.getContext(), kind);
1253-
result.addAttribute(getKindAttrName(result.name), kindAttr);
1254-
1255-
Region *condRegion = result.addRegion();
1256-
builder.createBlock(condRegion);
1257-
condBuilder(builder, result.location);
1258-
1259-
Region *bodyRegion = result.addRegion();
1260-
builder.createBlock(bodyRegion);
1261-
bodyBuilder(builder, result.location);
1262-
1263-
Region *stepRegion = result.addRegion();
1264-
builder.createBlock(stepRegion);
1265-
stepBuilder(builder, result.location);
1266-
}
1267-
1268-
void LoopOp::getSuccessorRegions(mlir::RegionBranchPoint point,
1269-
SmallVectorImpl<RegionSuccessor> &regions) {
1270-
LoopOpInterface::getLoopOpSuccessorRegions(*this, point, regions);
1271-
}
1272-
1273-
llvm::SmallVector<Region *> LoopOp::getLoopRegions() { return {&getBody()}; }
1274-
1275-
LogicalResult LoopOp::verify() { return success(); }
1276-
12771241
//===----------------------------------------------------------------------===//
12781242
// LoopOpInterface Methods
12791243
//===----------------------------------------------------------------------===//
@@ -1296,6 +1260,14 @@ void WhileOp::getSuccessorRegions(
12961260

12971261
::llvm::SmallVector<Region *> WhileOp::getLoopRegions() { return {&getBody()}; }
12981262

1263+
void ForOp::getSuccessorRegions(
1264+
::mlir::RegionBranchPoint point,
1265+
::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &regions) {
1266+
LoopOpInterface::getLoopOpSuccessorRegions(*this, point, regions);
1267+
}
1268+
1269+
::llvm::SmallVector<Region *> ForOp::getLoopRegions() { return {&getBody()}; }
1270+
12991271
//===----------------------------------------------------------------------===//
13001272
// GlobalOp
13011273
//===----------------------------------------------------------------------===//

clang/lib/CIR/Interfaces/CIRLoopOpInterface.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ void LoopOpInterface::getLoopOpSuccessorRegions(
3838
// Branching from step: go to condition.
3939
else if (op.maybeGetStep() == point.getRegionOrNull()) {
4040
regions.emplace_back(&op.getCond(), op.getCond().getArguments());
41+
} else {
42+
llvm_unreachable("unexpected branch origin");
4143
}
4244
}
4345

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -2263,7 +2263,6 @@ void ConvertCIRToLLVMPass::runOnOperation() {
22632263
// ,ConstantOp
22642264
// ,FuncOp
22652265
// ,LoadOp
2266-
// ,LoopOp
22672266
// ,ReturnOp
22682267
// ,StoreOp
22692268
// ,YieldOp

clang/test/CIR/CodeGen/loop-scope.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ void l0(void) {
1515
// CPPSCOPE-NEXT: %1 = cir.alloca !s32i, cir.ptr <!s32i>, ["j", init] {alignment = 4 : i64}
1616
// CPPSCOPE-NEXT: %2 = cir.const(#cir.int<0> : !s32i) : !s32i
1717
// CPPSCOPE-NEXT: cir.store %2, %0 : !s32i, cir.ptr <!s32i>
18-
// CPPSCOPE-NEXT: cir.loop for(cond : {
18+
// CPPSCOPE-NEXT: cir.for : cond {
1919

2020
// CSCOPE: cir.func @l0()
2121
// CSCOPE-NEXT: cir.scope {
2222
// CSCOPE-NEXT: %0 = cir.alloca !s32i, cir.ptr <!s32i>, ["i", init] {alignment = 4 : i64}
2323
// CSCOPE-NEXT: %1 = cir.const(#cir.int<0> : !s32i) : !s32i
2424
// CSCOPE-NEXT: cir.store %1, %0 : !s32i, cir.ptr <!s32i>
25-
// CSCOPE-NEXT: cir.loop for(cond : {
25+
// CSCOPE-NEXT: cir.for : cond {
2626

27-
// CSCOPE: }) {
27+
// CSCOPE: } body {
2828
// CSCOPE-NEXT: cir.scope {
2929
// CSCOPE-NEXT: %2 = cir.alloca !s32i, cir.ptr <!s32i>, ["j", init] {alignment = 4 : i64}

0 commit comments

Comments
 (0)