Skip to content

Commit

Permalink
Implement Array ufunc (#160)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgsavage authored Jun 9, 2024
1 parent bdce199 commit 6cddc10
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pint-pandas Changelog

- Fix dequantify duplicate column failure #202
- Fix astype issue #196

- Support for `__array_ufunc__` and unary ops. #160

0.5 (2023-09-07)
----------------
Expand Down
59 changes: 59 additions & 0 deletions pint_pandas/pint_array.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import numbers
import re
import warnings
from importlib.metadata import version
Expand Down Expand Up @@ -218,6 +219,15 @@ def __repr__(self):
dtypeunmap = {v: k for k, v in dtypemap.items()}


def convert_np_inputs(inputs):
if isinstance(inputs, tuple):
return tuple(x.quantity if isinstance(x, PintArray) else x for x in inputs)
if isinstance(inputs, dict):
return {
item: (x.quantity if isinstance(x, PintArray) else x) for item, x in inputs
}


class PintArray(ExtensionArray, ExtensionScalarOpsMixin):
"""Implements a class to describe an array of physical quantities:
the product of an array of numerical values and a unit of measurement.
Expand All @@ -240,6 +250,7 @@ class PintArray(ExtensionArray, ExtensionScalarOpsMixin):
_data: ExtensionArray = cast(ExtensionArray, np.array([]))
context_name = None
context_units = None
_HANDLED_TYPES = (np.ndarray, numbers.Number, _Quantity)

def __init__(self, values, dtype=None, copy=False):
if dtype is None:
Expand Down Expand Up @@ -281,6 +292,54 @@ def __setstate__(self, dct):
self.__dict__.update(dct)
self._Q = self.dtype.ureg.Quantity

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
out = kwargs.get("out", ())
for x in inputs + out:
# Only support operations with instances of _HANDLED_TYPES.
# Use ArrayLike instead of type(self) for isinstance to
# allow subclasses that don't override __array_ufunc__ to
# handle ArrayLike objects.
if not isinstance(x, self._HANDLED_TYPES + (PintArray,)):
return NotImplemented

# Defer to pint's implementation of the ufunc.
inputs = convert_np_inputs(inputs)
if out:
kwargs["out"] = convert_np_inputs(out)
print(inputs)
result = getattr(ufunc, method)(*inputs, **kwargs)
return self._convert_np_result(result)

def _convert_np_result(self, result):
if isinstance(result, _Quantity) and is_list_like(result.m):
return PintArray.from_1darray_quantity(result)
elif isinstance(result, _Quantity):
return result
elif type(result) is tuple:
# multiple return values
return tuple(type(self)(x) for x in result)
elif isinstance(result, np.ndarray) and all(
isinstance(item, _Quantity) for item in result
):
return PintArray._from_sequence(result)
elif result is None:
# no return value
return result
elif pd.api.types.is_bool_dtype(result):
return result
else:
# one return value
return type(self)(result)

def __pos__(self):
return 1 * self

def __neg__(self):
return -1 * self

def __abs__(self):
return self._Q(np.abs(self._data), self._dtype.units)

@property
def dtype(self):
# type: () -> ExtensionDtype
Expand Down
11 changes: 11 additions & 0 deletions pint_pandas/testsuite/test_pandas_extensiontests.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,17 @@ def test_setitem_2d_values(self, data):
assert (df.loc[1, :] == original[0]).all()


class TestUnaryOps(base.BaseUnaryOpsTests):
@pytest.mark.xfail(run=True, reason="invert not implemented")
def test_invert(self, data):
base.BaseUnaryOpsTests.test_invert(self, data)

@pytest.mark.xfail(run=True, reason="np.positive requires pint 0.21")
@pytest.mark.parametrize("ufunc", [np.positive, np.negative, np.abs])
def test_unary_ufunc_dunder_equivalence(self, data, ufunc):
base.BaseUnaryOpsTests.test_unary_ufunc_dunder_equivalence(self, data, ufunc)


class TestAccumulate(base.BaseAccumulateTests):
@pytest.mark.parametrize("skipna", [True, False])
def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):
Expand Down

0 comments on commit 6cddc10

Please sign in to comment.