Skip to content

Commit

Permalink
Merge pull request #101 from asmeurer/more-linalg2
Browse files Browse the repository at this point in the history
More improvements to test_linalg
  • Loading branch information
honno authored Feb 26, 2024
2 parents 4f83bb3 + 0ddb0cd commit a1d7701
Show file tree
Hide file tree
Showing 9 changed files with 555 additions and 221 deletions.
22 changes: 17 additions & 5 deletions array_api_tests/array_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
# These are exported here so that they can be included in the special cases
# tests from this file.
from ._array_module import logical_not, subtract, floor, ceil, where
from . import _array_module as xp
from . import dtype_helpers as dh


__all__ = ['all', 'any', 'logical_and', 'logical_or', 'logical_not', 'less',
'less_equal', 'greater', 'subtract', 'negative', 'floor', 'ceil',
'where', 'isfinite', 'equal', 'not_equal', 'zero', 'one', 'NaN',
Expand Down Expand Up @@ -164,19 +164,21 @@ def notequal(x, y):

return not_equal(x, y)

def assert_exactly_equal(x, y):
def assert_exactly_equal(x, y, msg_extra=None):
"""
Test that the arrays x and y are exactly equal.
If x and y do not have the same shape and dtype, they are not considered
equal.
"""
assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape})"
extra = '' if not msg_extra else f' ({msg_extra})'

assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape}){extra}"

assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype})"
assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype}){extra}"

assert all(exactly_equal(x, y)), "The input arrays have different values"
assert all(exactly_equal(x, y)), f"The input arrays have different values ({x!r} != {y!r}){extra}"

