Skip to content

Commit 0a2160b

Browse files
committed
Also ignore None in array_namespace
1 parent bd536d6 commit 0a2160b

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

array_api_compat/common/_helpers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def array_namespace(*xs, api_version=None, use_compat=None):
236236
----------
237237
xs: arrays
238238
one or more arrays. xs can also be Python scalars (bool, int, float,
239-
or complex), which are ignored.
239+
complex, or None), which are ignored.
240240
241241
api_version: str
242242
The newest version of the spec that you need support for (currently
@@ -299,7 +299,7 @@ def your_function(x, y):
299299

300300
namespaces = set()
301301
for x in xs:
302-
if isinstance(x, (bool, int, float, complex)):
302+
if isinstance(x, (bool, int, float, complex, type(None))):
303303
continue
304304
elif is_numpy_array(x):
305305
from .. import numpy as numpy_namespace

tests/test_array_namespace.py

+2
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,10 @@ def test_python_scalars():
101101
pytest.raises(TypeError, lambda: array_namespace(1.0))
102102
pytest.raises(TypeError, lambda: array_namespace(1j))
103103
pytest.raises(TypeError, lambda: array_namespace(True))
104+
pytest.raises(TypeError, lambda: array_namespace(None))
104105

105106
assert array_namespace(a, 1) == xp
106107
assert array_namespace(a, 1.0) == xp
107108
assert array_namespace(a, 1j) == xp
108109
assert array_namespace(a, True) == xp
110+
assert array_namespace(a, None) == xp

0 commit comments

Comments
 (0)