Skip to content

Commit

Permalink
[mlir][CAPI][python] bind CallSiteLoc, FileLineColRange, FusedLoc, Na…
Browse files Browse the repository at this point in the history
…meLoc, OpaqueLoc
  • Loading branch information
makslevental committed Mar 9, 2025
1 parent 7612dcc commit f595ba3
Show file tree
Hide file tree
Showing 5 changed files with 364 additions and 53 deletions.
80 changes: 80 additions & 0 deletions mlir/include/mlir-c/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,22 +261,96 @@ MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColRangeGet(
MlirContext context, MlirStringRef filename, unsigned start_line,
unsigned start_col, unsigned end_line, unsigned end_col);

/// Getter for filename of FileLineColRange.
MLIR_CAPI_EXPORTED MlirIdentifier
mlirLocationFileLineColRangeGetFilename(MlirLocation location);

/// Getter for start_line of FileLineColRange.
MLIR_CAPI_EXPORTED int
mlirLocationFileLineColRangeGetStartLine(MlirLocation location);

/// Getter for start_column of FileLineColRange.
MLIR_CAPI_EXPORTED int
mlirLocationFileLineColRangeGetStartColumn(MlirLocation location);

/// Getter for end_line of FileLineColRange.
MLIR_CAPI_EXPORTED int
mlirLocationFileLineColRangeGetEndLine(MlirLocation location);

/// Getter for end_column of FileLineColRange.
MLIR_CAPI_EXPORTED int
mlirLocationFileLineColRangeGetEndColumn(MlirLocation location);

/// TypeID Getter for FileLineColRange.
MLIR_CAPI_EXPORTED MlirTypeID mlirLocationFileLineColRangeGetTypeID(void);

/// Checks whether the given location is an FileLineColRange.
MLIR_CAPI_EXPORTED bool mlirLocationIsAFileLineColRange(MlirLocation location);

/// Creates a call site location with a callee and a caller.
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee,
MlirLocation caller);

/// Getter for callee of CallSite.
MLIR_CAPI_EXPORTED MlirLocation
mlirLocationCallSiteGetCallee(MlirLocation location);

/// Getter for caller of CallSite.
MLIR_CAPI_EXPORTED MlirLocation
mlirLocationCallSiteGetCaller(MlirLocation location);

/// TypeID Getter for CallSite.
MLIR_CAPI_EXPORTED MlirTypeID mlirLocationCallSiteGetTypeID(void);

/// Checks whether the given location is an CallSite.
MLIR_CAPI_EXPORTED bool mlirLocationIsACallSite(MlirLocation location);

/// Creates a fused location with an array of locations and metadata.
MLIR_CAPI_EXPORTED MlirLocation
mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
MlirLocation const *locations, MlirAttribute metadata);

/// Getter for number of locations fused together.
MLIR_CAPI_EXPORTED unsigned
mlirLocationFusedGetNumLocations(MlirLocation location);

/// Getter for locations of Fused. Requires pre-allocated memory of
/// #fusedLocations X sizeof(MlirLocation).
MLIR_CAPI_EXPORTED void
mlirLocationFusedGetLocations(MlirLocation location,
MlirLocation *locationsCPtr);

/// Getter for metadata of Fused.
MLIR_CAPI_EXPORTED MlirAttribute
mlirLocationFusedGetMetadata(MlirLocation location);

/// TypeID Getter for Fused.
MLIR_CAPI_EXPORTED MlirTypeID mlirLocationFusedGetTypeID(void);

/// Checks whether the given location is an Fused.
MLIR_CAPI_EXPORTED bool mlirLocationIsAFused(MlirLocation location);

/// Creates a name location owned by the given context. Providing null location
/// for childLoc is allowed and if childLoc is null location, then the behavior
/// is the same as having unknown child location.
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context,
MlirStringRef name,
MlirLocation childLoc);

/// Getter for name of Name.
MLIR_CAPI_EXPORTED MlirIdentifier
mlirLocationNameGetName(MlirLocation location);

/// Getter for childLoc of Name.
MLIR_CAPI_EXPORTED MlirLocation
mlirLocationNameGetChildLoc(MlirLocation location);

/// TypeID Getter for Name.
MLIR_CAPI_EXPORTED MlirTypeID mlirLocationNameGetTypeID(void);

/// Checks whether the given location is an Name.
MLIR_CAPI_EXPORTED bool mlirLocationIsAName(MlirLocation location);

/// Creates a location with unknown position owned by the given context.
MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context);

Expand Down Expand Up @@ -970,6 +1044,12 @@ mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with,
intptr_t numExceptions,
MlirOperation *exceptions);

/// Gets the location of the value.
MLIR_CAPI_EXPORTED MlirLocation mlirValueGetLocation(MlirValue v);

/// Gets the context that a value was created with.
MLIR_CAPI_EXPORTED MlirContext mlirValueGetContext(MlirValue v);

