Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][CAPI][python] bind CallSiteLoc, FileLineColRange, FusedLoc, NameLoc #129351

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions mlir/include/mlir-c/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,22 +261,96 @@ MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColRangeGet(
MlirContext context, MlirStringRef filename, unsigned start_line,
unsigned start_col, unsigned end_line, unsigned end_col);

/// Getter for filename of FileLineColRange.
MLIR_CAPI_EXPORTED MlirIdentifier
mlirLocationFileLineColRangeGetFilename(MlirLocation location);

/// Getter for start_line of FileLineColRange.
MLIR_CAPI_EXPORTED int
mlirLocationFileLineColRangeGetStartLine(MlirLocation location);

/// Getter for start_column of FileLineColRange.
MLIR_CAPI_EXPORTED int
mlirLocationFileLineColRangeGetStartColumn(MlirLocation location);

/// Getter for end_line of FileLineColRange.
MLIR_CAPI_EXPORTED int
mlirLocationFileLineColRangeGetEndLine(MlirLocation location);

/// Getter for end_column of FileLineColRange.
MLIR_CAPI_EXPORTED int
mlirLocationFileLineColRangeGetEndColumn(MlirLocation location);

/// TypeID Getter for FileLineColRange.
MLIR_CAPI_EXPORTED MlirTypeID mlirLocationFileLineColRangeGetTypeID(void);

/// Checks whether the given location is an FileLineColRange.
MLIR_CAPI_EXPORTED bool mlirLocationIsAFileLineColRange(MlirLocation location);

/// Creates a call site location with a callee and a caller.
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee,
MlirLocation caller);

/// Getter for callee of CallSite.
MLIR_CAPI_EXPORTED MlirLocation
mlirLocationCallSiteGetCallee(MlirLocation location);

/// Getter for caller of CallSite.
MLIR_CAPI_EXPORTED MlirLocation
mlirLocationCallSiteGetCaller(MlirLocation location);

/// TypeID Getter for CallSite.
MLIR_CAPI_EXPORTED MlirTypeID mlirLocationCallSiteGetTypeID(void);

/// Checks whether the given location is an CallSite.
MLIR_CAPI_EXPORTED bool mlirLocationIsACallSite(MlirLocation location);

/// Creates a fused location with an array of locations and metadata.
MLIR_CAPI_EXPORTED MlirLocation
mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
MlirLocation const *locations, MlirAttribute metadata);

/// Getter for number of locations fused together.
MLIR_CAPI_EXPORTED unsigned
mlirLocationFusedGetNumLocations(MlirLocation location);

/// Getter for locations of Fused. Requires pre-allocated memory of
/// #fusedLocations X sizeof(MlirLocation).
MLIR_CAPI_EXPORTED void
mlirLocationFusedGetLocations(MlirLocation location,
MlirLocation *locationsCPtr);

/// Getter for metadata of Fused.
MLIR_CAPI_EXPORTED MlirAttribute
mlirLocationFusedGetMetadata(MlirLocation location);

/// TypeID Getter for Fused.
MLIR_CAPI_EXPORTED MlirTypeID mlirLocationFusedGetTypeID(void);

/// Checks whether the given location is an Fused.
MLIR_CAPI_EXPORTED bool mlirLocationIsAFused(MlirLocation location);

/// Creates a name location owned by the given context. Providing null location
/// for childLoc is allowed and if childLoc is null location, then the behavior
/// is the same as having unknown child location.
MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context,
MlirStringRef name,
MlirLocation childLoc);

/// Getter for name of Name.
MLIR_CAPI_EXPORTED MlirIdentifier
mlirLocationNameGetName(MlirLocation location);

/// Getter for childLoc of Name.
MLIR_CAPI_EXPORTED MlirLocation
mlirLocationNameGetChildLoc(MlirLocation location);

