Skip to content

Commit

Permalink
[mlir][CAPI][python] bind CallSiteLoc, FileLineColRange, FusedLoc, Na…
Browse files Browse the repository at this point in the history
…meLoc, OpaqueLoc
  • Loading branch information
makslevental committed Mar 9, 2025
1 parent 667bbd2 commit 5fe8746
Show file tree
Hide file tree
Showing 5 changed files with 368 additions and 55 deletions.
80 changes: 80 additions & 0 deletions mlir/include/mlir-c/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,22 +261,96 @@ MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColRangeGet(
MlirContext context, MlirStringRef filename, unsigned start_line,
unsigned start_col, unsigned end_line, unsigned end_col);

/// Getter for filename of FileLineColRange.
MLIR_CAPI_EXPORTED MlirIdentifier
mlirLocationFileLineColRangeGetFilename(MlirLocation location);

/// Getter for start_line of FileLineColRange.
MLIR_CAPI_EXPORTED int
mlirLocationFileLineColRangeGetStartLine(MlirLocation location);

/// Getter for start_column of FileLineColRange.
MLIR_CAPI_EXPORTED int
mlirLocationFileLineColRangeGetStartColumn(MlirLocation location);

/// Getter for end_line of FileLineColRange.
MLIR_CAPI_EXPORTED int
mlirLocationFileLineColRangeGetEndLine(MlirLocation location);

/// Getter for end_column of FileLineColRange.
MLIR_CAPI_EXPORTED int
mlirLocationFileLineColRangeGetEndColumn(MlirLocation location);

/// TypeID Getter for FileLineColRange.
MLIR_CAPI_EXPORTED MlirTypeID mlirLocationFileLineColRangeGetTypeID(void);

/// Checks whether the given location is an FileLineColRange.
MLIR_CAPI_EXPORTED bool mlirLocationIsAFileLineColRange(MlirLocation location);

/// Creates a call site location with a callee and a caller.
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee,
MlirLocation caller);

/// Getter for callee of CallSite.
MLIR_CAPI_EXPORTED MlirLocation
mlirLocationCallSiteGetCallee(MlirLocation location);

/// Getter for caller of CallSite.
MLIR_CAPI_EXPORTED MlirLocation
mlirLocationCallSiteGetCaller(MlirLocation location);

/// TypeID Getter for CallSite.
MLIR_CAPI_EXPORTED MlirTypeID mlirLocationCallSiteGetTypeID(void);

/// Checks whether the given location is an CallSite.
MLIR_CAPI_EXPORTED bool mlirLocationIsACallSite(MlirLocation location);

/// Creates a fused location with an array of locations and metadata.
MLIR_CAPI_EXPORTED MlirLocation
mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
MlirLocation const *locations, MlirAttribute metadata);

/// Getter for number of locations fused together.
MLIR_CAPI_EXPORTED unsigned
mlirLocationFusedGetNumLocations(MlirLocation location);

/// Getter for locations of Fused. Requires pre-allocated memory of
/// #fusedLocations X sizeof(MlirLocation).
MLIR_CAPI_EXPORTED void
mlirLocationFusedGetLocations(MlirLocation location,
MlirLocation *locationsCPtr);

/// Getter for metadata of Fused.
MLIR_CAPI_EXPORTED MlirAttribute
mlirLocationFusedGetMetadata(MlirLocation location);

/// TypeID Getter for Fused.
MLIR_CAPI_EXPORTED MlirTypeID mlirLocationFusedGetTypeID(void);

/// Checks whether the given location is an Fused.
MLIR_CAPI_EXPORTED bool mlirLocationIsAFused(MlirLocation location);

/// Creates a name location owned by the given context. Providing null location
/// for childLoc is allowed and if childLoc is null location, then the behavior
/// is the same as having unknown child location.
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context,
MlirStringRef name,
MlirLocation childLoc);