//===----------------------------------------------------------------------===//
// OpOperand API.
//===----------------------------------------------------------------------===//
Expand Down
12 changes: 11 additions & 1 deletion mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@

#include <cstdint>

#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
#include "llvm/ADT/Twine.h"

// Raw CAPI type casters need to be declared before use, so always include them
Expand Down Expand Up @@ -319,6 +319,16 @@ struct type_caster<MlirType> {
}
};

/// Casts MlirStringRef -> object.
template <>
struct type_caster<MlirStringRef> {
NB_TYPE_CASTER(MlirStringRef, const_name("MlirStringRef"))
static handle from_cpp(MlirStringRef s, rv_policy,
cleanup_list *cleanup) noexcept {
return nanobind::str(s.data, s.length).release();
}
};

} // namespace detail
} // namespace nanobind

Expand Down
51 changes: 43 additions & 8 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
#include "Globals.h"
#include "IRModule.h"
#include "NanobindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Debug.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"

Expand Down Expand Up @@ -299,7 +299,7 @@ struct PyAttrBuilderMap {
return *builder;
}
static void dunderSetItemNamed(const std::string &attributeKind,
nb::callable func, bool replace) {
nb::callable func, bool replace) {
PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
replace);
}
Expand Down Expand Up @@ -2933,6 +2933,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("callee"), nb::arg("frames"),
nb::arg("context").none() = nb::none(),
kContextGetCallSiteLocationDocstring)
.def("is_a_callsite", mlirLocationIsACallSite)
.def_prop_ro("callee", mlirLocationCallSiteGetCallee)
.def_prop_ro("caller", mlirLocationCallSiteGetCaller)
.def_static(
"file",
[](std::string filename, int line, int col,
Expand All @@ -2957,6 +2960,16 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
nb::arg("end_line"), nb::arg("end_col"),
nb::arg("context").none() = nb::none(), kContextGetFileRangeDocstring)
.def("is_a_file", mlirLocationIsAFileLineColRange)
.def_prop_ro("filename",
[](MlirLocation loc) {
return mlirIdentifierStr(
mlirLocationFileLineColRangeGetFilename(loc));
})
.def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine)
.def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn)
.def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine)
.def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn)
.def_static(
"fused",
[](const std::vector<PyLocation> &pyLocations,
Expand All @@ -2974,6 +2987,16 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("locations"), nb::arg("metadata").none() = nb::none(),
nb::arg("context").none() = nb::none(),
kContextGetFusedLocationDocstring)
.def("is_a_fused", mlirLocationIsAFused)
.def_prop_ro("locations",
[](MlirLocation loc) {
unsigned numLocations =
mlirLocationFusedGetNumLocations(loc);
std::vector<MlirLocation> locations(numLocations);
if (numLocations)
mlirLocationFusedGetLocations(loc, locations.data());
return locations;
})
.def_static(
"name",
[](std::string name, std::optional<PyLocation> childLoc,
Expand All @@ -2988,6 +3011,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("name"), nb::arg("childLoc").none() = nb::none(),
nb::arg("context").none() = nb::none(),
kContextGetNameLocationDocString)
.def("is_a_name", mlirLocationIsAName)
.def_prop_ro("name_str",
[](MlirLocation loc) {
return mlirIdentifierStr(mlirLocationNameGetName(loc));
})
.def_prop_ro("child_loc", mlirLocationNameGetChildLoc)
.def_static(
"from_attr",
[](PyAttribute &attribute, DefaultingPyMlirContext context) {
Expand Down Expand Up @@ -3126,9 +3155,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
auto &concreteOperation = self.getOperation();
concreteOperation.checkValid();
MlirOperation operation = concreteOperation.get();
MlirStringRef name =
mlirIdentifierStr(mlirOperationGetName(operation));
return nb::str(name.data, name.length);
return mlirIdentifierStr(mlirOperationGetName(operation));
})
.def_prop_ro("operands",
[](PyOperationBase &self) {
Expand Down Expand Up @@ -3713,8 +3740,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
.def_prop_ro(
"name",
[](PyNamedAttribute &self) {
return nb::str(mlirIdentifierStr(self.namedAttr.name).data,
mlirIdentifierStr(self.namedAttr.name).length);
return mlirIdentifierStr(self.namedAttr.name);
},
"The name of the NamedAttribute binding")
.def_prop_ro(
Expand Down Expand Up @@ -3947,7 +3973,16 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("with"), nb::arg("exceptions"),
kValueReplaceAllUsesExceptDocstring)
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
[](PyValue &self) { return self.maybeDownCast(); });
[](PyValue &self) { return self.maybeDownCast(); })
.def_prop_ro(
"location",
[](MlirValue self) {
return PyLocation(
PyMlirContext::forContext(mlirValueGetContext(self)),
mlirValueGetLocation(self));
},
"Returns the source location the value");

PyBlockArgument::bind(m);
PyOpResult::bind(m);
PyOpOperand::bind(m);
Expand Down
100 changes: 99 additions & 1 deletion mlir/lib/CAPI/IR/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ MlirAttribute mlirLocationGetAttribute(MlirLocation location) {
}

MlirLocation mlirLocationFromAttribute(MlirAttribute attribute) {
return wrap(Location(llvm::cast<LocationAttr>(unwrap(attribute))));
return wrap(Location(llvm::dyn_cast<LocationAttr>(unwrap(attribute))));
}

MlirLocation mlirLocationFileLineColGet(MlirContext context,
Expand All @@ -273,10 +273,62 @@ mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename,
startLine, startCol, endLine, endCol)));
}

