Skip to content

Commit

Permalink
[RTG] Custom assembly format for 'rtg.test' operation
Browse files Browse the repository at this point in the history
  • Loading branch information
maerhart committed Feb 4, 2025
1 parent de186b6 commit 37a315d
Show file tree
Hide file tree
Showing 16 changed files with 183 additions and 81 deletions.
2 changes: 1 addition & 1 deletion frontends/PyRTG/test/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ rtg.sequence @seq0() {
rtg.label local %0
}

rtg.test @test0 : !rtg.dict<> {
rtg.test @test0() {
%0 = rtg.get_sequence @seq0 : !rtg.sequence
%1 = rtg.randomize_sequence %0
rtg.embed_sequence %1
Expand Down
6 changes: 2 additions & 4 deletions include/circt/Dialect/RTG/IR/RTGOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ def TestOp : RTGOp<"test", [
Symbol,
SingleBlock,
NoTerminator,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>,
HasParent<"mlir::ModuleOp">
]> {
let summary = "the root of a test";
Expand Down Expand Up @@ -509,10 +510,7 @@ def TestOp : RTGOp<"test", [
TypeAttrOf<DictType>:$target);
let regions = (region SizedRegion<1>:$bodyRegion);

let assemblyFormat = [{
$sym_name `:` $target attr-dict-with-keyword $bodyRegion
}];

let hasCustomAssemblyFormat = 1;
let hasRegionVerifier = 1;
}

Expand Down
7 changes: 3 additions & 4 deletions integration_test/Bindings/Python/dialects/rtg.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
# CHECK: [[V1:%.+]] = rtgtest.cpu_decl <1>
# CHECK: rtg.yield [[V0]], [[V1]] : !rtgtest.cpu, !rtgtest.cpu
# CHECK: }
# CHECK: rtg.test @test_name : !rtg.dict<cpu0: !rtgtest.cpu, cpu1: !rtgtest.cpu> {
# CHECK: ^bb{{.*}}(%{{.*}}: !rtgtest.cpu, %{{.*}}: !rtgtest.cpu):
# CHECK: rtg.test @test_name(%cpu0: !rtgtest.cpu, %cpu1: !rtgtest.cpu) {
# CHECK: }
print(m)

Expand Down Expand Up @@ -62,7 +61,7 @@
seq_get = rtg.GetSequenceOp(rtg.SequenceType.get(), 'sequence_name')
rtg.RandomizeSequenceOp(seq_get)

# CHECK: rtg.test @test_name : !rtg.dict<> {
# CHECK: rtg.test @test_name() {
# CHECK-NEXT: [[SEQ:%.+]] = rtg.get_sequence @sequence_name
# CHECK-NEXT: rtg.randomize_sequence [[SEQ]]
# CHECK-NEXT: }
Expand All @@ -76,7 +75,7 @@
rtgtool.populate_randomizer_pipeline(pm, options)
pm.run(m.operation)

# CHECK: rtg.test @test_name : !rtg.dict<> {
# CHECK: rtg.test @test_name() {
# CHECK-NEXT: }
print(m)

Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/RTG/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ add_circt_dialect_library(CIRCTRTGDialect
CIRCTRTGISAAssemblyTypeInterfacesIncGen
CIRCTRTGOpInterfacesIncGen
CIRCTRTGTypeInterfacesIncGen
CIRCTSupport
MLIRRTGIncGen

LINK_LIBS PUBLIC
Expand Down
111 changes: 111 additions & 0 deletions lib/Dialect/RTG/IR/RTGOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
//===----------------------------------------------------------------------===//

#include "circt/Dialect/RTG/IR/RTGOps.h"
#include "circt/Support/ParsingUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/SmallString.h"

using namespace mlir;
using namespace circt;
Expand Down Expand Up @@ -399,6 +401,115 @@ LogicalResult TestOp::verifyRegions() {
return success();
}

ParseResult TestOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse the name as a symbol.
if (parser.parseSymbolName(
result.getOrAddProperties<TestOp::Properties>().sym_name))
return failure();

// Parse the function signature.
SmallVector<OpAsmParser::Argument> arguments;
SmallVector<StringAttr> names;

auto parseOneArgument = [&]() -> ParseResult {
std::string name;
auto res =
parser.parseOptionalKeywordOrString(&name) || parser.parseColon();

auto argLoc = parser.getCurrentLocation();
if (failed(parser.parseArgument(arguments.emplace_back(),
/*allowType=*/true, /*allowAttrs=*/true)))
return failure();

// If no explicit name was provided, try to use the SSA name.
if (res) {
auto inferredName = parsing_util::getNameFromSSA(
result.getContext(), arguments.back().ssaName.name);
if (inferredName.empty())
return parser.emitError(argLoc, "invalid SSA name for test argument");
names.push_back(inferredName);
} else {
names.push_back(StringAttr::get(result.getContext(), name));
}

return success();
};
if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren,
parseOneArgument, " in argument list"))
return failure();

SmallVector<Type> argTypes;
SmallVector<DictEntry> entries;
SmallVector<Location> argLocs;
argTypes.reserve(arguments.size());
argLocs.reserve(arguments.size());
for (auto [name, arg] : llvm::zip(names, arguments)) {
argTypes.push_back(arg.type);
argLocs.push_back(arg.sourceLoc ? *arg.sourceLoc : result.location);
entries.push_back({name, arg.type});
}
auto emitError = [&]() -> InFlightDiagnostic {
return parser.emitError(parser.getCurrentLocation());
};
Type type = DictType::getChecked(emitError, result.getContext(),
ArrayRef<DictEntry>(entries));
if (!type)
return failure();
result.getOrAddProperties<TestOp::Properties>().target = TypeAttr::get(type);

auto loc = parser.getCurrentLocation();
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();
if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
return parser.emitError(loc)
<< "'" << result.name.getStringRef() << "' op ";
})))
return failure();

