Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Array ufunc #160

Merged
merged 12 commits into from
Jun 9, 2024
2 changes: 1 addition & 1 deletion CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pint-pandas Changelog

- Fix dequantify duplicate column failure #202
- Fix astype issue #196

- Support for `__array_ufunc__` and unary ops. #160

0.5 (2023-09-07)
----------------
Expand Down
59 changes: 59 additions & 0 deletions pint_pandas/pint_array.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import numbers
import re
import warnings
from importlib.metadata import version
Expand Down Expand Up @@ -218,6 +219,15 @@ def __repr__(self):
dtypeunmap = {v: k for k, v in dtypemap.items()}


def convert_np_inputs(inputs):
if isinstance(inputs, tuple):
return tuple(x.quantity if isinstance(x, PintArray) else x for x in inputs)
if isinstance(inputs, dict):
return {
item: (x.quantity if isinstance(x, PintArray) else x) for item, x in inputs
}


class PintArray(ExtensionArray, ExtensionScalarOpsMixin):
"""Implements a class to describe an array of physical quantities:
the product of an array of numerical values and a unit of measurement.
Expand All @@ -240,6 +250,7 @@ class PintArray(ExtensionArray, ExtensionScalarOpsMixin):
_data: ExtensionArray = cast(ExtensionArray, np.array([]))
context_name = None
context_units = None
_HANDLED_TYPES = (np.ndarray, numbers.Number, _Quantity)

def __init__(self, values, dtype=None, copy=False):
if dtype is None:
Expand Down Expand Up @@ -281,6 +292,54 @@ def __setstate__(self, dct):
self.__dict__.update(dct)
self._Q = self.dtype.ureg.Quantity

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
out = kwargs.get("out", ())
for x in inputs + out:
# Only support operations with instances of _HANDLED_TYPES.
# Use ArrayLike instead of type(self) for isinstance to
# allow subclasses that don't override __array_ufunc__ to
# handle ArrayLike objects.
if not isinstance(x, self._HANDLED_TYPES + (PintArray,)):
return NotImplemented

# Defer to pint's implementation of the ufunc.
inputs = convert_np_inputs(inputs)
if out:
kwargs["out"] = convert_np_inputs(out)
print(inputs)
result = getattr(ufunc, method)(*inputs, **kwargs)
return self._convert_np_result(result)

def _convert_np_result(self, result):
if isinstance(result, _Quantity) and is_list_like(result.m):
return PintArray.from_1darray_quantity(result)
elif isinstance(result, _Quantity):
return result
elif type(result) is tuple:
# multiple return values
return tuple(type(self)(x) for x in result)
elif isinstance(result, np.ndarray) and all(
isinstance(item, _Quantity) for item in result
):
return PintArray._from_sequence(result)
elif result is None:
# no return value
return result
elif pd.api.types.is_bool_dtype(result):
return result
else:
# one return value
return type(self)(result)

def __pos__(self):
return 1 * self

def __neg__(self):
return -1 * self

def __abs__(self):
return self._Q(np.abs(self._data), self._dtype.units)

@property
def dtype(self):
# type: () -> ExtensionDtype
Expand Down
11 changes: 11 additions & 0 deletions pint_pandas/testsuite/test_pandas_extensiontests.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,17 @@ def test_setitem_2d_values(self, data):
assert (df.loc[1, :] == original[0]).all()


class TestUnaryOps(base.BaseUnaryOpsTests):
@pytest.mark.xfail(run=True, reason="invert not implemented")
def test_invert(self, data):
base.BaseUnaryOpsTests.test_invert(self, data)

@pytest.mark.xfail(run=True, reason="np.positive requires pint 0.21")
@pytest.mark.parametrize("ufunc", [np.positive, np.negative, np.abs])
def test_unary_ufunc_dunder_equivalence(self, data, ufunc):
base.BaseUnaryOpsTests.test_unary_ufunc_dunder_equivalence(self, data, ufunc)


class TestAccumulate(base.BaseAccumulateTests):
@pytest.mark.parametrize("skipna", [True, False])
def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):
Expand Down
Loading