MlirIdentifier mlirLocationFileLineColRangeGetFilename(MlirLocation location) {
return wrap(llvm::dyn_cast<FileLineColRange>(unwrap(location)).getFilename());
}

int mlirLocationFileLineColRangeGetStartLine(MlirLocation location) {
if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
return loc.getStartLine();
return -1;
}

int mlirLocationFileLineColRangeGetStartColumn(MlirLocation location) {
if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
return loc.getStartColumn();
return -1;
}

int mlirLocationFileLineColRangeGetEndLine(MlirLocation location) {
if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
return loc.getEndLine();
return -1;
}

int mlirLocationFileLineColRangeGetEndColumn(MlirLocation location) {
if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
return loc.getEndColumn();
return -1;
}

MlirTypeID mlirLocationFileLineColRangeGetTypeID() {
return wrap(FileLineColRange::getTypeID());
}

bool mlirLocationIsAFileLineColRange(MlirLocation location) {
return isa<FileLineColRange>(unwrap(location));
}

MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) {
return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller))));
}

MlirLocation mlirLocationCallSiteGetCallee(MlirLocation location) {
return wrap(Location(llvm::cast<CallSiteLoc>(unwrap(location)).getCallee()));
}

MlirLocation mlirLocationCallSiteGetCaller(MlirLocation location) {
return wrap(Location(llvm::cast<CallSiteLoc>(unwrap(location)).getCaller()));
}

MlirTypeID mlirLocationCallSiteGetTypeID() {
return wrap(CallSiteLoc::getTypeID());
}

bool mlirLocationIsACallSite(MlirLocation location) {
return isa<CallSiteLoc>(unwrap(location));
}

MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
MlirLocation const *locations,
MlirAttribute metadata) {
Expand All @@ -285,6 +337,30 @@ MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx)));
}

unsigned mlirLocationFusedGetNumLocations(MlirLocation location) {
if (auto locationsArrRef = llvm::dyn_cast<FusedLoc>(unwrap(location)))
return locationsArrRef.getLocations().size();
return 0;
}

void mlirLocationFusedGetLocations(MlirLocation location,
MlirLocation *locationsCPtr) {
if (auto locationsArrRef = llvm::dyn_cast<FusedLoc>(unwrap(location))) {
for (auto [i, location] : llvm::enumerate(locationsArrRef.getLocations()))
locationsCPtr[i] = wrap(location);
}
}

MlirAttribute mlirLocationFusedGetMetadata(MlirLocation location) {
return wrap(llvm::cast<FusedLoc>(unwrap(location)).getMetadata());
}

MlirTypeID mlirLocationFusedGetTypeID() { return wrap(FusedLoc::getTypeID()); }

bool mlirLocationIsAFused(MlirLocation location) {
return isa<FusedLoc>(unwrap(location));
}

MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name,
MlirLocation childLoc) {
if (mlirLocationIsNull(childLoc))
Expand All @@ -294,6 +370,20 @@ MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name,
StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc))));
}

MlirIdentifier mlirLocationNameGetName(MlirLocation location) {
return wrap((llvm::cast<NameLoc>(unwrap(location)).getName()));
}

MlirLocation mlirLocationNameGetChildLoc(MlirLocation location) {
return wrap(Location(llvm::cast<NameLoc>(unwrap(location)).getChildLoc()));
}

MlirTypeID mlirLocationNameGetTypeID() { return wrap(NameLoc::getTypeID()); }

bool mlirLocationIsAName(MlirLocation location) {
return isa<NameLoc>(unwrap(location));
}

MlirLocation mlirLocationUnknownGet(MlirContext context) {
return wrap(Location(UnknownLoc::get(unwrap(context))));
}
Expand Down Expand Up @@ -1033,6 +1123,14 @@ void mlirValueReplaceAllUsesExcept(MlirValue oldValue, MlirValue newValue,
oldValueCpp.replaceAllUsesExcept(newValueCpp, exceptionSet);
}

MlirLocation mlirValueGetLocation(MlirValue v) {
return wrap(unwrap(v).getLoc());
}

MlirContext mlirValueGetContext(MlirValue v) {
return wrap(unwrap(v).getContext());
}

//===----------------------------------------------------------------------===//
// OpOperand API.
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit f595ba3

Please sign in to comment.