Skip to content

[RTG] Custom assembly format for 'rtg.test' operation #8188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion frontends/PyRTG/test/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ rtg.sequence @seq0() {
rtg.label local %0
}

rtg.test @test0 : !rtg.dict<> {
rtg.test @test0() {
%0 = rtg.get_sequence @seq0 : !rtg.sequence
%1 = rtg.randomize_sequence %0
rtg.embed_sequence %1
Expand Down
6 changes: 2 additions & 4 deletions include/circt/Dialect/RTG/IR/RTGOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ def TestOp : RTGOp<"test", [
Symbol,
SingleBlock,
NoTerminator,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>,
HasParent<"mlir::ModuleOp">
]> {
let summary = "the root of a test";
Expand Down Expand Up @@ -509,10 +510,7 @@ def TestOp : RTGOp<"test", [
TypeAttrOf<DictType>:$target);
let regions = (region SizedRegion<1>:$bodyRegion);

let assemblyFormat = [{
$sym_name `:` $target attr-dict-with-keyword $bodyRegion
}];

let hasCustomAssemblyFormat = 1;
let hasRegionVerifier = 1;
}

Expand Down
7 changes: 3 additions & 4 deletions integration_test/Bindings/Python/dialects/rtg.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
# CHECK: [[V1:%.+]] = rtgtest.cpu_decl <1>
# CHECK: rtg.yield [[V0]], [[V1]] : !rtgtest.cpu, !rtgtest.cpu
# CHECK: }
# CHECK: rtg.test @test_name : !rtg.dict<cpu0: !rtgtest.cpu, cpu1: !rtgtest.cpu> {
# CHECK: ^bb{{.*}}(%{{.*}}: !rtgtest.cpu, %{{.*}}: !rtgtest.cpu):
# CHECK: rtg.test @test_name(cpu0 = %cpu0: !rtgtest.cpu, cpu1 = %cpu1: !rtgtest.cpu) {
# CHECK: }
print(m)

Expand Down Expand Up @@ -62,7 +61,7 @@
seq_get = rtg.GetSequenceOp(rtg.SequenceType.get(), 'sequence_name')
rtg.RandomizeSequenceOp(seq_get)

# CHECK: rtg.test @test_name : !rtg.dict<> {
# CHECK: rtg.test @test_name() {
# CHECK-NEXT: [[SEQ:%.+]] = rtg.get_sequence @sequence_name
# CHECK-NEXT: rtg.randomize_sequence [[SEQ]]
# CHECK-NEXT: }
Expand All @@ -76,7 +75,7 @@
rtgtool.populate_randomizer_pipeline(pm, options)
pm.run(m.operation)

# CHECK: rtg.test @test_name : !rtg.dict<> {
# CHECK: rtg.test @test_name() {
# CHECK-NEXT: }
print(m)

Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/RTG/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ add_circt_dialect_library(CIRCTRTGDialect
CIRCTRTGISAAssemblyTypeInterfacesIncGen
CIRCTRTGOpInterfacesIncGen
CIRCTRTGTypeInterfacesIncGen
CIRCTSupport
MLIRRTGIncGen

LINK_LIBS PUBLIC
Expand Down
93 changes: 93 additions & 0 deletions lib/Dialect/RTG/IR/RTGOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
//===----------------------------------------------------------------------===//

#include "circt/Dialect/RTG/IR/RTGOps.h"
#include "circt/Support/ParsingUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/SmallString.h"

using namespace mlir;
using namespace circt;
Expand Down Expand Up @@ -399,6 +401,97 @@ LogicalResult TestOp::verifyRegions() {
return success();
}

ParseResult TestOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse the name as a symbol.
if (parser.parseSymbolName(
result.getOrAddProperties<TestOp::Properties>().sym_name))
return failure();

// Parse the function signature.
SmallVector<OpAsmParser::Argument> arguments;
SmallVector<StringAttr> names;

auto parseOneArgument = [&]() -> ParseResult {
std::string name;
if (parser.parseKeywordOrString(&name) || parser.parseEqual() ||
parser.parseArgument(arguments.emplace_back(), /*allowType=*/true,
/*allowAttrs=*/true))
return failure();

names.push_back(StringAttr::get(result.getContext(), name));
return success();
};
if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren,
parseOneArgument, " in argument list"))
return failure();

SmallVector<Type> argTypes;
SmallVector<DictEntry> entries;
SmallVector<Location> argLocs;
argTypes.reserve(arguments.size());
argLocs.reserve(arguments.size());
for (auto [name, arg] : llvm::zip(names, arguments)) {
argTypes.push_back(arg.type);
argLocs.push_back(arg.sourceLoc ? *arg.sourceLoc : result.location);
entries.push_back({name, arg.type});
}
auto emitError = [&]() -> InFlightDiagnostic {
return parser.emitError(parser.getCurrentLocation());
};
Type type = DictType::getChecked(emitError, result.getContext(),
ArrayRef<DictEntry>(entries));
if (!type)
return failure();
result.getOrAddProperties<TestOp::Properties>().target = TypeAttr::get(type);

auto loc = parser.getCurrentLocation();
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();
if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
return parser.emitError(loc)
<< "'" << result.name.getStringRef() << "' op ";
})))
return failure();

std::unique_ptr<Region> bodyRegionRegion = std::make_unique<Region>();
if (parser.parseRegion(*bodyRegionRegion, arguments))
return failure();