/// TypeID Getter for Name.
MLIR_CAPI_EXPORTED MlirTypeID mlirLocationNameGetTypeID(void);

/// Checks whether the given location is an Name.
MLIR_CAPI_EXPORTED bool mlirLocationIsAName(MlirLocation location);

/// Creates a location with unknown position owned by the given context.
MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context);

Expand Down Expand Up @@ -978,6 +1052,12 @@ mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with,
intptr_t numExceptions,
MlirOperation *exceptions);

/// Gets the location of the value.
MLIR_CAPI_EXPORTED MlirLocation mlirValueGetLocation(MlirValue v);

/// Gets the context that a value was created with.
MLIR_CAPI_EXPORTED MlirContext mlirValueGetContext(MlirValue v);

//===----------------------------------------------------------------------===//
// OpOperand API.
//===----------------------------------------------------------------------===//
Expand Down
10 changes: 10 additions & 0 deletions mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,16 @@ struct type_caster<MlirType> {
}
};

/// Casts MlirStringRef -> object.
template <>
struct type_caster<MlirStringRef> {
NB_TYPE_CASTER(MlirStringRef, const_name("MlirStringRef"))
static handle from_cpp(MlirStringRef s, rv_policy,
cleanup_list *cleanup) noexcept {
return nanobind::str(s.data, s.length).release();
}
};

} // namespace detail
} // namespace nanobind

Expand Down
47 changes: 41 additions & 6 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2943,6 +2943,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("callee"), nb::arg("frames"),
nb::arg("context").none() = nb::none(),
kContextGetCallSiteLocationDocstring)
.def("is_a_callsite", mlirLocationIsACallSite)
.def_prop_ro("callee", mlirLocationCallSiteGetCallee)
.def_prop_ro("caller", mlirLocationCallSiteGetCaller)
.def_static(
"file",
[](std::string filename, int line, int col,
Expand All @@ -2967,6 +2970,16 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
nb::arg("end_line"), nb::arg("end_col"),
nb::arg("context").none() = nb::none(), kContextGetFileRangeDocstring)
.def("is_a_file", mlirLocationIsAFileLineColRange)
.def_prop_ro("filename",
[](MlirLocation loc) {
return mlirIdentifierStr(
mlirLocationFileLineColRangeGetFilename(loc));
})
.def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine)
.def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn)
.def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine)
.def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn)
.def_static(
"fused",
[](const std::vector<PyLocation> &pyLocations,
Expand All @@ -2984,6 +2997,16 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("locations"), nb::arg("metadata").none() = nb::none(),
nb::arg("context").none() = nb::none(),
kContextGetFusedLocationDocstring)
.def("is_a_fused", mlirLocationIsAFused)
.def_prop_ro("locations",
[](MlirLocation loc) {
unsigned numLocations =
mlirLocationFusedGetNumLocations(loc);
std::vector<MlirLocation> locations(numLocations);
if (numLocations)
mlirLocationFusedGetLocations(loc, locations.data());
return locations;
})
.def_static(
"name",
[](std::string name, std::optional<PyLocation> childLoc,
Expand All @@ -2998,6 +3021,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("name"), nb::arg("childLoc").none() = nb::none(),
nb::arg("context").none() = nb::none(),
kContextGetNameLocationDocString)
.def("is_a_name", mlirLocationIsAName)
.def_prop_ro("name_str",
[](MlirLocation loc) {
return mlirIdentifierStr(mlirLocationNameGetName(loc));
})
.def_prop_ro("child_loc", mlirLocationNameGetChildLoc)
.def_static(
"from_attr",
[](PyAttribute &attribute, DefaultingPyMlirContext context) {
Expand Down Expand Up @@ -3148,9 +3177,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
auto &concreteOperation = self.getOperation();
concreteOperation.checkValid();
MlirOperation operation = concreteOperation.get();
MlirStringRef name =
mlirIdentifierStr(mlirOperationGetName(operation));
return nb::str(name.data, name.length);
return mlirIdentifierStr(mlirOperationGetName(operation));
})
.def_prop_ro("operands",
[](PyOperationBase &self) {
Expand Down Expand Up @@ -3738,8 +3765,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
.def_prop_ro(
"name",
[](PyNamedAttribute &self) {
return nb::str(mlirIdentifierStr(self.namedAttr.name).data,
mlirIdentifierStr(self.namedAttr.name).length);
return mlirIdentifierStr(self.namedAttr.name);
},
"The name of the NamedAttribute binding")
.def_prop_ro(
Expand Down Expand Up @@ -3972,7 +3998,16 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("with"), nb::arg("exceptions"),
kValueReplaceAllUsesExceptDocstring)
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
[](PyValue &self) { return self.maybeDownCast(); });
[](PyValue &self) { return self.maybeDownCast(); })
.def_prop_ro(
"location",
[](MlirValue self) {
return PyLocation(
PyMlirContext::forContext(mlirValueGetContext(self)),
mlirValueGetLocation(self));
},
"Returns the source location the value");

PyBlockArgument::bind(m);
PyOpResult::bind(m);
PyOpOperand::bind(m);
Expand Down
114 changes: 108 additions & 6 deletions mlir/lib/CAPI/IR/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ MlirAttribute mlirLocationGetAttribute(MlirLocation location) {
}

MlirLocation mlirLocationFromAttribute(MlirAttribute attribute) {
return wrap(Location(llvm::cast<LocationAttr>(unwrap(attribute))));
return wrap(Location(llvm::dyn_cast<LocationAttr>(unwrap(attribute))));
}

MlirLocation mlirLocationFileLineColGet(MlirContext context,
Expand All @@ -278,10 +278,64 @@ mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename,
startLine, startCol, endLine, endCol)));
}