/// Getter for name of Name.
MLIR_CAPI_EXPORTED MlirIdentifier
mlirLocationNameGetName(MlirLocation location);

/// Getter for childLoc of Name.
MLIR_CAPI_EXPORTED MlirLocation
mlirLocationNameGetChildLoc(MlirLocation location);

/// TypeID Getter for Name.
MLIR_CAPI_EXPORTED MlirTypeID mlirLocationNameGetTypeID(void);

/// Checks whether the given location is an Name.
MLIR_CAPI_EXPORTED bool mlirLocationIsAName(MlirLocation location);

/// Creates a location with unknown position owned by the given context.
MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context);

Expand Down Expand Up @@ -978,6 +1052,12 @@ mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with,
intptr_t numExceptions,
MlirOperation *exceptions);

/// Gets the location of the value.
MLIR_CAPI_EXPORTED MlirLocation mlirValueGetLocation(MlirValue v);

/// Gets the context that a value was created with.
MLIR_CAPI_EXPORTED MlirContext mlirValueGetContext(MlirValue v);

//===----------------------------------------------------------------------===//
// OpOperand API.
//===----------------------------------------------------------------------===//
Expand Down
10 changes: 10 additions & 0 deletions mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,16 @@ struct type_caster<MlirType> {
}
};

/// Casts MlirStringRef -> object.
template <>
struct type_caster<MlirStringRef> {
NB_TYPE_CASTER(MlirStringRef, const_name("MlirStringRef"))
static handle from_cpp(MlirStringRef s, rv_policy,
cleanup_list *cleanup) noexcept {
return nanobind::str(s.data, s.length).release();
}
};

} // namespace detail
} // namespace nanobind

Expand Down
47 changes: 41 additions & 6 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2943,6 +2943,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("callee"), nb::arg("frames"),
nb::arg("context").none() = nb::none(),
kContextGetCallSiteLocationDocstring)
.def("is_a_callsite", mlirLocationIsACallSite)
.def_prop_ro("callee", mlirLocationCallSiteGetCallee)
.def_prop_ro("caller", mlirLocationCallSiteGetCaller)
.def_static(
"file",
[](std::string filename, int line, int col,
Expand All @@ -2967,6 +2970,16 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
nb::arg("end_line"), nb::arg("end_col"),
nb::arg("context").none() = nb::none(), kContextGetFileRangeDocstring)
.def("is_a_file", mlirLocationIsAFileLineColRange)
.def_prop_ro("filename",
[](MlirLocation loc) {
return mlirIdentifierStr(
mlirLocationFileLineColRangeGetFilename(loc));
})
.def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine)
.def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn)
.def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine)
.def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn)
.def_static(
"fused",
[](const std::vector<PyLocation> &pyLocations,
Expand All @@ -2984,6 +2997,16 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("locations"), nb::arg("metadata").none() = nb::none(),
nb::arg("context").none() = nb::none(),
kContextGetFusedLocationDocstring)
.def("is_a_fused", mlirLocationIsAFused)
.def_prop_ro("locations",
[](MlirLocation loc) {
unsigned numLocations =
mlirLocationFusedGetNumLocations(loc);
std::vector<MlirLocation> locations(numLocations);
if (numLocations)
mlirLocationFusedGetLocations(loc, locations.data());
return locations;
})
.def_static(
"name",
[](std::string name, std::optional<PyLocation> childLoc,
Expand All @@ -2998,6 +3021,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("name"), nb::arg("childLoc").none() = nb::none(),
nb::arg("context").none() = nb::none(),
kContextGetNameLocationDocString)
.def("is_a_name", mlirLocationIsAName)
.def_prop_ro("name_str",
[](MlirLocation loc) {
return mlirIdentifierStr(mlirLocationNameGetName(loc));
})
.def_prop_ro("child_loc", mlirLocationNameGetChildLoc)
.def_static(
"from_attr",
[](PyAttribute &attribute, DefaultingPyMlirContext context) {
Expand Down Expand Up @@ -3148,9 +3177,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
auto &concreteOperation = self.getOperation();
concreteOperation.checkValid();
MlirOperation operation = concreteOperation.get();
MlirStringRef name =
mlirIdentifierStr(mlirOperationGetName(operation));
return nb::str(name.data, name.length);
return mlirIdentifierStr(mlirOperationGetName(operation));
})
.def_prop_ro("operands",
[](PyOperationBase &self) {
Expand Down Expand Up @@ -3738,8 +3765,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
.def_prop_ro(
"name",
[](PyNamedAttribute &self) {
return nb::str(mlirIdentifierStr(self.namedAttr.name).data,
mlirIdentifierStr(self.namedAttr.name).length);
return mlirIdentifierStr(self.namedAttr.name);
},
"The name of the NamedAttribute binding")
.def_prop_ro(
Expand Down Expand Up @@ -3972,7 +3998,16 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("with"), nb::arg("exceptions"),
kValueReplaceAllUsesExceptDocstring)
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
[](PyValue &self) { return self.maybeDownCast(); });
[](PyValue &self) { return self.maybeDownCast(); })
.def_prop_ro(
"location",
[](MlirValue self) {
return PyLocation(
PyMlirContext::forContext(mlirValueGetContext(self)),
mlirValueGetLocation(self));
},
"Returns the source location the value");

PyBlockArgument::bind(m);
PyOpResult::bind(m);
PyOpOperand::bind(m);
Expand Down
112 changes: 106 additions & 6 deletions mlir/lib/CAPI/IR/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ MlirAttribute mlirLocationGetAttribute(MlirLocation location) {
}

MlirLocation mlirLocationFromAttribute(MlirAttribute attribute) {
return wrap(Location(llvm::cast<LocationAttr>(unwrap(attribute))));
return wrap(Location(llvm::dyn_cast<LocationAttr>(unwrap(attribute))));
}

MlirLocation mlirLocationFileLineColGet(MlirContext context,
Expand All @@ -278,10 +278,62 @@ mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename,
startLine, startCol, endLine, endCol)));
}

