Skip to content

Commit

Permalink
Merge pull request #399 from neutrinoceros/nep18_unsupported
Browse files Browse the repository at this point in the history
ENH: (NEP 18) follow NEP 18 more closely for unsupported functions
  • Loading branch information
ngoldbaum authored Mar 30, 2023
2 parents 1b84388 + c426635 commit d17c0da
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 26 deletions.
26 changes: 26 additions & 0 deletions unyt/_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,32 @@
UnytError,
)

# Functions for which passing units doesn't make sense
# bail out with NotImplemented (escalated to TypeError by numpy)
_UNSUPPORTED_FUNCTIONS = {
# Polynomials
np.poly,
np.polyadd,
np.polyder,
np.polydiv,
np.polyfit,
np.polyint,
np.polymul,
np.polysub,
np.polyval,
np.roots,
np.vander,
# datetime64 is not a sensible dtype for unyt_array
np.datetime_as_string,
np.busday_count,
np.busday_offset,
np.is_busday,
# not clear how to approach
np.piecewise, # astropy.units doens't have a simple implementation either
np.packbits,
np.unpackbits,
}

_HANDLED_FUNCTIONS = {}


Expand Down
8 changes: 7 additions & 1 deletion unyt/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1982,7 +1982,13 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
def __array_function__(self, func, types, args, kwargs):
# Follow NEP 18 guidelines
# https://numpy.org/neps/nep-0018-array-function-protocol.html
from unyt._array_functions import _HANDLED_FUNCTIONS
from unyt._array_functions import _HANDLED_FUNCTIONS, _UNSUPPORTED_FUNCTIONS

if func in _UNSUPPORTED_FUNCTIONS:
# following NEP 18, return NotImplemented as a sentinel value
# which will lead to raising a TypeError, while
# leaving other arguments a chance to take the lead
return NotImplemented

if func not in _HANDLED_FUNCTIONS:
# default to numpy's private implementation
Expand Down
35 changes: 10 additions & 25 deletions unyt/tests/test_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from packaging.version import Version

from unyt import A, K, cm, degC, degF, delta_degC, g, km, rad, s
from unyt._array_functions import _HANDLED_FUNCTIONS as HANDLED_FUNCTIONS
from unyt._array_functions import (
_HANDLED_FUNCTIONS as HANDLED_FUNCTIONS,
_UNSUPPORTED_FUNCTIONS as UNSUPPORTED_FUNCTIONS,
)
from unyt.array import unyt_array, unyt_quantity
from unyt.exceptions import (
InvalidUnitOperation,
Expand Down Expand Up @@ -150,36 +153,16 @@
np.take_along_axis, # works out of the box (tested)
}

# Functions that are wrappable but don't really make sense with units
# Functions for which behaviour is intentionally left to default
IGNORED_FUNCTIONS = {
# Polynomials
np.poly,
np.polyadd,
np.polyder,
np.polydiv,
np.polyfit,
np.polyint,
np.polymul,
np.polysub,
np.polyval,
np.roots,
np.vander,
np.i0,
# IO functions (no way to add units)
np.save,
np.savez,
np.savez_compressed,
# datetime64 is not a sensible dtype for unyt_array
np.datetime_as_string,
np.busday_count,
np.busday_offset,
np.is_busday,
# not clear how to approach
np.piecewise, # astropy.units doens't have a simple implementation either
np.packbits,
np.unpackbits,
np.i0,
}


# this set represents all functions that need inspection, tests, or both
# it is always possible that some of its elements belong in NOOP_FUNCTIONS
TODO_FUNCTIONS = {
Expand All @@ -206,7 +189,9 @@
"cumproduct", # deprecated in numpy 1.25
}

NOT_HANDLED_FUNCTIONS = NOOP_FUNCTIONS | TODO_FUNCTIONS | IGNORED_FUNCTIONS
NOT_HANDLED_FUNCTIONS = (
NOOP_FUNCTIONS | TODO_FUNCTIONS | UNSUPPORTED_FUNCTIONS | IGNORED_FUNCTIONS
)

for func in DEPRECATED_FUNCTIONS:
if hasattr(np, func):
Expand Down

0 comments on commit d17c0da

Please sign in to comment.