Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RTG] Custom assembly format for 'rtg.test' operation #8188

Open
wants to merge 1 commit into
base: maerhart-pyrtg-boilerplate
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion frontends/PyRTG/test/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ rtg.sequence @seq0() {
rtg.label local %0
}

rtg.test @test0 : !rtg.dict<> {
rtg.test @test0() {
%0 = rtg.get_sequence @seq0 : !rtg.sequence
%1 = rtg.randomize_sequence %0
rtg.embed_sequence %1
Expand Down
6 changes: 2 additions & 4 deletions include/circt/Dialect/RTG/IR/RTGOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ def TestOp : RTGOp<"test", [
Symbol,
SingleBlock,
NoTerminator,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>,
HasParent<"mlir::ModuleOp">
]> {
let summary = "the root of a test";
Expand Down Expand Up @@ -509,10 +510,7 @@ def TestOp : RTGOp<"test", [
TypeAttrOf<DictType>:$target);
let regions = (region SizedRegion<1>:$bodyRegion);

let assemblyFormat = [{
$sym_name `:` $target attr-dict-with-keyword $bodyRegion
}];

let hasCustomAssemblyFormat = 1;
let hasRegionVerifier = 1;
}

Expand Down
7 changes: 3 additions & 4 deletions integration_test/Bindings/Python/dialects/rtg.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
# CHECK: [[V1:%.+]] = rtgtest.cpu_decl <1>
# CHECK: rtg.yield [[V0]], [[V1]] : !rtgtest.cpu, !rtgtest.cpu
# CHECK: }
# CHECK: rtg.test @test_name : !rtg.dict<cpu0: !rtgtest.cpu, cpu1: !rtgtest.cpu> {
# CHECK: ^bb{{.*}}(%{{.*}}: !rtgtest.cpu, %{{.*}}: !rtgtest.cpu):
# CHECK: rtg.test @test_name(%cpu0: !rtgtest.cpu, %cpu1: !rtgtest.cpu) {
# CHECK: }
print(m)

Expand Down Expand Up @@ -62,7 +61,7 @@
seq_get = rtg.GetSequenceOp(rtg.SequenceType.get(), 'sequence_name')
rtg.RandomizeSequenceOp(seq_get)

# CHECK: rtg.test @test_name : !rtg.dict<> {
# CHECK: rtg.test @test_name() {
# CHECK-NEXT: [[SEQ:%.+]] = rtg.get_sequence @sequence_name
# CHECK-NEXT: rtg.randomize_sequence [[SEQ]]
# CHECK-NEXT: }
Expand All @@ -76,7 +75,7 @@
rtgtool.populate_randomizer_pipeline(pm, options)
pm.run(m.operation)

# CHECK: rtg.test @test_name : !rtg.dict<> {
# CHECK: rtg.test @test_name() {
# CHECK-NEXT: }
print(m)

Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/RTG/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ add_circt_dialect_library(CIRCTRTGDialect
CIRCTRTGISAAssemblyTypeInterfacesIncGen
CIRCTRTGOpInterfacesIncGen
CIRCTRTGTypeInterfacesIncGen
CIRCTSupport
MLIRRTGIncGen

LINK_LIBS PUBLIC
Expand Down
111 changes: 111 additions & 0 deletions lib/Dialect/RTG/IR/RTGOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
//===----------------------------------------------------------------------===//

#include "circt/Dialect/RTG/IR/RTGOps.h"
#include "circt/Support/ParsingUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/SmallString.h"

using namespace mlir;
using namespace circt;
Expand Down Expand Up @@ -399,6 +401,115 @@ LogicalResult TestOp::verifyRegions() {
return success();
}

ParseResult TestOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse the name as a symbol.
if (parser.parseSymbolName(
result.getOrAddProperties<TestOp::Properties>().sym_name))
return failure();

// Parse the function signature.
SmallVector<OpAsmParser::Argument> arguments;
SmallVector<StringAttr> names;

auto parseOneArgument = [&]() -> ParseResult {
std::string name;
auto res =
parser.parseOptionalKeywordOrString(&name) || parser.parseColon();

auto argLoc = parser.getCurrentLocation();
if (failed(parser.parseArgument(arguments.emplace_back(),
/*allowType=*/true, /*allowAttrs=*/true)))
return failure();

// If no explicit name was provided, try to use the SSA name.
if (res) {
auto inferredName = parsing_util::getNameFromSSA(
result.getContext(), arguments.back().ssaName.name);
if (inferredName.empty())
return parser.emitError(argLoc, "invalid SSA name for test argument");
names.push_back(inferredName);
} else {
names.push_back(StringAttr::get(result.getContext(), name));
}

return success();
};
if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren,
parseOneArgument, " in argument list"))
return failure();

SmallVector<Type> argTypes;
SmallVector<DictEntry> entries;
SmallVector<Location> argLocs;
argTypes.reserve(arguments.size());
argLocs.reserve(arguments.size());
for (auto [name, arg] : llvm::zip(names, arguments)) {
argTypes.push_back(arg.type);
argLocs.push_back(arg.sourceLoc ? *arg.sourceLoc : result.location);
entries.push_back({name, arg.type});
}
auto emitError = [&]() -> InFlightDiagnostic {
return parser.emitError(parser.getCurrentLocation());
};
Type type = DictType::getChecked(emitError, result.getContext(),
ArrayRef<DictEntry>(entries));
if (!type)
return failure();
result.getOrAddProperties<TestOp::Properties>().target = TypeAttr::get(type);

auto loc = parser.getCurrentLocation();
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();
if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
return parser.emitError(loc)
<< "'" << result.name.getStringRef() << "' op ";
})))
return failure();