std::unique_ptr<Region> bodyRegionRegion = std::make_unique<Region>();
if (parser.parseRegion(*bodyRegionRegion, arguments))
return failure();

if (bodyRegionRegion->empty()) {
bodyRegionRegion->emplaceBlock();
bodyRegionRegion->addArguments(argTypes, argLocs);
}
result.addRegion(std::move(bodyRegionRegion));

return success();
}

void TestOp::print(OpAsmPrinter &p) {
p << ' ';
p.printSymbolName(getSymNameAttr().getValue());
p << "(";
SmallString<32> resultNameStr;
llvm::interleaveComma(
llvm::zip(getTarget().getEntries(), getBody()->getArguments()), p,
[&](auto entryAndArg) {
auto [entry, arg] = entryAndArg;

resultNameStr.clear();
llvm::raw_svector_ostream tmpStream(resultNameStr);
p.printOperand(arg, tmpStream);
if (tmpStream.str().drop_front() != entry.name)
p << entry.name.getValue() << ": ";
p.printRegionArgument(arg);
});
p << ")";
p.printOptionalAttrDictWithKeyword(
(*this)->getAttrs(), {getSymNameAttrName(), getTargetAttrName()});
p << ' ';
p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
}

void TestOp::getAsmBlockArgumentNames(Region &region,
OpAsmSetValueNameFn setNameFn) {
for (auto [entry, arg] :
llvm::zip(getTarget().getEntries(), region.getArguments()))
setNameFn(arg, entry.name.getValue());
}

