From 198c731ba747cb87b972feecb39d96c1d0fedec7 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Mon, 3 Mar 2025 23:11:46 +0000 Subject: [PATCH] [mlir][py] Plumb through name loc as prefix --- mlir/include/mlir-c/IR.h | 4 ++++ mlir/lib/Bindings/Python/IRCore.cpp | 16 ++++++++++++---- mlir/lib/Bindings/Python/IRModule.h | 7 ++++--- mlir/lib/CAPI/IR/IR.cpp | 4 ++++ mlir/test/python/ir/operation.py | 2 ++ 5 files changed, 26 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 14ccae650606a..d562da1f90757 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -456,6 +456,10 @@ mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags); +/// Print the name and location, if NamedLoc, as a prefix to the SSA ID. +MLIR_CAPI_EXPORTED void +mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags); + /// Use local scope when printing the operation. This allows for using the /// printer in a more localized and thread-safe setting, but may not /// necessarily be identical to what the IR will look like when dumping diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index b13a429d4a3c0..12793f7dd15be 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1291,8 +1291,9 @@ void PyOperation::checkValid() const { void PyOperationBase::print(std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified, nb::object fileObject, - bool binary, bool skipRegions) { + bool useNameLocAsPrefix, bool assumeVerified, + nb::object fileObject, bool binary, + bool skipRegions) { PyOperation &operation = getOperation(); operation.checkValid(); if (fileObject.is_none()) @@ -1314,6 +1315,8 @@ void PyOperationBase::print(std::optional largeElementsLimit, mlirOpPrintingFlagsAssumeVerified(flags); if (skipRegions) mlirOpPrintingFlagsSkipRegions(flags); + if (useNameLocAsPrefix) + mlirOpPrintingFlagsPrintNameLocAsPrefix(flags); PyFileAccumulator accum(fileObject, binary); mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), @@ -1390,7 +1393,8 @@ nb::object PyOperationBase::getAsm(bool binary, std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified, bool skipRegions) { + bool useNameLocAsPrefix, bool assumeVerified, + bool skipRegions) { nb::object fileObject; if (binary) { fileObject = nb::module_::import_("io").attr("BytesIO")(); @@ -1402,6 +1406,7 @@ nb::object PyOperationBase::getAsm(bool binary, /*prettyDebugInfo=*/prettyDebugInfo, /*printGenericOpForm=*/printGenericOpForm, /*useLocalScope=*/useLocalScope, + /*useNameLocAsPrefix=*/useNameLocAsPrefix, /*assumeVerified=*/assumeVerified, /*fileObject=*/fileObject, /*binary=*/binary, @@ -3195,6 +3200,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { /*prettyDebugInfo=*/false, /*printGenericOpForm=*/false, /*useLocalScope=*/false, + /*useNameLocAsPrefix=*/false, /*assumeVerified=*/false, /*skipRegions=*/false); }, @@ -3206,7 +3212,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("binary") = false, kOperationPrintStateDocstring) .def("print", nb::overload_cast, bool, bool, bool, bool, - bool, nb::object, bool, bool>( + bool, bool, nb::object, bool, bool>( &PyOperationBase::print), // Careful: Lots of arguments must match up with print method. nb::arg("large_elements_limit").none() = nb::none(), @@ -3214,6 +3220,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("pretty_debug_info") = false, nb::arg("print_generic_op_form") = false, nb::arg("use_local_scope") = false, + nb::arg("use_name_loc_as_prefix") = false, nb::arg("assume_verified") = false, nb::arg("file").none() = nb::none(), nb::arg("binary") = false, nb::arg("skip_regions") = false, kOperationPrintDocstring) @@ -3228,6 +3235,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("pretty_debug_info") = false, nb::arg("print_generic_op_form") = false, nb::arg("use_local_scope") = false, + nb::arg("use_name_loc_as_prefix") = false, nb::arg("assume_verified") = false, nb::arg("skip_regions") = false, kOperationGetAsmDocstring) .def("verify", &PyOperationBase::verify, diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index dd6e7ef912374..1ed6240a6ca69 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -576,15 +576,16 @@ class PyOperationBase { /// Implements the bound 'print' method and helps with others. void print(std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified, nanobind::object fileObject, bool binary, - bool skipRegions); + bool useNameLocAsPrefix, bool assumeVerified, + nanobind::object fileObject, bool binary, bool skipRegions); void print(PyAsmState &state, nanobind::object fileObject, bool binary); nanobind::object getAsm(bool binary, std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified, bool skipRegions); + bool useNameLocAsPrefix, bool assumeVerified, + bool skipRegions); // Implement the bound 'writeBytecode' method. void writeBytecode(const nanobind::object &fileObject, diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 999e8cbda1295..6cd9ba2aef233 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -218,6 +218,10 @@ void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) { unwrap(flags)->printGenericOpForm(); } +void mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags) { + unwrap(flags)->printNameLocAsPrefix(); +} + void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) { unwrap(flags)->useLocalScope(); } diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index 090d0030fb062..dd2731ba2e1f1 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -643,6 +643,8 @@ def testOperationPrint(): # Test print local_scope. # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom") module.operation.print(enable_debug_info=True, use_local_scope=True) + # CHECK: %nom = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> + module.operation.print(use_name_loc_as_prefix=True, use_local_scope=True) # Test printing using state. state = AsmState(module.operation)