Skip to content

Commit

Permalink
Implement Tensor indexing (partial), addition, division, and list con…
Browse files Browse the repository at this point in the history
…catenation (#53)

Co-authored-by: Arham Khan <[email protected]>
  • Loading branch information
123epsilon and Arham Khan authored Jun 28, 2023
1 parent 78dc279 commit 0e5d5cc
Show file tree
Hide file tree
Showing 10 changed files with 217 additions and 11 deletions.
3 changes: 2 additions & 1 deletion cpp_ext/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ void populateTorchMLIROps(py::module &m) {
const PyTorch_BoolValue &keepdim,
const PyAnyTorchOptionalIntValue &dtype, DefaultingPyLocation &loc,
const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue {
return mean(self, dim, keepdim, dtype, loc.get(), ip.get());
auto dims = PyAnyTorchOptionalListOfTorchIntValue(py::make_tuple(dim));
return mean(self, dims, keepdim, dtype, loc.get(), ip.get());
},
"self"_a, "dim"_a = py::none(), "keepdim"_a = false,
"dtype"_a = py::none(), py::kw_only(), "loc"_a = py::none(),
Expand Down
101 changes: 101 additions & 0 deletions cpp_ext/TorchTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,107 @@ void PyAnyTorchTensorValue::bindDerived(ClassTy &c) {
"chunks"_a, "dim"_a = 0, py::kw_only(), "loc"_a = py::none(),
"ip"_a = py::none());

// __truediv__(self, other Any) -> Tensor
// aten::div.Scalar : (Tensor, Scalar) -> (Tensor)
c.def(
"__truediv__",
[](const PyAnyTorchTensorValue &self,
PyAnyTorchScalarValue &other) -> PyAnyTorchTensorValue {
auto loc = getValueLocation(self);
return div(self, other, &loc, &DefaultingPyInsertionPoint::resolve());
},
"other"_a);

// aten::div.Tensor : (Tensor, Tensor) -> (Tensor)
c.def(
"__truediv__",
[](const PyAnyTorchTensorValue &self,
PyAnyTorchTensorValue &other) -> PyAnyTorchTensorValue {
auto loc = getValueLocation(self);
return div(self, other, &loc, &DefaultingPyInsertionPoint::resolve());
},
"other"_a);

// __rtruediv__(self, other Any) -> Tensor
c.def(
"__rtruediv__",
[](const PyAnyTorchTensorValue &self,
PyAnyTorchScalarValue &other) -> PyAnyTorchTensorValue {
auto loc = getValueLocation(self);
PyAnyTorchTensorValue recip =
reciprocal(self, &loc, &DefaultingPyInsertionPoint::resolve());
auto recip_loc = getValueLocation(recip);
return mul(recip, other, &recip_loc,
&DefaultingPyInsertionPoint::resolve());
},
"other"_a);

// __getitem__(self, indices: Union[None, _int, slice, Tensor, List, Tuple])
// -> Tensor
// __getitem__(self, int) -> Tensor
c.def(
"__getitem__",
[](const PyAnyTorchTensorValue &self,
const PyTorch_IntValue &index) -> PyAnyTorchTensorValue {
auto loc = getValueLocation(self);
return select(self, 0, index, &loc,
&DefaultingPyInsertionPoint::resolve());
},
"index"_a);

// __getitem__(self, None) -> Tensor
c.def("__getitem__",
[](const PyAnyTorchTensorValue &self,
const py::none &noneValue) -> PyAnyTorchTensorValue {
auto loc = getValueLocation(self);
return unsqueeze(self, 0, &loc,
&DefaultingPyInsertionPoint::resolve());
});

// __getitem__(self, slice) -> Tensor
c.def("__getitem__",
[](const PyAnyTorchTensorValue &self,
py::slice &sliceObject) -> PyAnyTorchTensorValue {
int dim = 0;

auto parseAttr =
[](const py::object &obj) -> PyAnyTorchOptionalIntValue {
if (py::isinstance<py::none>(obj)) {
return obj.cast<PyAnyTorchOptionalIntValue>();
} else if (py::isinstance<py::int_>(obj)) {
return obj.cast<int>();
} else {
throw std::invalid_argument(
"Invalid: aten.slice.Tensor expects either an integer or "
"None type as indices");
}
};

PyAnyTorchOptionalIntValue start =
parseAttr(getattr(sliceObject, "start"));
PyAnyTorchOptionalIntValue stop =
parseAttr(getattr(sliceObject, "stop"));
py::object step_attr = getattr(sliceObject, "step");
PyTorch_IntValue step =
py::isinstance<py::none>(step_attr) ? 1 : step_attr.cast<int>();

auto loc = getValueLocation(step);

return slice(self, dim, start, stop, step, &loc,
&DefaultingPyInsertionPoint::resolve());
});

// @overload reshape(self, shape: Sequence[Union[_int, SymInt]]) -> Tensor
// aten::reshape : (Tensor, int...) -> (Tensor)
c.def("reshape",
[](const PyAnyTorchTensorValue &self,
const py::args &args) -> PyAnyTorchTensorValue {
auto shape = PyAnyTorchListOfTorchIntValue(args);
auto loc = getValueLocation(shape);
return reshape(self, shape, &loc,
&DefaultingPyInsertionPoint::resolve());
});

#include "TorchTensor.pybinds.cpp"
}

Expand Down
6 changes: 0 additions & 6 deletions cpp_ext/TorchTensor.pybinds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ c.def("__ge__", [](const PyAnyTorchTensorValue &self, const PyAnyTorchScalarValu
// aten::ge.Tensor : (Tensor, Tensor) -> (Tensor)
c.def("__ge__", [](const PyAnyTorchTensorValue &self, const PyAnyTorchTensorValue &other, DefaultingPyLocation &loc, const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue { return __ge__(self, other, loc.get(), ip.get()); }, "other"_a, py::kw_only(), "loc"_a = py::none(), "ip"_a = py::none());

// __getitem__(self, indices: Union[None, _int, slice, Tensor, List, Tuple]) -> Tensor
c.def("__getitem__", [](PyAnyTorchTensorValue& self, py::args args, py::kwargs kwargs) { throw NotImplementedError("NotImplementedError: __getitem__ with signature __getitem__(self, indices: Union[None, _int, slice, Tensor, List, Tuple]) -> Tensor"); });

// __gt__(self, other: Any) -> Tensor
// aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)
c.def("__gt__", [](const PyAnyTorchTensorValue &self, const PyAnyTorchScalarValue &other, DefaultingPyLocation &loc, const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue { return __gt__(self, other, loc.get(), ip.get()); }, "other"_a, py::kw_only(), "loc"_a = py::none(), "ip"_a = py::none());
Expand Down Expand Up @@ -286,9 +283,6 @@ c.def("_nested_tensor_strides", [](PyAnyTorchTensorValue& self, py::args args, p
// _nnz(self) -> _int
c.def("_nnz", [](PyAnyTorchTensorValue& self, py::args args, py::kwargs kwargs) { throw NotImplementedError("NotImplementedError: _nnz with signature _nnz(self) -> _int"); });

// _sparse_mask_projection(self, mask: Tensor) -> Tensor
c.def("_sparse_mask_projection", [](PyAnyTorchTensorValue& self, py::args args, py::kwargs kwargs) { throw NotImplementedError("NotImplementedError: _sparse_mask_projection with signature _sparse_mask_projection(self, mask: Tensor) -> Tensor"); });

// _to_dense(self, dtype: Optional[_dtype]=None, masked_grad: Optional[_bool]=None) -> Tensor
c.def("_to_dense", [](PyAnyTorchTensorValue& self, py::args args, py::kwargs kwargs) { throw NotImplementedError("NotImplementedError: _to_dense with signature _to_dense(self, dtype: Optional[_dtype]=None, masked_grad: Optional[_bool]=None) -> Tensor"); });

Expand Down
18 changes: 18 additions & 0 deletions cpp_ext/TorchValues.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,24 @@ makeListIter<PyAnyTorchTensorValue, PyAnyTorchListOfTensorValue const>(
void PyAnyTorchListValue::bindDerived(ClassTy &c) {
c.def(py::init<py::list>(), "value"_a);
c.def(py::init<py::tuple>(), "value"_a);
c.def(
"__add__",
[](const PyAnyTorchListValue &self,
const PyAnyTorchListValue &other) -> PyAnyTorchListValue {
auto loc = getValueLocation(self);
return add(self, other, &loc, &DefaultingPyInsertionPoint::resolve());
},
"other"_a);

c.def(
"__radd__",
[](const PyAnyTorchListValue &self,
const PyAnyTorchListValue &other) -> PyAnyTorchListValue {
auto loc = getValueLocation(self);
return add(self, other, &loc, &DefaultingPyInsertionPoint::resolve());
},
"other"_a);

py::implicitly_convertible<py::list, PyAnyTorchListValue>();
py::implicitly_convertible<py::tuple, PyAnyTorchListValue>();
c.def("__len__", [](const PyAnyTorchListValue &self) { return self.length; });
Expand Down
2 changes: 1 addition & 1 deletion pi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
LongTensor,
TensorPlaceholder,
dtype,
empty,
empty_placeholder,
layout,
memory_format,
ones,
Expand Down
2 changes: 1 addition & 1 deletion pi/mlir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def wrapper(*args, **kwargs):
return wrapper


empty = functools.partial(_np_wrapper, factory=np.empty)
empty_placeholder = functools.partial(_np_wrapper, factory=np.empty)
ones = functools.partial(_np_wrapper, factory=star_args_wrapper(np.ones))
zeros = functools.partial(_np_wrapper, factory=star_args_wrapper(np.zeros))
rand = functools.partial(_np_wrapper, factory=np.random.rand)
Expand Down
4 changes: 2 additions & 2 deletions pi/nn/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Union, List, Tuple


from pi import Tensor, dtype, empty
from pi import Tensor, dtype, empty_placeholder
import pi


Expand All @@ -27,7 +27,7 @@ def __new__(cls, *args, **keywords):
if inspect.isfunction(func) or isinstance(func, functools.partial):
args = args[1:]
else:
func = empty
func = empty_placeholder

if isinstance(args[0], (tuple, list)):
assert len(args) == 1, f"unknown len args {args}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def get_clean_name(name):
"__truediv__(self, other: Any) -> Tensor",
"__rtruediv__(self, other: Any) -> Tensor",
"chunk(self, chunks: _int, dim: _int=0) -> List[Tensor]",
"__getitem__(self, indices: Union[None, _int, slice, Tensor, List, Tuple]) -> Tensor",
}

TORCH_OPS_IMPL_CPP = "TorchOps.impls.cpp"
Expand Down
37 changes: 37 additions & 0 deletions tests/unit/test_mlir_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,3 +631,40 @@ def test_ListIndexing(self):
"Torch_StringValue(%7 = torch.aten.__getitem__.t %6, %int0 : !torch.list<str>, !torch.int -> !torch.str)",
t,
)

def test_ListConcatenation(self):
with mlir_mod_ctx():
x = AnyTorchListValue([1, 2, 3])
y = AnyTorchListValue([4, 5])

r = x + y
check_correct(
"AnyTorchListValue(%DONT_CARE = torch.aten.add.t %DONT_CARE, %DONT_CARE : !torch.list<int>, !torch.list<int> -> !torch.list<int>)",
r,
)

py_list = [1, 2, 3]
r = x + py_list
check_correct(
"AnyTorchListValue(%DONT_CARE = torch.aten.add.t %DONT_CARE, %DONT_CARE : !torch.list<int>, !torch.list<int> -> !torch.list<int>)",
r,
)

r = py_list + x
check_correct(
"AnyTorchListValue(%DONT_CARE = torch.aten.add.t %DONT_CARE, %DONT_CARE : !torch.list<int>, !torch.list<int> -> !torch.list<int>)",
r,
)

py_tuple = (4, 5)
r = x + py_tuple
check_correct(
"AnyTorchListValue(%DONT_CARE = torch.aten.add.t %DONT_CARE, %DONT_CARE : !torch.list<int>, !torch.list<int> -> !torch.list<int>)",
r,
)

r = py_tuple + x
check_correct(
"AnyTorchListValue(%DONT_CARE = torch.aten.add.t %DONT_CARE, %DONT_CARE : !torch.list<int>, !torch.list<int> -> !torch.list<int>)",
r,
)
54 changes: 54 additions & 0 deletions tests/unit/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np

import pi
from pi import ops
from pi.mlir.utils import mlir_mod_ctx, int_op, non_value_tensor_op, bool_op, tensor_op
from util import check_correct
Expand Down Expand Up @@ -106,3 +107,56 @@ def test_optional_args(self):
"Tensor(%2 = torch.aten.gather %0, %int0, %0, %false_3 : !torch.tensor<[10,10],si64>, !torch.int, !torch.tensor<[10,10],si64>, !torch.bool -> !torch.tensor)",
r,
)

def test_tensor_div_overload(self):
with mlir_mod_ctx():
x = pi.ones(3)
y = pi.ones(3)
z = 2

d = x / z
check_correct(
"Tensor(%2 = torch.aten.div.Scalar %0, %int2 : !torch.tensor<[3],f64>, !torch.int -> !torch.tensor)",
d,
)

d = x / y
check_correct(
"Tensor(%3 = torch.aten.div.Tensor %0, %1 : !torch.tensor<[3],f64>, !torch.tensor<[3],f64> -> !torch.tensor)",
d,
)

# __rtruediv__ is computed via a reciprocal(tensor) and a scalar multiplication operator
d = z / x
check_correct(
"Tensor(%5 = torch.aten.mul.Scalar %4, %int2 : !torch.tensor, !torch.int -> !torch.tensor)",
d,
)

def test_tensor_indexing(self):
with mlir_mod_ctx():
x = pi.ones((4, 4))

v = x[0]
check_correct(
"Tensor(%1 = torch.aten.select.int %0, %int0_0, %int0 : !torch.tensor<[4,4],f64>, !torch.int, !torch.int -> !torch.tensor)",
v,
)

v = x[None]
check_correct(
"Tensor(%3 = torch.aten.unsqueeze %0, %int0_1 : !torch.tensor<[4,4],f64>, !torch.int -> !torch.tensor)",
v,
)

v = x[1:3:2]
check_correct(
"Tensor(%4 = torch.aten.slice.Tensor %0, %int0_4, %int1_2, %int3, %int2_3 : !torch.tensor<[4,4],f64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor)",
v,
)

v = x[:]
check_correct(
"Tensor(%5 = torch.aten.slice.Tensor %0, %int0_7, %none, %none_5, %int1_6 : !torch.tensor<[4,4],f64>, !torch.int, !torch.none, !torch.none, !torch.int -> !torch.tensor)",
v,
)

0 comments on commit 0e5d5cc

Please sign in to comment.