Skip to content

TST: run parameterized tests on more libraries #227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,13 +649,9 @@ def device(x: Array, /) -> Device:
return "cpu"
elif is_dask_array(x):
# Peek at the metadata of the jax array to determine type
try:
import numpy as np
if isinstance(x._meta, np.ndarray):
# Must be on CPU since backed by numpy
return "cpu"
except ImportError:
pass
if is_numpy_array(x._meta):
# Must be on CPU since backed by numpy
return "cpu"
return _DASK_DEVICE
elif is_jax_array(x):
# JAX has .device() as a method, but it is being deprecated so that it
Expand Down
5 changes: 5 additions & 0 deletions docs/supported-array-libraries.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,8 @@ The minimum supported Dask version is 2023.12.0.
## [Sparse](https://sparse.pydata.org/en/stable/)

Similar to JAX, `sparse` Array API support is contained directly in `sparse`.

(array-api-strict-support)=
## [array-api-strict](https://data-apis.org/array-api-strict/)

array-api-strict exists only to test support for the Array API, so it does not need any wrappers.
10 changes: 2 additions & 8 deletions tests/_helpers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from importlib import import_module
import sys

import pytest

wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"]
all_libraries = wrapped_libraries + ["jax.numpy"]
all_libraries = wrapped_libraries + ["array_api_strict", "jax.numpy", "sparse"]

# `sparse` added array API support as of Python 3.10.
if sys.version_info >= (3, 10):
all_libraries.append('sparse')

def import_(library, wrapper=False):
if library == 'cupy':
Expand All @@ -20,9 +16,7 @@ def import_(library, wrapper=False):
jax_numpy = import_module("jax.numpy")
if not hasattr(jax_numpy, "__array_api_version__"):
library = 'jax.experimental.array_api'
elif library.startswith('sparse'):
library = 'sparse'
else:
elif library in wrapped_libraries:
library = 'array_api_compat.' + library

return import_module(library)
5 changes: 4 additions & 1 deletion tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

@pytest.mark.parametrize("library", ["common"] + wrapped_libraries)
def test_all(library):
import_(library, wrapper=True)
if library == "common":
import array_api_compat.common # noqa: F401
else:
import_(library, wrapper=True)

for mod_name in sys.modules:
if not mod_name.startswith('array_api_compat.' + library):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_array_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

@pytest.mark.parametrize("use_compat", [True, False, None])
@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12", "2023.12"])
@pytest.mark.parametrize("library", all_libraries + ['array_api_strict'])
@pytest.mark.parametrize("library", all_libraries)
def test_array_namespace(library, api_version, use_compat):
xp = import_(library)

array = xp.asarray([1.0, 2.0, 3.0])
if use_compat is True and library in {'array_api_strict', 'jax.numpy', 'sparse'}:
if use_compat and library not in wrapped_libraries:
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
return
namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)
Expand Down
24 changes: 19 additions & 5 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
is_dask_array, is_jax_array, is_pydata_sparse_array,
is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
is_array_api_strict_namespace,
)

from array_api_compat import (
Expand All @@ -33,6 +34,7 @@
'dask.array': 'is_dask_namespace',
'jax.numpy': 'is_jax_namespace',
'sparse': 'is_pydata_sparse_namespace',
'array_api_strict': 'is_array_api_strict_namespace',
}


Expand Down Expand Up @@ -74,7 +76,12 @@ def test_xp_is_array_generics(library):
is_func = globals()[func]
if is_func(x0):
matches.append(library2)
assert matches in ([library], ["numpy"])

if library == "array_api_strict":
# There is no is_array_api_strict_array() function
assert matches == []
else:
assert matches in ([library], ["numpy"])


@pytest.mark.parametrize("library", all_libraries)
Expand Down Expand Up @@ -213,26 +220,33 @@ def test_to_device_host(library):
@pytest.mark.parametrize("target_library", is_array_functions.keys())
@pytest.mark.parametrize("source_library", is_array_functions.keys())
def test_asarray_cross_library(source_library, target_library, request):
if source_library == "dask.array" and target_library == "torch":
def _xfail(reason: str) -> None:
# Allow rest of test to execute instead of immediately xfailing
# xref https://github.com/pandas-dev/pandas/issues/38902
request.node.add_marker(pytest.mark.xfail(reason=reason))

if source_library == "dask.array" and target_library == "torch":
# TODO: remove xfail once
# https://github.com/dask/dask/issues/8260 is resolved
request.node.add_marker(pytest.mark.xfail(reason="Bug in dask raising error on conversion"))
if source_library == "cupy" and target_library != "cupy":
_xfail(reason="Bug in dask raising error on conversion")
elif source_library == "jax.numpy" and target_library == "torch":
_xfail(reason="casts int to float")
elif source_library == "cupy" and target_library != "cupy":
# cupy explicitly disallows implicit conversions to CPU
pytest.skip(reason="cupy does not support implicit conversion to CPU")
elif source_library == "sparse" and target_library != "sparse":
pytest.skip(reason="`sparse` does not allow implicit densification")

src_lib = import_(source_library, wrapper=True)
tgt_lib = import_(target_library, wrapper=True)
is_tgt_type = globals()[is_array_functions[target_library]]

a = src_lib.asarray([1, 2, 3])
a = src_lib.asarray([1, 2, 3], dtype=src_lib.int32)
b = tgt_lib.asarray(a)

assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"
assert b.dtype == tgt_lib.int32



@pytest.mark.parametrize("library", wrapped_libraries)
Expand Down
Loading