diff --git a/unyt/_array_functions.py b/unyt/_array_functions.py index 6b321aca..06232e47 100644 --- a/unyt/_array_functions.py +++ b/unyt/_array_functions.py @@ -13,6 +13,32 @@ UnytError, ) +# Functions for which passing units doesn't make sense +# bail out with NotImplemented (escalated to TypeError by numpy) +_UNSUPPORTED_FUNCTIONS = { + # Polynomials + np.poly, + np.polyadd, + np.polyder, + np.polydiv, + np.polyfit, + np.polyint, + np.polymul, + np.polysub, + np.polyval, + np.roots, + np.vander, + # datetime64 is not a sensible dtype for unyt_array + np.datetime_as_string, + np.busday_count, + np.busday_offset, + np.is_busday, + # not clear how to approach + np.piecewise, # astropy.units doens't have a simple implementation either + np.packbits, + np.unpackbits, +} + _HANDLED_FUNCTIONS = {} diff --git a/unyt/array.py b/unyt/array.py index 98645689..71c5b69c 100644 --- a/unyt/array.py +++ b/unyt/array.py @@ -1982,7 +1982,13 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): def __array_function__(self, func, types, args, kwargs): # Follow NEP 18 guidelines # https://numpy.org/neps/nep-0018-array-function-protocol.html - from unyt._array_functions import _HANDLED_FUNCTIONS + from unyt._array_functions import _HANDLED_FUNCTIONS, _UNSUPPORTED_FUNCTIONS + + if func in _UNSUPPORTED_FUNCTIONS: + # following NEP 18, return NotImplemented as a sentinel value + # which will lead to raising a TypeError, while + # leaving other arguments a chance to take the lead + return NotImplemented if func not in _HANDLED_FUNCTIONS: # default to numpy's private implementation diff --git a/unyt/tests/test_array_functions.py b/unyt/tests/test_array_functions.py index d5f5f3b9..b75e6a0b 100644 --- a/unyt/tests/test_array_functions.py +++ b/unyt/tests/test_array_functions.py @@ -7,7 +7,10 @@ from packaging.version import Version from unyt import A, K, cm, degC, degF, delta_degC, g, km, rad, s -from unyt._array_functions import _HANDLED_FUNCTIONS as HANDLED_FUNCTIONS +from unyt._array_functions import ( + _HANDLED_FUNCTIONS as HANDLED_FUNCTIONS, + _UNSUPPORTED_FUNCTIONS as UNSUPPORTED_FUNCTIONS, +) from unyt.array import unyt_array, unyt_quantity from unyt.exceptions import ( InvalidUnitOperation, @@ -150,36 +153,16 @@ np.take_along_axis, # works out of the box (tested) } -# Functions that are wrappable but don't really make sense with units +# Functions for which behaviour is intentionally left to default IGNORED_FUNCTIONS = { - # Polynomials - np.poly, - np.polyadd, - np.polyder, - np.polydiv, - np.polyfit, - np.polyint, - np.polymul, - np.polysub, - np.polyval, - np.roots, - np.vander, + np.i0, # IO functions (no way to add units) np.save, np.savez, np.savez_compressed, - # datetime64 is not a sensible dtype for unyt_array - np.datetime_as_string, - np.busday_count, - np.busday_offset, - np.is_busday, - # not clear how to approach - np.piecewise, # astropy.units doens't have a simple implementation either - np.packbits, - np.unpackbits, - np.i0, } + # this set represents all functions that need inspection, tests, or both # it is always possible that some of its elements belong in NOOP_FUNCTIONS TODO_FUNCTIONS = { @@ -206,7 +189,9 @@ "cumproduct", # deprecated in numpy 1.25 } -NOT_HANDLED_FUNCTIONS = NOOP_FUNCTIONS | TODO_FUNCTIONS | IGNORED_FUNCTIONS +NOT_HANDLED_FUNCTIONS = ( + NOOP_FUNCTIONS | TODO_FUNCTIONS | UNSUPPORTED_FUNCTIONS | IGNORED_FUNCTIONS +) for func in DEPRECATED_FUNCTIONS: if hasattr(np, func):