From 1fedbac6d75f66070ecdb26b9b53cf8ef3a26f55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Tue, 8 Oct 2024 15:43:40 +0200 Subject: [PATCH] BUG: fix raising a unyt array to an array power in sensible cases --- unyt/array.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/unyt/array.py b/unyt/array.py index 2141a2c3..5a418739 100644 --- a/unyt/array.py +++ b/unyt/array.py @@ -1764,8 +1764,8 @@ def __pow__(self, p, mod=None, /): Power function """ # see https://github.com/yt-project/unyt/issues/203 - if p == 0.0: - ret = self.ua + if np.isscalar(p) and p == 0.0: + ret = self.unit_array ret.units = Unit("dimensionless") return ret else: @@ -1854,17 +1854,22 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): u1 = Unit(registry=getattr(u0, "registry", None)) elif ufunc is power: u1 = inp1 - if inp0.shape != () and inp1.shape != (): - raise UnitOperationError(ufunc, u0, u1) - if isinstance(u1, unyt_array): - if u1.units.is_dimensionless: - pass - else: + if inp0.shape == () or inp1.shape == (): + if isinstance(u1, unyt_array) and not u1.units.is_dimensionless: raise UnitOperationError(ufunc, u0, u1.units) - if u1.shape == (): - u1 = float(u1) + if u1.shape == (): + u1 = float(u1) + else: + u1 = 1.0 + elif inp0.shape == inp1.shape: + if ( + isinstance(u1, unyt_array) and not u1.units.is_dimensionless + ) or np.ptp(u1) != 0: + raise UnitOperationError(ufunc, u0, getattr(u1, "units", None)) + first_element_slice = (0,) * u1.ndim + u1 = float(u1[first_element_slice]) else: - u1 = 1.0 + raise UnitOperationError(ufunc, u0, u1) unit_operator = self._ufunc_registry[ufunc] if (