//===----------------------------------------------------------------------===//
// TargetOp
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion test/CAPI/rtg-pipelines.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ int main(int argc, char **argv) {
ctx, mlirStringRefCreateFromCString(
"rtg.sequence @seq() {\n"
"}\n"
"rtg.test @test : !rtg.dict<> {\n"
"rtg.test @test() {\n"
" %0 = rtg.get_sequence @seq : !rtg.sequence\n"
"}\n"));
if (mlirModuleIsNull(moduleOp)) {
Expand Down
22 changes: 11 additions & 11 deletions test/Dialect/RTG/IR/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ rtg.target @empty_target : !rtg.dict<> {
rtg.yield
}

// CHECK-LABEL: rtg.test @empty_test : !rtg.dict<> {
rtg.test @empty_test : !rtg.dict<> { }
// CHECK-LABEL: rtg.test @empty_test() {
rtg.test @empty_test() { }

// CHECK-LABEL: rtg.target @target : !rtg.dict<num_cpus: i32, num_modes: i32> {
// CHECK: rtg.yield %{{.*}}, %{{.*}} : i32, i32
Expand All @@ -115,19 +115,19 @@ rtg.target @context_switch : !rtg.dict<> {
}

// CHECK-LABEL: @contexts
rtg.test @contexts : !rtg.dict<ctxt0: !rtgtest.cpu> {
^bb0(%arg0: !rtgtest.cpu):
rtg.test @contexts(%ctxt0: !rtgtest.cpu) {
// CHECK: rtg.on_context {{%.+}}, {{%.+}} : !rtgtest.cpu
%seq = rtg.get_sequence @seq0 : !rtg.sequence
rtg.on_context %arg0, %seq : !rtgtest.cpu
rtg.on_context %ctxt0, %seq : !rtgtest.cpu
}

// CHECK-LABEL: rtg.test @test : !rtg.dict<num_cpus: i32, num_modes: i32> {
// CHECK: ^bb0(%arg0: i32, %arg1: i32):
// CHECK: }
rtg.test @test : !rtg.dict<num_cpus: i32, num_modes: i32> {
^bb0(%arg0: i32, %arg1: i32):
}
// CHECK-LABEL: rtg.test @test0
// CHECK-SAME: (%num_cpus: i32, %num_modes: i32) {
rtg.test @test0(%num_cpus: i32, %num_modes: i32) { }

// CHECK-LABEL: rtg.test @test1
// CHECK-SAME: (%num_cpus: i32, %num_modes: i32) {
rtg.test @test1(num_cpus: %a: i32, num_modes: %b: i32) { }

// CHECK-LABEL: rtg.sequence @integerHandlingOps
rtg.sequence @integerHandlingOps(%arg0: index, %arg1: index) {
Expand Down
18 changes: 9 additions & 9 deletions test/Dialect/RTG/IR/errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -83,29 +83,29 @@ rtg.target @target : !rtg.dict<a: i32> {
// -----

// expected-error @below {{argument types must match dict entry types}}
rtg.test @test : !rtg.dict<a: i32> {
}
"rtg.test"() <{sym_name="test", target=!rtg.dict<a: i32>}> ({^bb0(%b: i8):}) : () -> ()

// -----

// expected-error @below {{dictionary must be sorted by names and contain no duplicates, first violation at entry 'a'}}
rtg.test @test : !rtg.dict<a: i32, a: i32> {
^bb0(%arg0: i32, %arg1: i32):
rtg.test @test(%a: i32, %a: i32) {
}

// -----

// expected-error @below {{dictionary must be sorted by names and contain no duplicates, first violation at entry 'a'}}
rtg.test @test : !rtg.dict<b: i32, a: i32> {
^bb0(%arg0: i32, %arg1: i32):
rtg.test @test(%b: i32, %a: i32) {
}

// -----

// expected-error @below {{invalid SSA name for test argument}}
rtg.test @test0(%0: i32, %1: i32) { }

// -----

// expected-error @below {{empty strings not allowed as entry names}}
rtg.test @test : !rtg.dict<"": i32> {
^bb0(%arg0: i32):
}
rtg.test @test(%dict: !rtg.dict<"": i32>) { }

// -----

Expand Down
Loading

0 comments on commit 37a315d

Please sign in to comment.