if (bodyRegionRegion->empty()) {
bodyRegionRegion->emplaceBlock();
bodyRegionRegion->addArguments(argTypes, argLocs);
}
result.addRegion(std::move(bodyRegionRegion));

return success();
}

void TestOp::print(OpAsmPrinter &p) {
p << ' ';
p.printSymbolName(getSymNameAttr().getValue());
p << "(";
SmallString<32> resultNameStr;
llvm::interleaveComma(
llvm::zip(getTarget().getEntries(), getBody()->getArguments()), p,
[&](auto entryAndArg) {
auto [entry, arg] = entryAndArg;
p << entry.name.getValue() << " = ";
p.printRegionArgument(arg);
});
p << ")";
p.printOptionalAttrDictWithKeyword(
(*this)->getAttrs(), {getSymNameAttrName(), getTargetAttrName()});
p << ' ';
p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
}

void TestOp::getAsmBlockArgumentNames(Region &region,
OpAsmSetValueNameFn setNameFn) {
for (auto [entry, arg] :
llvm::zip(getTarget().getEntries(), region.getArguments()))
setNameFn(arg, entry.name.getValue());
}

//===----------------------------------------------------------------------===//
// TargetOp
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion test/CAPI/rtg-pipelines.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ int main(int argc, char **argv) {
ctx, mlirStringRefCreateFromCString(
"rtg.sequence @seq() {\n"
"}\n"
"rtg.test @test : !rtg.dict<> {\n"
"rtg.test @test() {\n"
" %0 = rtg.get_sequence @seq : !rtg.sequence\n"
"}\n"));
if (mlirModuleIsNull(moduleOp)) {
Expand Down
22 changes: 11 additions & 11 deletions test/Dialect/RTG/IR/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ rtg.target @empty_target : !rtg.dict<> {
rtg.yield
}

// CHECK-LABEL: rtg.test @empty_test : !rtg.dict<> {
rtg.test @empty_test : !rtg.dict<> { }
// CHECK-LABEL: rtg.test @empty_test() {
rtg.test @empty_test() { }

// CHECK-LABEL: rtg.target @target : !rtg.dict<num_cpus: i32, num_modes: i32> {
// CHECK: rtg.yield %{{.*}}, %{{.*}} : i32, i32
Expand All @@ -115,19 +115,19 @@ rtg.target @context_switch : !rtg.dict<> {
}

// CHECK-LABEL: @contexts
rtg.test @contexts : !rtg.dict<ctxt0: !rtgtest.cpu> {
^bb0(%arg0: !rtgtest.cpu):
rtg.test @contexts(ctxt0 = %ctxt0: !rtgtest.cpu) {
// CHECK: rtg.on_context {{%.+}}, {{%.+}} : !rtgtest.cpu
%seq = rtg.get_sequence @seq0 : !rtg.sequence
rtg.on_context %arg0, %seq : !rtgtest.cpu
rtg.on_context %ctxt0, %seq : !rtgtest.cpu
}

// CHECK-LABEL: rtg.test @test : !rtg.dict<num_cpus: i32, num_modes: i32> {
// CHECK: ^bb0(%arg0: i32, %arg1: i32):
// CHECK: }
rtg.test @test : !rtg.dict<num_cpus: i32, num_modes: i32> {
^bb0(%arg0: i32, %arg1: i32):
}
// CHECK-LABEL: rtg.test @test0
// CHECK-SAME: (num_cpus = %num_cpus: i32, num_modes = %num_modes: i32) {
rtg.test @test0(num_cpus = %num_cpus: i32, num_modes = %num_modes: i32) { }

// CHECK-LABEL: rtg.test @test1
// CHECK-SAME: (num_cpus = %num_cpus: i32, num_modes = %num_modes: i32) {
rtg.test @test1(num_cpus = %a: i32, num_modes = %b: i32) { }

// CHECK-LABEL: rtg.sequence @integerHandlingOps
rtg.sequence @integerHandlingOps(%arg0: index, %arg1: index) {
Expand Down
13 changes: 4 additions & 9 deletions test/Dialect/RTG/IR/errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -83,29 +83,24 @@ rtg.target @target : !rtg.dict<a: i32> {
// -----

// expected-error @below {{argument types must match dict entry types}}
rtg.test @test : !rtg.dict<a: i32> {
}
"rtg.test"() <{sym_name="test", target=!rtg.dict<a: i32>}> ({^bb0(%b: i8):}) : () -> ()

// -----

// expected-error @below {{dictionary must be sorted by names and contain no duplicates, first violation at entry 'a'}}
rtg.test @test : !rtg.dict<a: i32, a: i32> {
^bb0(%arg0: i32, %arg1: i32):
rtg.test @test(a = %a: i32, a = %a: i32) {
}

// -----

// expected-error @below {{dictionary must be sorted by names and contain no duplicates, first violation at entry 'a'}}
rtg.test @test : !rtg.dict<b: i32, a: i32> {
^bb0(%arg0: i32, %arg1: i32):
rtg.test @test(b = %b: i32, a = %a: i32) {
}

// -----

// expected-error @below {{empty strings not allowed as entry names}}
rtg.test @test : !rtg.dict<"": i32> {
^bb0(%arg0: i32):
}
rtg.test @test(dict = %dict: !rtg.dict<"": i32>) { }

// -----

Expand Down
Loading