Skip to content

Commit

Permalink
Merge pull request #400 from neutrinoceros/nep18_last
Browse files Browse the repository at this point in the history
ENH: (NEP 18) implement and test np.linalg.svd
  • Loading branch information
ngoldbaum committed Mar 31, 2023
2 parents d17c0da + c9f8f6e commit 7c88ad6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
14 changes: 14 additions & 0 deletions unyt/_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
np.piecewise, # astropy.units doens't have a simple implementation either
np.packbits,
np.unpackbits,
np.ix_,
}

_HANDLED_FUNCTIONS = {}
Expand Down Expand Up @@ -119,6 +120,19 @@ def linalg_pinv(a, *args, **kwargs):
return np.linalg.pinv._implementation(a, *args, **kwargs).view(np.ndarray) / a.units


@implements(np.linalg.svd)
def linalg_svd(a, full_matrices=True, compute_uv=True, *args, **kwargs):
ret_units = a.units
retv = np.linalg.svd._implementation(
a.view(np.ndarray), full_matrices, compute_uv, *args, **kwargs
)
if compute_uv:
u, s, vh = retv
return (u, s * ret_units, vh)
else:
return retv * ret_units


def _sanitize_range(_range, units):
# helper function to histogram* functions
ndim = len(units)
Expand Down
24 changes: 14 additions & 10 deletions unyt/tests/test_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,6 @@
}


# this set represents all functions that need inspection, tests, or both
# it is always possible that some of its elements belong in NOOP_FUNCTIONS
TODO_FUNCTIONS = {
np.ix_,
np.linalg.svd,
}

DEPRECATED_FUNCTIONS = {
"alen", # deprecated in numpy 1.18, removed in 1.22
"asscalar", # deprecated in numpy 1.18, removed in 1.22
Expand All @@ -189,9 +182,7 @@
"cumproduct", # deprecated in numpy 1.25
}

NOT_HANDLED_FUNCTIONS = (
NOOP_FUNCTIONS | TODO_FUNCTIONS | UNSUPPORTED_FUNCTIONS | IGNORED_FUNCTIONS
)
NOT_HANDLED_FUNCTIONS = NOOP_FUNCTIONS | UNSUPPORTED_FUNCTIONS | IGNORED_FUNCTIONS

for func in DEPRECATED_FUNCTIONS:
if hasattr(np, func):
Expand Down Expand Up @@ -1166,6 +1157,19 @@ def test_eigvals(func):
assert w.units == cm


def test_linalg_svd():
a = (np.random.randn(9, 6) + 1j * np.random.randn(9, 6)) * cm
u, s, vh = np.linalg.svd(a)
assert type(u) is np.ndarray
assert type(vh) is np.ndarray
assert type(s) is unyt_array
assert s.units == cm

s = np.linalg.svd(a, compute_uv=False)
assert type(s) is unyt_array
assert s.units == cm


def test_savetxt(tmp_path):
a = [1, 2, 3] * cm
with pytest.raises(
Expand Down

0 comments on commit 7c88ad6

Please sign in to comment.