Skip to content

Commit 4e7603a

Browse files
committed
Run more tests on array-api-strict and sparse
1 parent beac55b commit 4e7603a

File tree

6 files changed

+35
-23
lines changed

6 files changed

+35
-23
lines changed

array_api_compat/common/_helpers.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -631,13 +631,9 @@ def device(x: Array, /) -> Device:
631631
return "cpu"
632632
elif is_dask_array(x):
633633
# Peek at the metadata of the jax array to determine type
634-
try:
635-
import numpy as np
636-
if isinstance(x._meta, np.ndarray):
637-
# Must be on CPU since backed by numpy
638-
return "cpu"
639-
except ImportError:
640-
pass
634+
if is_numpy_array(x._meta):
635+
# Must be on CPU since backed by numpy
636+
return "cpu"
641637
return _DASK_DEVICE
642638
elif is_jax_array(x):
643639
# JAX has .device() as a method, but it is being deprecated so that it

docs/supported-array-libraries.md

+5
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,8 @@ The minimum supported Dask version is 2023.12.0.
137137
## [Sparse](https://sparse.pydata.org/en/stable/)
138138

139139
Similar to JAX, `sparse` Array API support is contained directly in `sparse`.
140+
141+
(array-api-strict-support)=
142+
## [array-api-strict](https://data-apis.org/array-api-strict/)
143+
144+
array-api-strict exists only to test support for the Array API, so it does not need any wrappers.

tests/_helpers.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
from importlib import import_module
2-
import sys
32

43
import pytest
54

65
wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"]
7-
all_libraries = wrapped_libraries + ["jax.numpy"]
6+
all_libraries = wrapped_libraries + ["array_api_strict", "jax.numpy", "sparse"]
87

9-
# `sparse` added array API support as of Python 3.10.
10-
if sys.version_info >= (3, 10):
11-
all_libraries.append('sparse')
128

139
def import_(library, wrapper=False):
1410
if library == 'cupy':
@@ -20,9 +16,7 @@ def import_(library, wrapper=False):
2016
jax_numpy = import_module("jax.numpy")
2117
if not hasattr(jax_numpy, "__array_api_version__"):
2218
library = 'jax.experimental.array_api'
23-
elif library.startswith('sparse'):
24-
library = 'sparse'
25-
else:
19+
elif library in wrapped_libraries:
2620
library = 'array_api_compat.' + library
2721

2822
return import_module(library)

tests/test_all.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818

1919
@pytest.mark.parametrize("library", ["common"] + wrapped_libraries)
2020
def test_all(library):
21-
import_(library, wrapper=True)
21+
if library == "common":
22+
import array_api_compat.common # noqa: F401
23+
else:
24+
import_(library, wrapper=True)
2225

2326
for mod_name in sys.modules:
2427
if not mod_name.startswith('array_api_compat.' + library):

tests/test_array_namespace.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414

1515
@pytest.mark.parametrize("use_compat", [True, False, None])
1616
@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12", "2023.12"])
17-
@pytest.mark.parametrize("library", all_libraries + ['array_api_strict'])
17+
@pytest.mark.parametrize("library", all_libraries)
1818
def test_array_namespace(library, api_version, use_compat):
1919
xp = import_(library)
2020

2121
array = xp.asarray([1.0, 2.0, 3.0])
22-
if use_compat is True and library in {'array_api_strict', 'jax.numpy', 'sparse'}:
22+
if use_compat and library not in wrapped_libraries:
2323
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
2424
return
2525
namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)

tests/test_common.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
is_dask_array, is_jax_array, is_pydata_sparse_array,
44
is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
55
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
6+
is_array_api_strict_namespace,
67
)
78

89
from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device
@@ -30,6 +31,7 @@
3031
'dask.array': 'is_dask_namespace',
3132
'jax.numpy': 'is_jax_namespace',
3233
'sparse': 'is_pydata_sparse_namespace',
34+
'array_api_strict': 'is_array_api_strict_namespace',
3335
}
3436

3537

@@ -71,7 +73,12 @@ def test_xp_is_array_generics(library):
7173
is_func = globals()[func]
7274
if is_func(x0):
7375
matches.append(library2)
74-
assert matches in ([library], ["numpy"])
76+
77+
if library == "array_api_strict":
78+
# There is no is_array_api_strict_array() function
79+
assert matches == []
80+
else:
81+
assert matches in ([library], ["numpy"])
7582

7683

7784
@pytest.mark.parametrize("library", all_libraries)
@@ -128,26 +135,33 @@ def test_to_device_host(library):
128135
@pytest.mark.parametrize("target_library", is_array_functions.keys())
129136
@pytest.mark.parametrize("source_library", is_array_functions.keys())
130137
def test_asarray_cross_library(source_library, target_library, request):
131-
if source_library == "dask.array" and target_library == "torch":
138+
def _xfail(reason: str) -> None:
132139
# Allow rest of test to execute instead of immediately xfailing
133140
# xref https://github.com/pandas-dev/pandas/issues/38902
141+
request.node.add_marker(pytest.mark.xfail(reason=reason))
134142

143+
if source_library == "dask.array" and target_library == "torch":
135144
# TODO: remove xfail once
136145
# https://github.com/dask/dask/issues/8260 is resolved
137-
request.node.add_marker(pytest.mark.xfail(reason="Bug in dask raising error on conversion"))
138-
if source_library == "cupy" and target_library != "cupy":
146+
_xfail(reason="Bug in dask raising error on conversion")
147+
elif source_library == "jax.numpy" and target_library == "torch":
148+
_xfail(reason="casts int to float")
149+
elif source_library == "cupy" and target_library != "cupy":
139150
# cupy explicitly disallows implicit conversions to CPU
140151
pytest.skip(reason="cupy does not support implicit conversion to CPU")
141152
elif source_library == "sparse" and target_library != "sparse":
142153
pytest.skip(reason="`sparse` does not allow implicit densification")
154+
143155
src_lib = import_(source_library, wrapper=True)
144156
tgt_lib = import_(target_library, wrapper=True)
145157
is_tgt_type = globals()[is_array_functions[target_library]]
146158

147-
a = src_lib.asarray([1, 2, 3])
159+
a = src_lib.asarray([1, 2, 3], dtype=src_lib.int32)
148160
b = tgt_lib.asarray(a)
149161

150162
assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"
163+
assert b.dtype == tgt_lib.int32
164+
151165

152166
@pytest.mark.parametrize("library", wrapped_libraries)
153167
def test_asarray_copy(library):

0 commit comments

Comments
 (0)