Skip to content

Commit 540d7dd

Browse files
authored
[mlir][py] Plumb OpPrintingFlags::printNameLocAsPrefix() through the C/Python APIs (#129607)
1 parent d38380d commit 540d7dd

File tree

5 files changed

+26
-7
lines changed

5 files changed

+26
-7
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,10 @@ mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable,
456456
MLIR_CAPI_EXPORTED void
457457
mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags);
458458

459+
/// Print the name and location, if NamedLoc, as a prefix to the SSA ID.
460+
MLIR_CAPI_EXPORTED void
461+
mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags);
462+
459463
/// Use local scope when printing the operation. This allows for using the
460464
/// printer in a more localized and thread-safe setting, but may not
461465
/// necessarily be identical to what the IR will look like when dumping

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,8 +1291,9 @@ void PyOperation::checkValid() const {
12911291
void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
12921292
bool enableDebugInfo, bool prettyDebugInfo,
12931293
bool printGenericOpForm, bool useLocalScope,
1294-
bool assumeVerified, nb::object fileObject,
1295-
bool binary, bool skipRegions) {
1294+
bool useNameLocAsPrefix, bool assumeVerified,
1295+
nb::object fileObject, bool binary,
1296+
bool skipRegions) {
12961297
PyOperation &operation = getOperation();
12971298
operation.checkValid();
12981299
if (fileObject.is_none())
@@ -1314,6 +1315,8 @@ void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
13141315
mlirOpPrintingFlagsAssumeVerified(flags);
13151316
if (skipRegions)
13161317
mlirOpPrintingFlagsSkipRegions(flags);
1318+
if (useNameLocAsPrefix)
1319+
mlirOpPrintingFlagsPrintNameLocAsPrefix(flags);
13171320

13181321
PyFileAccumulator accum(fileObject, binary);
13191322
mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
@@ -1390,7 +1393,8 @@ nb::object PyOperationBase::getAsm(bool binary,
13901393
std::optional<int64_t> largeElementsLimit,
13911394
bool enableDebugInfo, bool prettyDebugInfo,
13921395
bool printGenericOpForm, bool useLocalScope,
1393-
bool assumeVerified, bool skipRegions) {
1396+
bool useNameLocAsPrefix, bool assumeVerified,
1397+
bool skipRegions) {
13941398
nb::object fileObject;
13951399
if (binary) {
13961400
fileObject = nb::module_::import_("io").attr("BytesIO")();
@@ -1402,6 +1406,7 @@ nb::object PyOperationBase::getAsm(bool binary,
14021406
/*prettyDebugInfo=*/prettyDebugInfo,
14031407
/*printGenericOpForm=*/printGenericOpForm,
14041408
/*useLocalScope=*/useLocalScope,
1409+
/*useNameLocAsPrefix=*/useNameLocAsPrefix,
14051410
/*assumeVerified=*/assumeVerified,
14061411
/*fileObject=*/fileObject,
14071412
/*binary=*/binary,
@@ -3195,6 +3200,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
31953200
/*prettyDebugInfo=*/false,
31963201
/*printGenericOpForm=*/false,
31973202
/*useLocalScope=*/false,
3203+
/*useNameLocAsPrefix=*/false,
31983204
/*assumeVerified=*/false,
31993205
/*skipRegions=*/false);
32003206
},
@@ -3206,14 +3212,15 @@ void mlir::python::populateIRCore(nb::module_ &m) {
32063212
nb::arg("binary") = false, kOperationPrintStateDocstring)
32073213
.def("print",
32083214
nb::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
3209-
bool, nb::object, bool, bool>(
3215+
bool, bool, nb::object, bool, bool>(
32103216
&PyOperationBase::print),
32113217
// Careful: Lots of arguments must match up with print method.
32123218
nb::arg("large_elements_limit").none() = nb::none(),
32133219
nb::arg("enable_debug_info") = false,
32143220
nb::arg("pretty_debug_info") = false,
32153221
nb::arg("print_generic_op_form") = false,
32163222
nb::arg("use_local_scope") = false,
3223+
nb::arg("use_name_loc_as_prefix") = false,
32173224
nb::arg("assume_verified") = false,
32183225
nb::arg("file").none() = nb::none(), nb::arg("binary") = false,
32193226
nb::arg("skip_regions") = false, kOperationPrintDocstring)
@@ -3228,6 +3235,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
32283235
nb::arg("pretty_debug_info") = false,
32293236
nb::arg("print_generic_op_form") = false,
32303237
nb::arg("use_local_scope") = false,
3238+
nb::arg("use_name_loc_as_prefix") = false,
32313239
nb::arg("assume_verified") = false, nb::arg("skip_regions") = false,
32323240
kOperationGetAsmDocstring)
32333241
.def("verify", &PyOperationBase::verify,

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -576,15 +576,16 @@ class PyOperationBase {
576576
/// Implements the bound 'print' method and helps with others.
577577
void print(std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
578578
bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
579-
bool assumeVerified, nanobind::object fileObject, bool binary,
580-
bool skipRegions);
579+
bool useNameLocAsPrefix, bool assumeVerified,
580+
nanobind::object fileObject, bool binary, bool skipRegions);
581581
void print(PyAsmState &state, nanobind::object fileObject, bool binary);
582582

583583
nanobind::object getAsm(bool binary,
584584
std::optional<int64_t> largeElementsLimit,
585585
bool enableDebugInfo, bool prettyDebugInfo,
586586
bool printGenericOpForm, bool useLocalScope,
587-
bool assumeVerified, bool skipRegions);
587+
bool useNameLocAsPrefix, bool assumeVerified,
588+
bool skipRegions);
588589

589590
// Implement the bound 'writeBytecode' method.
590591
void writeBytecode(const nanobind::object &fileObject,

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) {
218218
unwrap(flags)->printGenericOpForm();
219219
}
220220

221+
void mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags) {
222+
unwrap(flags)->printNameLocAsPrefix();
223+
}
224+
221225
void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
222226
unwrap(flags)->useLocalScope();
223227
}

mlir/test/python/ir/operation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,8 @@ def testOperationPrint():
643643
# Test print local_scope.
644644
# CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
645645
module.operation.print(enable_debug_info=True, use_local_scope=True)
646+
# CHECK: %nom = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
647+
module.operation.print(use_name_loc_as_prefix=True, use_local_scope=True)
646648

647649
# Test printing using state.
648650
state = AsmState(module.operation)

0 commit comments

Comments
 (0)