MlirIdentifier mlirLocationFileLineColRangeGetFilename(MlirLocation location) {
return wrap(llvm::dyn_cast<FileLineColRange>(unwrap(location)).getFilename());
}

int mlirLocationFileLineColRangeGetStartLine(MlirLocation location) {
if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
return loc.getStartLine();
return -1;
}

int mlirLocationFileLineColRangeGetStartColumn(MlirLocation location) {
if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
return loc.getStartColumn();
return -1;
}

int mlirLocationFileLineColRangeGetEndLine(MlirLocation location) {
if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
return loc.getEndLine();
return -1;
}

int mlirLocationFileLineColRangeGetEndColumn(MlirLocation location) {
if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
return loc.getEndColumn();
return -1;
}

MlirTypeID mlirLocationFileLineColRangeGetTypeID() {
return wrap(FileLineColRange::getTypeID());
}

bool mlirLocationIsAFileLineColRange(MlirLocation location) {
return isa<FileLineColRange>(unwrap(location));
}

MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) {
return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller))));
}

MlirLocation mlirLocationCallSiteGetCallee(MlirLocation location) {
return wrap(Location(llvm::dyn_cast<CallSiteLoc>(unwrap(location)).getCallee()));
}

MlirLocation mlirLocationCallSiteGetCaller(MlirLocation location) {
return wrap(Location(llvm::dyn_cast<CallSiteLoc>(unwrap(location)).getCaller()));
}

MlirTypeID mlirLocationCallSiteGetTypeID() {
return wrap(CallSiteLoc::getTypeID());
}

bool mlirLocationIsACallSite(MlirLocation location) {
return isa<CallSiteLoc>(unwrap(location));
}

MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
MlirLocation const *locations,
MlirAttribute metadata) {
Expand All @@ -290,6 +342,30 @@ MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx)));
}