std::unique_ptr<Region> bodyRegionRegion = std::make_unique<Region>();
if (parser.parseRegion(*bodyRegionRegion, arguments))
return failure();

if (bodyRegionRegion->empty()) {
bodyRegionRegion->emplaceBlock();
bodyRegionRegion->addArguments(argTypes, argLocs);
}
result.addRegion(std::move(bodyRegionRegion));

return success();
}

void TestOp::print(OpAsmPrinter &p) {
p << ' ';
p.printSymbolName(getSymNameAttr().getValue());
p << "(";
SmallString<32> resultNameStr;
llvm::interleaveComma(
llvm::zip(getTarget().getEntries(), getBody()->getArguments()), p,
[&](auto entryAndArg) {
auto [entry, arg] = entryAndArg;

resultNameStr.clear();
llvm::raw_svector_ostream tmpStream(resultNameStr);
p.printOperand(arg, tmpStream);
if (tmpStream.str().drop_front() != entry.name)
p << entry.name.getValue() << ": ";
p.printRegionArgument(arg);
});
p << ")";
p.printOptionalAttrDictWithKeyword(
(*this)->getAttrs(), {getSymNameAttrName(), getTargetAttrName()});
p << ' ';
p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
}

void TestOp::getAsmBlockArgumentNames(Region &region,
OpAsmSetValueNameFn setNameFn) {
for (auto [entry, arg] :
llvm::zip(getTarget().getEntries(), region.getArguments()))
setNameFn(arg, entry.name.getValue());
}

//===----------------------------------------------------------------------===//
// TargetOp
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion test/CAPI/rtg-pipelines.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ int main(int argc, char **argv) {
ctx, mlirStringRefCreateFromCString(
"rtg.sequence @seq() {\n"
"}\n"
"rtg.test @test : !rtg.dict<> {\n"
"rtg.test @test() {\n"
" %0 = rtg.get_sequence @seq : !rtg.sequence\n"
"}\n"));
if (mlirModuleIsNull(moduleOp)) {
Expand Down
22 changes: 11 additions & 11 deletions test/Dialect/RTG/IR/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ rtg.target @empty_target : !rtg.dict<> {
rtg.yield
}

// CHECK-LABEL: rtg.test @empty_test : !rtg.dict<> {
rtg.test @empty_test : !rtg.dict<> { }
// CHECK-LABEL: rtg.test @empty_test() {
rtg.test @empty_test() { }

// CHECK-LABEL: rtg.target @target : !rtg.dict<num_cpus: i32, num_modes: i32> {
// CHECK: rtg.yield %{{.*}}, %{{.*}} : i32, i32
Expand All @@ -115,19 +115,19 @@ rtg.target @context_switch : !rtg.dict<> {
}

// CHECK-LABEL: @contexts
rtg.test @contexts : !rtg.dict<ctxt0: !rtgtest.cpu> {
^bb0(%arg0: !rtgtest.cpu):
rtg.test @contexts(%ctxt0: !rtgtest.cpu) {
// CHECK: rtg.on_context {{%.+}}, {{%.+}} : !rtgtest.cpu
%seq = rtg.get_sequence @seq0 : !rtg.sequence
rtg.on_context %arg0, %seq : !rtgtest.cpu
rtg.on_context %ctxt0, %seq : !rtgtest.cpu
}

// CHECK-LABEL: rtg.test @test : !rtg.dict<num_cpus: i32, num_modes: i32> {
// CHECK: ^bb0(%arg0: i32, %arg1: i32):
// CHECK: }
rtg.test @test : !rtg.dict<num_cpus: i32, num_modes: i32> {
^bb0(%arg0: i32, %arg1: i32):
}
// CHECK-LABEL: rtg.test @test0
// CHECK-SAME: (%num_cpus: i32, %num_modes: i32) {
rtg.test @test0(%num_cpus: i32, %num_modes: i32) { }

// CHECK-LABEL: rtg.test @test1
// CHECK-SAME: (%num_cpus: i32, %num_modes: i32) {
rtg.test @test1(num_cpus: %a: i32, num_modes: %b: i32) { }

// CHECK-LABEL: rtg.sequence @integerHandlingOps
rtg.sequence @integerHandlingOps(%arg0: index, %arg1: index) {
Expand Down
18 changes: 9 additions & 9 deletions test/Dialect/RTG/IR/errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -83,29 +83,29 @@ rtg.target @target : !rtg.dict<a: i32> {
// -----

// expected-error @below {{argument types must match dict entry types}}
rtg.test @test : !rtg.dict<a: i32> {
}
"rtg.test"() <{sym_name="test", target=!rtg.dict<a: i32>}> ({^bb0(%b: i8):}) : () -> ()

// -----

// expected-error @below {{dictionary must be sorted by names and contain no duplicates, first violation at entry 'a'}}
rtg.test @test : !rtg.dict<a: i32, a: i32> {
^bb0(%arg0: i32, %arg1: i32):
rtg.test @test(%a: i32, %a: i32) {
}

// -----

// expected-error @below {{dictionary must be sorted by names and contain no duplicates, first violation at entry 'a'}}
rtg.test @test : !rtg.dict<b: i32, a: i32> {
^bb0(%arg0: i32, %arg1: i32):
rtg.test @test(%b: i32, %a: i32) {
}

// -----

// expected-error @below {{invalid SSA name for test argument}}
rtg.test @test0(%0: i32, %1: i32) { }

// -----

// expected-error @below {{empty strings not allowed as entry names}}
rtg.test @test : !rtg.dict<"": i32> {
^bb0(%arg0: i32):
}
rtg.test @test(%dict: !rtg.dict<"": i32>) { }

// -----

Expand Down
Loading