|
3 | 3 | is_dask_array, is_jax_array, is_pydata_sparse_array,
|
4 | 4 | is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
|
5 | 5 | is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
|
| 6 | + is_array_api_strict_namespace, |
6 | 7 | )
|
7 | 8 |
|
8 | 9 | from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device
|
|
30 | 31 | 'dask.array': 'is_dask_namespace',
|
31 | 32 | 'jax.numpy': 'is_jax_namespace',
|
32 | 33 | 'sparse': 'is_pydata_sparse_namespace',
|
| 34 | + 'array_api_strict': 'is_array_api_strict_namespace', |
33 | 35 | }
|
34 | 36 |
|
35 | 37 |
|
@@ -71,7 +73,12 @@ def test_xp_is_array_generics(library):
|
71 | 73 | is_func = globals()[func]
|
72 | 74 | if is_func(x0):
|
73 | 75 | matches.append(library2)
|
74 |
| - assert matches in ([library], ["numpy"]) |
| 76 | + |
| 77 | + if library == "array_api_strict": |
| 78 | + # There is no is_array_api_strict_array() function |
| 79 | + assert matches == [] |
| 80 | + else: |
| 81 | + assert matches in ([library], ["numpy"]) |
75 | 82 |
|
76 | 83 |
|
77 | 84 | @pytest.mark.parametrize("library", all_libraries)
|
@@ -128,26 +135,33 @@ def test_to_device_host(library):
|
128 | 135 | @pytest.mark.parametrize("target_library", is_array_functions.keys())
|
129 | 136 | @pytest.mark.parametrize("source_library", is_array_functions.keys())
|
130 | 137 | def test_asarray_cross_library(source_library, target_library, request):
|
131 |
| - if source_library == "dask.array" and target_library == "torch": |
| 138 | + def _xfail(reason: str) -> None: |
132 | 139 | # Allow rest of test to execute instead of immediately xfailing
|
133 | 140 | # xref https://github.com/pandas-dev/pandas/issues/38902
|
| 141 | + request.node.add_marker(pytest.mark.xfail(reason=reason)) |
134 | 142 |
|
| 143 | + if source_library == "dask.array" and target_library == "torch": |
135 | 144 | # TODO: remove xfail once
|
136 | 145 | # https://github.com/dask/dask/issues/8260 is resolved
|
137 |
| - request.node.add_marker(pytest.mark.xfail(reason="Bug in dask raising error on conversion")) |
138 |
| - if source_library == "cupy" and target_library != "cupy": |
| 146 | + _xfail(reason="Bug in dask raising error on conversion") |
| 147 | + elif source_library == "jax.numpy" and target_library == "torch": |
| 148 | + _xfail(reason="casts int to float") |
| 149 | + elif source_library == "cupy" and target_library != "cupy": |
139 | 150 | # cupy explicitly disallows implicit conversions to CPU
|
140 | 151 | pytest.skip(reason="cupy does not support implicit conversion to CPU")
|
141 | 152 | elif source_library == "sparse" and target_library != "sparse":
|
142 | 153 | pytest.skip(reason="`sparse` does not allow implicit densification")
|
| 154 | + |
143 | 155 | src_lib = import_(source_library, wrapper=True)
|
144 | 156 | tgt_lib = import_(target_library, wrapper=True)
|
145 | 157 | is_tgt_type = globals()[is_array_functions[target_library]]
|
146 | 158 |
|
147 |
| - a = src_lib.asarray([1, 2, 3]) |
| 159 | + a = src_lib.asarray([1, 2, 3], dtype=src_lib.int32) |
148 | 160 | b = tgt_lib.asarray(a)
|
149 | 161 |
|
150 | 162 | assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"
|
| 163 | + assert b.dtype == tgt_lib.int32 |
| 164 | + |
151 | 165 |
|
152 | 166 | @pytest.mark.parametrize("library", wrapped_libraries)
|
153 | 167 | def test_asarray_copy(library):
|
|
0 commit comments