Skip to content

Commit

Permalink
[mlir][py] Plumb OpPrintingFlags::printNameLocAsPrefix() through the …
Browse files Browse the repository at this point in the history
…C/Python APIs (#129607)
  • Loading branch information
jpienaar authored Mar 4, 2025
1 parent d38380d commit 540d7dd
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 7 deletions.
4 changes: 4 additions & 0 deletions mlir/include/mlir-c/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,10 @@ mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable,
MLIR_CAPI_EXPORTED void
mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags);

/// Print the name and location, if NamedLoc, as a prefix to the SSA ID.
MLIR_CAPI_EXPORTED void
mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags);

/// Use local scope when printing the operation. This allows for using the
/// printer in a more localized and thread-safe setting, but may not
/// necessarily be identical to what the IR will look like when dumping
Expand Down
16 changes: 12 additions & 4 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1291,8 +1291,9 @@ void PyOperation::checkValid() const {
void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
bool printGenericOpForm, bool useLocalScope,
bool assumeVerified, nb::object fileObject,
bool binary, bool skipRegions) {
bool useNameLocAsPrefix, bool assumeVerified,
nb::object fileObject, bool binary,
bool skipRegions) {
PyOperation &operation = getOperation();
operation.checkValid();
if (fileObject.is_none())
Expand All @@ -1314,6 +1315,8 @@ void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
mlirOpPrintingFlagsAssumeVerified(flags);
if (skipRegions)
mlirOpPrintingFlagsSkipRegions(flags);
if (useNameLocAsPrefix)
mlirOpPrintingFlagsPrintNameLocAsPrefix(flags);

PyFileAccumulator accum(fileObject, binary);
mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
Expand Down Expand Up @@ -1390,7 +1393,8 @@ nb::object PyOperationBase::getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
bool printGenericOpForm, bool useLocalScope,
bool assumeVerified, bool skipRegions) {
bool useNameLocAsPrefix, bool assumeVerified,
bool skipRegions) {
nb::object fileObject;
if (binary) {
fileObject = nb::module_::import_("io").attr("BytesIO")();
Expand All @@ -1402,6 +1406,7 @@ nb::object PyOperationBase::getAsm(bool binary,
/*prettyDebugInfo=*/prettyDebugInfo,
/*printGenericOpForm=*/printGenericOpForm,
/*useLocalScope=*/useLocalScope,
/*useNameLocAsPrefix=*/useNameLocAsPrefix,
/*assumeVerified=*/assumeVerified,
/*fileObject=*/fileObject,
/*binary=*/binary,
Expand Down Expand Up @@ -3195,6 +3200,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
/*prettyDebugInfo=*/false,
/*printGenericOpForm=*/false,
/*useLocalScope=*/false,
/*useNameLocAsPrefix=*/false,
/*assumeVerified=*/false,
/*skipRegions=*/false);
},
Expand All @@ -3206,14 +3212,15 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("binary") = false, kOperationPrintStateDocstring)
.def("print",
nb::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
bool, nb::object, bool, bool>(
bool, bool, nb::object, bool, bool>(
&PyOperationBase::print),
// Careful: Lots of arguments must match up with print method.
nb::arg("large_elements_limit").none() = nb::none(),
nb::arg("enable_debug_info") = false,
nb::arg("pretty_debug_info") = false,
nb::arg("print_generic_op_form") = false,
nb::arg("use_local_scope") = false,
nb::arg("use_name_loc_as_prefix") = false,
nb::arg("assume_verified") = false,
nb::arg("file").none() = nb::none(), nb::arg("binary") = false,
nb::arg("skip_regions") = false, kOperationPrintDocstring)
Expand All @@ -3228,6 +3235,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("pretty_debug_info") = false,
nb::arg("print_generic_op_form") = false,
nb::arg("use_local_scope") = false,
nb::arg("use_name_loc_as_prefix") = false,
nb::arg("assume_verified") = false, nb::arg("skip_regions") = false,
kOperationGetAsmDocstring)
.def("verify", &PyOperationBase::verify,
Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -576,15 +576,16 @@ class PyOperationBase {
/// Implements the bound 'print' method and helps with others.
void print(std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
bool assumeVerified, nanobind::object fileObject, bool binary,
bool skipRegions);
bool useNameLocAsPrefix, bool assumeVerified,
nanobind::object fileObject, bool binary, bool skipRegions);
void print(PyAsmState &state, nanobind::object fileObject, bool binary);

nanobind::object getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
bool printGenericOpForm, bool useLocalScope,
bool assumeVerified, bool skipRegions);
bool useNameLocAsPrefix, bool assumeVerified,
bool skipRegions);

// Implement the bound 'writeBytecode' method.
void writeBytecode(const nanobind::object &fileObject,
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/CAPI/IR/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) {
unwrap(flags)->printGenericOpForm();
}

void mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags) {
unwrap(flags)->printNameLocAsPrefix();
}

void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
unwrap(flags)->useLocalScope();
}
Expand Down
2 changes: 2 additions & 0 deletions mlir/test/python/ir/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,8 @@ def testOperationPrint():
# Test print local_scope.
# CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
module.operation.print(enable_debug_info=True, use_local_scope=True)
# CHECK: %nom = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
module.operation.print(use_name_loc_as_prefix=True, use_local_scope=True)

# Test printing using state.
state = AsmState(module.operation)
Expand Down

0 comments on commit 540d7dd

Please sign in to comment.