MlirIdentifier mlirLocationFileLineColRangeGetFilename(MlirLocation location) {
return wrap(llvm::dyn_cast<FileLineColRange>(unwrap(location)).getFilename());
}

int mlirLocationFileLineColRangeGetStartLine(MlirLocation location) {
if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
return loc.getStartLine();
return -1;
}

int mlirLocationFileLineColRangeGetStartColumn(MlirLocation location) {
if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
return loc.getStartColumn();
return -1;
}

int mlirLocationFileLineColRangeGetEndLine(MlirLocation location) {
if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
return loc.getEndLine();
return -1;
}

int mlirLocationFileLineColRangeGetEndColumn(MlirLocation location) {
if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
return loc.getEndColumn();
return -1;
}

MlirTypeID mlirLocationFileLineColRangeGetTypeID() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where are the TypeID ones used? (It feels a little bit like a C++ implementation detail)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're not currently but they can be used for value builders (that thing I added a while ago for hooking cpp -> python) so that's why I included those APIs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the PR I'm talking about #69644

return wrap(FileLineColRange::getTypeID());
}

bool mlirLocationIsAFileLineColRange(MlirLocation location) {
return isa<FileLineColRange>(unwrap(location));
}

MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) {
return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller))));
}

MlirLocation mlirLocationCallSiteGetCallee(MlirLocation location) {
return wrap(
Location(llvm::dyn_cast<CallSiteLoc>(unwrap(location)).getCallee()));
}

MlirLocation mlirLocationCallSiteGetCaller(MlirLocation location) {
return wrap(
Location(llvm::dyn_cast<CallSiteLoc>(unwrap(location)).getCaller()));
}

MlirTypeID mlirLocationCallSiteGetTypeID() {
return wrap(CallSiteLoc::getTypeID());
}

bool mlirLocationIsACallSite(MlirLocation location) {
return isa<CallSiteLoc>(unwrap(location));
}

MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
MlirLocation const *locations,
MlirAttribute metadata) {
Expand All @@ -290,6 +344,30 @@ MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx)));
}

