Skip to content

Commit 37a315d

Browse files
committed
[RTG] Custom assembly format for 'rtg.test' operation
1 parent de186b6 commit 37a315d

File tree

16 files changed

+183
-81
lines changed

16 files changed

+183
-81
lines changed

frontends/PyRTG/test/basic.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ rtg.sequence @seq0() {
2525
rtg.label local %0
2626
}
2727

28-
rtg.test @test0 : !rtg.dict<> {
28+
rtg.test @test0() {
2929
%0 = rtg.get_sequence @seq0 : !rtg.sequence
3030
%1 = rtg.randomize_sequence %0
3131
rtg.embed_sequence %1

include/circt/Dialect/RTG/IR/RTGOps.td

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,7 @@ def TestOp : RTGOp<"test", [
481481
Symbol,
482482
SingleBlock,
483483
NoTerminator,
484+
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>,
484485
HasParent<"mlir::ModuleOp">
485486
]> {
486487
let summary = "the root of a test";
@@ -509,10 +510,7 @@ def TestOp : RTGOp<"test", [
509510
TypeAttrOf<DictType>:$target);
510511
let regions = (region SizedRegion<1>:$bodyRegion);
511512

512-
let assemblyFormat = [{
513-
$sym_name `:` $target attr-dict-with-keyword $bodyRegion
514-
}];
515-
513+
let hasCustomAssemblyFormat = 1;
516514
let hasRegionVerifier = 1;
517515
}
518516

integration_test/Bindings/Python/dialects/rtg.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@
3232
# CHECK: [[V1:%.+]] = rtgtest.cpu_decl <1>
3333
# CHECK: rtg.yield [[V0]], [[V1]] : !rtgtest.cpu, !rtgtest.cpu
3434
# CHECK: }
35-
# CHECK: rtg.test @test_name : !rtg.dict<cpu0: !rtgtest.cpu, cpu1: !rtgtest.cpu> {
36-
# CHECK: ^bb{{.*}}(%{{.*}}: !rtgtest.cpu, %{{.*}}: !rtgtest.cpu):
35+
# CHECK: rtg.test @test_name(%cpu0: !rtgtest.cpu, %cpu1: !rtgtest.cpu) {
3736
# CHECK: }
3837
print(m)
3938

@@ -62,7 +61,7 @@
6261
seq_get = rtg.GetSequenceOp(rtg.SequenceType.get(), 'sequence_name')
6362
rtg.RandomizeSequenceOp(seq_get)
6463

65-
# CHECK: rtg.test @test_name : !rtg.dict<> {
64+
# CHECK: rtg.test @test_name() {
6665
# CHECK-NEXT: [[SEQ:%.+]] = rtg.get_sequence @sequence_name
6766
# CHECK-NEXT: rtg.randomize_sequence [[SEQ]]
6867
# CHECK-NEXT: }
@@ -76,7 +75,7 @@
7675
rtgtool.populate_randomizer_pipeline(pm, options)
7776
pm.run(m.operation)
7877

79-
# CHECK: rtg.test @test_name : !rtg.dict<> {
78+
# CHECK: rtg.test @test_name() {
8079
# CHECK-NEXT: }
8180
print(m)
8281

lib/Dialect/RTG/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ add_circt_dialect_library(CIRCTRTGDialect
2222
CIRCTRTGISAAssemblyTypeInterfacesIncGen
2323
CIRCTRTGOpInterfacesIncGen
2424
CIRCTRTGTypeInterfacesIncGen
25+
CIRCTSupport
2526
MLIRRTGIncGen
2627

2728
LINK_LIBS PUBLIC

lib/Dialect/RTG/IR/RTGOps.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "circt/Dialect/RTG/IR/RTGOps.h"
14+
#include "circt/Support/ParsingUtils.h"
1415
#include "mlir/IR/Builders.h"
1516
#include "mlir/IR/DialectImplementation.h"
17+
#include "llvm/ADT/SmallString.h"
1618

1719
using namespace mlir;
1820
using namespace circt;
@@ -399,6 +401,115 @@ LogicalResult TestOp::verifyRegions() {
399401
return success();
400402
}
401403

404+
ParseResult TestOp::parse(OpAsmParser &parser, OperationState &result) {
405+
// Parse the name as a symbol.
406+
if (parser.parseSymbolName(
407+
result.getOrAddProperties<TestOp::Properties>().sym_name))
408+
return failure();
409+
410+
// Parse the function signature.
411+
SmallVector<OpAsmParser::Argument> arguments;
412+
SmallVector<StringAttr> names;
413+
414+
auto parseOneArgument = [&]() -> ParseResult {
415+
std::string name;
416+
auto res =
417+
parser.parseOptionalKeywordOrString(&name) || parser.parseColon();
418+
419+
auto argLoc = parser.getCurrentLocation();
420+
if (failed(parser.parseArgument(arguments.emplace_back(),
421+
/*allowType=*/true, /*allowAttrs=*/true)))
422+
return failure();
423+
424+
// If no explicit name was provided, try to use the SSA name.
425+
if (res) {
426+
auto inferredName = parsing_util::getNameFromSSA(
427+
result.getContext(), arguments.back().ssaName.name);
428+
if (inferredName.empty())
429+
return parser.emitError(argLoc, "invalid SSA name for test argument");
430+
names.push_back(inferredName);
431+
} else {
432+
names.push_back(StringAttr::get(result.getContext(), name));
433+
}
434+
435+
return success();
436+
};
437+
if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren,
438+
parseOneArgument, " in argument list"))
439+
return failure();
440+
441+
SmallVector<Type> argTypes;
442+
SmallVector<DictEntry> entries;
443+
SmallVector<Location> argLocs;
444+
argTypes.reserve(arguments.size());
445+
argLocs.reserve(arguments.size());
446+
for (auto [name, arg] : llvm::zip(names, arguments)) {
447+
argTypes.push_back(arg.type);
448+
argLocs.push_back(arg.sourceLoc ? *arg.sourceLoc : result.location);
449+
entries.push_back({name, arg.type});
450+
}
451+
auto emitError = [&]() -> InFlightDiagnostic {
452+
return parser.emitError(parser.getCurrentLocation());
453+
};
454+
Type type = DictType::getChecked(emitError, result.getContext(),
455+
ArrayRef<DictEntry>(entries));
456+
if (!type)
457+
return failure();
458+
result.getOrAddProperties<TestOp::Properties>().target = TypeAttr::get(type);
459+
460+
auto loc = parser.getCurrentLocation();
461+
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
462+
return failure();
463+
if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
464+
return parser.emitError(loc)
465+
<< "'" << result.name.getStringRef() << "' op ";
466+
})))
467+
return failure();
468+
469+
std::unique_ptr<Region> bodyRegionRegion = std::make_unique<Region>();
470+
if (parser.parseRegion(*bodyRegionRegion, arguments))
471+
return failure();
472+
473+
if (bodyRegionRegion->empty()) {
474+
bodyRegionRegion->emplaceBlock();
475+
bodyRegionRegion->addArguments(argTypes, argLocs);
476+
}
477+
result.addRegion(std::move(bodyRegionRegion));
478+
479+
return success();
480+
}
481+
482+
void TestOp::print(OpAsmPrinter &p) {
483+
p << ' ';
484+
p.printSymbolName(getSymNameAttr().getValue());
485+
p << "(";
486+
SmallString<32> resultNameStr;
487+
llvm::interleaveComma(
488+
llvm::zip(getTarget().getEntries(), getBody()->getArguments()), p,
489+
[&](auto entryAndArg) {
490+
auto [entry, arg] = entryAndArg;
491+
492+
resultNameStr.clear();
493+
llvm::raw_svector_ostream tmpStream(resultNameStr);
494+
p.printOperand(arg, tmpStream);
495+
if (tmpStream.str().drop_front() != entry.name)
496+
p << entry.name.getValue() << ": ";
497+
p.printRegionArgument(arg);
498+
});
499+
p << ")";
500+
p.printOptionalAttrDictWithKeyword(
501+
(*this)->getAttrs(), {getSymNameAttrName(), getTargetAttrName()});
502+
p << ' ';
503+
p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
504+
}
505+
506+
void TestOp::getAsmBlockArgumentNames(Region &region,
507+
OpAsmSetValueNameFn setNameFn) {
508+
for (auto [entry, arg] :
509+
llvm::zip(getTarget().getEntries(), region.getArguments()))
510+
setNameFn(arg, entry.name.getValue());
511+
}
512+
402513
//===----------------------------------------------------------------------===//
403514
// TargetOp
404515
//===----------------------------------------------------------------------===//