def assert_finite(x):
"""
Expand Down Expand Up @@ -306,3 +308,13 @@ def same_sign(x, y):
def assert_same_sign(x, y):
assert all(same_sign(x, y)), "The input arrays do not have the same sign"

def _matrix_transpose(x):
if not isinstance(xp.matrix_transpose, xp._UndefinedStub):
return xp.matrix_transpose(x)
if hasattr(x, 'mT'):
return x.mT
if not isinstance(xp.permute_dims, xp._UndefinedStub):
perm = list(range(x.ndim))
perm[-1], perm[-2] = perm[-2], perm[-1]
return xp.permute_dims(x, axes=tuple(perm))
raise NotImplementedError("No way to compute matrix transpose")
51 changes: 51 additions & 0 deletions array_api_tests/dtype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,57 @@ class MinMax(NamedTuple):
{"complex64": xp.float32, "complex128": xp.float64}
)

def as_real_dtype(dtype):
"""
Return the corresponding real dtype for a given floating-point dtype.
"""
if dtype in real_float_dtypes:
return dtype
elif dtype_to_name[dtype] in complex_names:
return dtype_components[dtype]
else:
raise ValueError("as_real_dtype requires a floating-point dtype")

def accumulation_result_dtype(x_dtype, dtype_kwarg):
"""
Result dtype logic for sum(), prod(), and trace()
Note: may return None if a default uint cannot exist (e.g., for pytorch
which doesn't support uint32 or uint64). See https://github.com/data-apis/array-api-tests/issues/106
"""
if dtype_kwarg is None:
if is_int_dtype(x_dtype):
if x_dtype in uint_dtypes:
default_dtype = default_uint
else:
default_dtype = default_int
if default_dtype is None:
_dtype = None
else:
m, M = dtype_ranges[x_dtype]
d_m, d_M = dtype_ranges[default_dtype]
if m < d_m or M > d_M:
_dtype = x_dtype
else:
_dtype = default_dtype
elif is_float_dtype(x_dtype, include_complex=False):
if dtype_nbits[x_dtype] > dtype_nbits[default_float]:
_dtype = x_dtype
else:
_dtype = default_float
elif api_version > "2021.12":
# Complex dtype
if dtype_nbits[x_dtype] > dtype_nbits[default_complex]:
_dtype = x_dtype
else:
_dtype = default_complex
else:
raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.")
else:
_dtype = dtype_kwarg

return _dtype

if not hasattr(xp, "asarray"):
default_int = xp.int32
Expand Down
30 changes: 21 additions & 9 deletions array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
sampled_from, shared, builds)

from . import _array_module as xp, api_version
from . import array_helpers as ah
from . import dtype_helpers as dh
from . import shape_helpers as sh
from . import xps
Expand Down Expand Up @@ -211,6 +212,7 @@ def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False)

# Use this to avoid memory errors with NumPy.
# See https://github.com/numpy/numpy/issues/15753
# Note, the hypothesis default for max_dims is min_dims + 2 (i.e., 0 + 2)
def shapes(**kw):
kw.setdefault('min_dims', 0)
kw.setdefault('min_side', 0)
Expand Down Expand Up @@ -280,25 +282,29 @@ def mutually_broadcastable_shapes(

# Note: This should become hermitian_matrices when complex dtypes are added
@composite
def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True):
def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True, bound=10.):
shape = draw(square_matrix_shapes)
dtype = draw(dtypes)
if not isinstance(finite, bool):
finite = draw(finite)
elements = {'allow_nan': False, 'allow_infinity': False} if finite else None
a = draw(arrays(dtype=dtype, shape=shape, elements=elements))
upper = xp.triu(a)
lower = xp.triu(a, k=1).mT
return upper + lower
at = ah._matrix_transpose(a)
H = (a + at)*0.5
if finite:
assume(not xp.any(xp.isinf(H)))
assume(xp.all((H == 0.) | ((1/bound <= xp.abs(H)) & (xp.abs(H) <= bound))))
return H

@composite
def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()):
# For now just generate stacks of identity matrices
# TODO: Generate arbitrary positive definite matrices, for instance, by
# using something like
# https://github.com/scikit-learn/scikit-learn/blob/844b4be24/sklearn/datasets/_samples_generator.py#L1351.
n = draw(integers(0))
shape = draw(shapes()) + (n, n)
base_shape = draw(shapes())
n = draw(integers(0, 8)) # 8 is an arbitrary small but interesting-enough value
shape = base_shape + (n, n)
assume(prod(i for i in shape if i) < MAX_ARRAY_SIZE)
dtype = draw(dtypes)
return broadcast_to(eye(n, dtype=dtype), shape)
Expand All @@ -308,12 +314,18 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes(
# For now, just generate stacks of diagonal matrices.
n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),)
stack_shape = draw(stack_shapes)
d = draw(arrays(dtypes, shape=(*stack_shape, 1, n),
elements=dict(allow_nan=False, allow_infinity=False)))
dtype = draw(dtypes)
elements = one_of(
from_dtype(dtype, min_value=0.5, allow_nan=False, allow_infinity=False),
from_dtype(dtype, max_value=-0.5, allow_nan=False, allow_infinity=False),
)
d = draw(arrays(dtype, shape=(*stack_shape, 1, n), elements=elements))

# Functions that require invertible matrices may do anything when it is
# singular, including raising an exception, so we make sure the diagonals
# are sufficiently nonzero to avoid any numerical issues.
assume(xp.all(xp.abs(d) > 0.5))
assert xp.all(xp.abs(d) >= 0.5)

diag_mask = xp.arange(n) == xp.reshape(xp.arange(n), (n, 1))
return xp.where(diag_mask, d, xp.zeros_like(d))

Expand Down
16 changes: 16 additions & 0 deletions array_api_tests/meta/test_linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

from hypothesis import given

from ..hypothesis_helpers import symmetric_matrices
from .. import array_helpers as ah
from .. import _array_module as xp

@pytest.mark.xp_extension('linalg')
@given(x=symmetric_matrices(finite=True))
def test_symmetric_matrices(x):
upper = xp.triu(x)
lower = xp.tril(x)
lowerT = ah._matrix_transpose(lower)

ah.assert_exactly_equal(upper, lowerT)
Loading

0 comments on commit a1d7701

Please sign in to comment.