unsigned mlirLocationFusedGetNumLocations(MlirLocation location) {
if (auto locationsArrRef = llvm::dyn_cast<FusedLoc>(unwrap(location)))
return locationsArrRef.getLocations().size();
return 0;
}

void mlirLocationFusedGetLocations(MlirLocation location,
MlirLocation *locationsCPtr) {
if (auto locationsArrRef = llvm::dyn_cast<FusedLoc>(unwrap(location))) {
for (auto [i, location] : llvm::enumerate(locationsArrRef.getLocations()))
locationsCPtr[i] = wrap(location);
}
}

MlirAttribute mlirLocationFusedGetMetadata(MlirLocation location) {
return wrap(llvm::dyn_cast<FusedLoc>(unwrap(location)).getMetadata());
}

MlirTypeID mlirLocationFusedGetTypeID() { return wrap(FusedLoc::getTypeID()); }

bool mlirLocationIsAFused(MlirLocation location) {
return isa<FusedLoc>(unwrap(location));
}

MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name,
MlirLocation childLoc) {
if (mlirLocationIsNull(childLoc))
Expand All @@ -299,6 +375,21 @@ MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name,
StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc))));
}

MlirIdentifier mlirLocationNameGetName(MlirLocation location) {
return wrap((llvm::dyn_cast<NameLoc>(unwrap(location)).getName()));
}

MlirLocation mlirLocationNameGetChildLoc(MlirLocation location) {
return wrap(
Location(llvm::dyn_cast<NameLoc>(unwrap(location)).getChildLoc()));
}

MlirTypeID mlirLocationNameGetTypeID() { return wrap(NameLoc::getTypeID()); }

bool mlirLocationIsAName(MlirLocation location) {
return isa<NameLoc>(unwrap(location));
}

MlirLocation mlirLocationUnknownGet(MlirContext context) {
return wrap(Location(UnknownLoc::get(unwrap(context))));
}
Expand Down Expand Up @@ -975,25 +1066,26 @@ bool mlirValueIsAOpResult(MlirValue value) {
}

MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
return wrap(llvm::cast<BlockArgument>(unwrap(value)).getOwner());
return wrap(llvm::dyn_cast<BlockArgument>(unwrap(value)).getOwner());
}

intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
return static_cast<intptr_t>(
llvm::cast<BlockArgument>(unwrap(value)).getArgNumber());
llvm::dyn_cast<BlockArgument>(unwrap(value)).getArgNumber());
}

void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
llvm::cast<BlockArgument>(unwrap(value)).setType(unwrap(type));
if (auto blockArg = llvm::dyn_cast<BlockArgument>(unwrap(value)))
blockArg.setType(unwrap(type));
}

MlirOperation mlirOpResultGetOwner(MlirValue value) {
return wrap(llvm::cast<OpResult>(unwrap(value)).getOwner());
return wrap(llvm::dyn_cast<OpResult>(unwrap(value)).getOwner());
}

intptr_t mlirOpResultGetResultNumber(MlirValue value) {
return static_cast<intptr_t>(
llvm::cast<OpResult>(unwrap(value)).getResultNumber());
llvm::dyn_cast<OpResult>(unwrap(value)).getResultNumber());
}

MlirType mlirValueGetType(MlirValue value) {
Expand Down Expand Up @@ -1047,6 +1139,14 @@ void mlirValueReplaceAllUsesExcept(MlirValue oldValue, MlirValue newValue,
oldValueCpp.replaceAllUsesExcept(newValueCpp, exceptionSet);
}

MlirLocation mlirValueGetLocation(MlirValue v) {
return wrap(unwrap(v).getLoc());
}

MlirContext mlirValueGetContext(MlirValue v) {
return wrap(unwrap(v).getContext());
}

//===----------------------------------------------------------------------===//
// OpOperand API.
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 5fe8746

Please sign in to comment.