test/CAPI/rtg-pipelines.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ int main(int argc, char **argv) {
2121
ctx, mlirStringRefCreateFromCString(
2222
"rtg.sequence @seq() {\n"
2323
"}\n"
24-
"rtg.test @test : !rtg.dict<> {\n"
24+
"rtg.test @test() {\n"
2525
" %0 = rtg.get_sequence @seq : !rtg.sequence\n"
2626
"}\n"));
2727
if (mlirModuleIsNull(moduleOp)) {

test/Dialect/RTG/IR/basic.mlir

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ rtg.target @empty_target : !rtg.dict<> {
8989
rtg.yield
9090
}
9191

92-
// CHECK-LABEL: rtg.test @empty_test : !rtg.dict<> {
93-
rtg.test @empty_test : !rtg.dict<> { }
92+
// CHECK-LABEL: rtg.test @empty_test() {
93+
rtg.test @empty_test() { }
9494

9595
// CHECK-LABEL: rtg.target @target : !rtg.dict<num_cpus: i32, num_modes: i32> {
9696
// CHECK: rtg.yield %{{.*}}, %{{.*}} : i32, i32
@@ -115,19 +115,19 @@ rtg.target @context_switch : !rtg.dict<> {
115115
}
116116

117117
// CHECK-LABEL: @contexts
118-
rtg.test @contexts : !rtg.dict<ctxt0: !rtgtest.cpu> {
119-
^bb0(%arg0: !rtgtest.cpu):
118+
rtg.test @contexts(%ctxt0: !rtgtest.cpu) {
120119
// CHECK: rtg.on_context {{%.+}}, {{%.+}} : !rtgtest.cpu
121120
%seq = rtg.get_sequence @seq0 : !rtg.sequence
122-
rtg.on_context %arg0, %seq : !rtgtest.cpu
121+
rtg.on_context %ctxt0, %seq : !rtgtest.cpu
123122
}
124123

125-
// CHECK-LABEL: rtg.test @test : !rtg.dict<num_cpus: i32, num_modes: i32> {
126-
// CHECK: ^bb0(%arg0: i32, %arg1: i32):
127-
// CHECK: }
128-
rtg.test @test : !rtg.dict<num_cpus: i32, num_modes: i32> {
129-
^bb0(%arg0: i32, %arg1: i32):
130-
}
124+
// CHECK-LABEL: rtg.test @test0
125+
// CHECK-SAME: (%num_cpus: i32, %num_modes: i32) {
126+
rtg.test @test0(%num_cpus: i32, %num_modes: i32) { }
127+
128+
// CHECK-LABEL: rtg.test @test1
129+
// CHECK-SAME: (%num_cpus: i32, %num_modes: i32) {
130+
rtg.test @test1(num_cpus: %a: i32, num_modes: %b: i32) { }
131131

132132
// CHECK-LABEL: rtg.sequence @integerHandlingOps
133133
rtg.sequence @integerHandlingOps(%arg0: index, %arg1: index) {

test/Dialect/RTG/IR/errors.mlir

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,29 +83,29 @@ rtg.target @target : !rtg.dict<a: i32> {
8383
// -----
8484

8585
// expected-error @below {{argument types must match dict entry types}}
86-
rtg.test @test : !rtg.dict<a: i32> {
87-
}
86+
"rtg.test"() <{sym_name="test", target=!rtg.dict<a: i32>}> ({^bb0(%b: i8):}) : () -> ()
8887

8988
// -----
9089

9190
// expected-error @below {{dictionary must be sorted by names and contain no duplicates, first violation at entry 'a'}}
92-
rtg.test @test : !rtg.dict<a: i32, a: i32> {
93-
^bb0(%arg0: i32, %arg1: i32):
91+
rtg.test @test(%a: i32, %a: i32) {
9492
}
9593

9694
// -----
9795

9896
// expected-error @below {{dictionary must be sorted by names and contain no duplicates, first violation at entry 'a'}}
99-
rtg.test @test : !rtg.dict<b: i32, a: i32> {
100-
^bb0(%arg0: i32, %arg1: i32):
97+
rtg.test @test(%b: i32, %a: i32) {
10198
}
10299

103100
// -----
104101

102+
// expected-error @below {{invalid SSA name for test argument}}
103+
rtg.test @test0(%0: i32, %1: i32) { }
104+
105+
// -----
106+
105107
// expected-error @below {{empty strings not allowed as entry names}}
106-
rtg.test @test : !rtg.dict<"": i32> {
107-
^bb0(%arg0: i32):
108-
}
108+
rtg.test @test(%dict: !rtg.dict<"": i32>) { }
109109

110110
// -----
111111

0 commit comments

Comments
 (0)