unsigned mlirLocationFusedGetNumLocations(MlirLocation location) {
if (auto locationsArrRef = llvm::dyn_cast<FusedLoc>(unwrap(location)))
return locationsArrRef.getLocations().size();
return 0;
}

void mlirLocationFusedGetLocations(MlirLocation location,
MlirLocation *locationsCPtr) {
if (auto locationsArrRef = llvm::dyn_cast<FusedLoc>(unwrap(location))) {
for (auto [i, location] : llvm::enumerate(locationsArrRef.getLocations()))
locationsCPtr[i] = wrap(location);
}
}

MlirAttribute mlirLocationFusedGetMetadata(MlirLocation location) {
return wrap(llvm::dyn_cast<FusedLoc>(unwrap(location)).getMetadata());
}

MlirTypeID mlirLocationFusedGetTypeID() { return wrap(FusedLoc::getTypeID()); }

bool mlirLocationIsAFused(MlirLocation location) {
return isa<FusedLoc>(unwrap(location));
}

MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name,
MlirLocation childLoc) {
if (mlirLocationIsNull(childLoc))
Expand All @@ -299,6 +377,21 @@ MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name,
StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc))));
}

MlirIdentifier mlirLocationNameGetName(MlirLocation location) {
return wrap((llvm::dyn_cast<NameLoc>(unwrap(location)).getName()));
}

MlirLocation mlirLocationNameGetChildLoc(MlirLocation location) {
return wrap(
Location(llvm::dyn_cast<NameLoc>(unwrap(location)).getChildLoc()));
}

MlirTypeID mlirLocationNameGetTypeID() { return wrap(NameLoc::getTypeID()); }

bool mlirLocationIsAName(MlirLocation location) {
return isa<NameLoc>(unwrap(location));
}

MlirLocation mlirLocationUnknownGet(MlirContext context) {
return wrap(Location(UnknownLoc::get(unwrap(context))));
}
Expand Down Expand Up @@ -975,25 +1068,26 @@ bool mlirValueIsAOpResult(MlirValue value) {
}

MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
return wrap(llvm::cast<BlockArgument>(unwrap(value)).getOwner());
return wrap(llvm::dyn_cast<BlockArgument>(unwrap(value)).getOwner());
}

intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
return static_cast<intptr_t>(
llvm::cast<BlockArgument>(unwrap(value)).getArgNumber());
llvm::dyn_cast<BlockArgument>(unwrap(value)).getArgNumber());
}

void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
llvm::cast<BlockArgument>(unwrap(value)).setType(unwrap(type));
if (auto blockArg = llvm::dyn_cast<BlockArgument>(unwrap(value)))
blockArg.setType(unwrap(type));
}

MlirOperation mlirOpResultGetOwner(MlirValue value) {
return wrap(llvm::cast<OpResult>(unwrap(value)).getOwner());
return wrap(llvm::dyn_cast<OpResult>(unwrap(value)).getOwner());
}

intptr_t mlirOpResultGetResultNumber(MlirValue value) {
return static_cast<intptr_t>(
llvm::cast<OpResult>(unwrap(value)).getResultNumber());
llvm::dyn_cast<OpResult>(unwrap(value)).getResultNumber());
}

MlirType mlirValueGetType(MlirValue value) {
Expand Down Expand Up @@ -1047,6 +1141,14 @@ void mlirValueReplaceAllUsesExcept(MlirValue oldValue, MlirValue newValue,
oldValueCpp.replaceAllUsesExcept(newValueCpp, exceptionSet);
}

MlirLocation mlirValueGetLocation(MlirValue v) {
return wrap(unwrap(v).getLoc());
}

MlirContext mlirValueGetContext(MlirValue v) {
return wrap(unwrap(v).getContext());
}

//===----------------------------------------------------------------------===//
// OpOperand API.
//===----------------------------------------------------------------------===//
Expand Down
Loading