diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index eb1e726..83aed2e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -62,8 +62,8 @@ jobs: - name: Test package run: >- - python -m pytest -ra --cov --cov-report=xml --cov-report=term - --durations=20 + python -m pytest tests/ src/ docs/ -ra --cov --cov-report=xml + --cov-report=term --durations=20 - name: Upload coverage report uses: codecov/codecov-action@v4.1.0 diff --git a/src/array_api_jax_compat/_elementwise_functions.py b/src/array_api_jax_compat/_elementwise_functions.py index 650745e..0301fb0 100644 --- a/src/array_api_jax_compat/_elementwise_functions.py +++ b/src/array_api_jax_compat/_elementwise_functions.py @@ -215,7 +215,12 @@ def isfinite(x: ArrayLike, /) -> Value: @quaxify def isinf(x: ArrayLike, /) -> Value: - return array_api.isinf(x) + # Jax `isinf` makes a numpy array with value `inf` and then compares it with + # the input array. If the input array cannot be compared to base numpy + # arrays, e.g. a Quantity with units, then Jax's `isinf` will raise an + # unwanted error. Instead, we just negate the `isfinite` function, which + # should work for all array-like inputs. + return ~array_api.isfinite(x) @quaxify