Skip to content

Commit

Permalink
Implement PyAnyTorchListOfTensorValue list __getitem__ (#57)
Browse files Browse the repository at this point in the history
Co-authored-by: Arham Khan <[email protected]>
  • Loading branch information
123epsilon and Arham Khan authored Jun 28, 2023
1 parent 0e5d5cc commit 9b6129d
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 0 deletions.
9 changes: 9 additions & 0 deletions cpp_ext/TorchTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,15 @@ void PyAnyTorchListOfTensorValue::bindDerived(ClassTy &c) {
self, &DefaultingPyLocation::resolve(),
&DefaultingPyInsertionPoint::resolve())));
});
c.def(
"__getitem__",
[](const PyAnyTorchListOfTensorValue &self,
const PyTorch_IntValue &idx) -> PyAnyTorchTensorValue {
return makeGetItem<PyAnyTorchTensorValue>(
self, idx, &DefaultingPyLocation::resolve(),
&DefaultingPyInsertionPoint::resolve());
},
"idx"_a);
}

PyAnyTorchListOfOptionalTensorValue
Expand Down
8 changes: 8 additions & 0 deletions cpp_ext/TorchValues.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,9 @@ T makeGetItem(U &self, const PyTorch_IntValue &idx, PyLocation *loc,
t = torchMlirTorchIntTypeGet(loc->getContext()->get());
else if (std::is_same<T, PyTorch_StringValue>::value)
t = torchMlirTorchStringTypeGet(loc->getContext()->get());
else if (std::is_same<T, PyAnyTorchTensorValue>::value)
t = torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(
loc->getContext()->get());
else
throw std::runtime_error("unknown element type");
auto resultType = py::cast(t).cast<PyType>();
Expand All @@ -410,6 +413,11 @@ T makeGetItem(U &self, const PyTorch_IntValue &idx, PyLocation *loc,
return {opRef, mlirOperationGetResult(opRef->get(), 0)};
}

template PyAnyTorchTensorValue
makeGetItem<PyAnyTorchTensorValue>(const PyAnyTorchListOfTensorValue &self,
const PyTorch_IntValue &idx, PyLocation *loc,
PyInsertionPoint *ip);

MlirOperation getOwner(const PyValue &value) {
MlirOperation owner;
if (mlirValueIsAOpResult(value))
Expand Down
4 changes: 4 additions & 0 deletions cpp_ext/TorchValues.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ PyAnyTorchListValue makePyAnyTorchListValue(const py::object &type,
PyLocation *loc,
PyInsertionPoint *ip);

template <typename T, typename U>
T makeGetItem(U &self, const PyTorch_IntValue &idx, PyLocation *loc,
PyInsertionPoint *ip);

class PyTorch_NoneValue : public PyConcreteValue<PyTorch_NoneValue> {
public:
static constexpr IsAFunctionTy isaFunction = isATorch_NoneValue;
Expand Down

0 comments on commit 9b6129d

Please sign in to comment.