Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: fix raising a unyt array to an array power in sensible cases #524

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions unyt/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1764,8 +1764,8 @@ def __pow__(self, p, mod=None, /):
Power function
"""
# see https://github.com/yt-project/unyt/issues/203
if p == 0.0:
ret = self.ua
if np.isscalar(p) and p == 0.0:
ret = self.unit_array
ret.units = Unit("dimensionless")
return ret
else:
Expand Down Expand Up @@ -1854,17 +1854,32 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
u1 = Unit(registry=getattr(u0, "registry", None))
elif ufunc is power:
u1 = inp1
if inp0.shape != () and inp1.shape != ():
raise UnitOperationError(ufunc, u0, u1)
if isinstance(u1, unyt_array):
if u1.units.is_dimensionless:
pass
else:
if inp0.shape == () or inp1.shape == ():
if isinstance(u1, unyt_array) and not u1.units.is_dimensionless:
raise UnitOperationError(ufunc, u0, u1.units)
if u1.shape == ():
u1 = float(u1)
if u1.shape == ():
u1 = float(u1)
else:
u1 = 1.0
elif inp0.shape == inp1.shape:
if isinstance(u1, unyt_array) and not u1.units.is_dimensionless:
raise UnitOperationError(ufunc, u0, getattr(u1, "units", None))

if (
(isinstance(u0, Unit) and not u0.is_dimensionless)
or isinstance(u0, unyt_array)
and not u0.units.is_dimensionless
):
# u0 has units
if np.ptp(u1) != 0:
raise UnitOperationError(
ufunc, u0, getattr(u1, "units", None)
)

first_element_slice = (0,) * u1.ndim
u1 = float(u1[first_element_slice])
else:
u1 = 1.0
raise UnitOperationError(ufunc, u0, u1)
unit_operator = self._ufunc_registry[ufunc]

if (
Expand Down
38 changes: 38 additions & 0 deletions unyt/tests/test_unyt_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,44 @@ def test_power():
assert_array_equal_units(cm_arr**0, unyt_quantity(1.0, "dimensionless"))


@pytest.mark.parametrize(
"p",
[
pytest.param(unyt_array(0), id="scalar power 0"),
pytest.param(unyt_array(1), id="scalar power 1"),
pytest.param(unyt_array([0, 0, 0, 0]), id="uniform power 0"),
pytest.param(unyt_array([1, 1, 1, 1]), id="uniform power 1"),
pytest.param(unyt_array([0, 1, 2, 3]), id="non-uniform power"),
],
)
def test_dimensionless_array_power(p):
# see https://github.com/yt-project/unyt/issues/522
a = unyt_array([1, 2, 3, 4])
expected = a.value**p.value
assert_array_equal(expected, (a**p).value)


@pytest.mark.parametrize(
"a, p",
[
pytest.param(
unyt_array([0, 1, 2, 3], "cm"),
unyt_array([0, 1, 2, 3]),
id="non-uniform power with arr non dimensionless",
),
pytest.param(
unyt_array([0, 1, 2, 3]),
unyt_array([0, 0, 0, 0], "cm"),
id="non dimensionless power",
),
],
)
def test_invalid_array_power(a, p):
# see https://github.com/yt-project/unyt/issues/522
with pytest.raises(UnitOperationError):
a**p


def test_comparisons():
"""
Test numpy ufunc comparison operators for unit consistency.
Expand Down
Loading