Skip to content

Commit b3eafaf

Browse files
authored
CI: Add array-api-tests CI job and upgrade Array API version (#121)
* CI: Add `array-api-tests` CI job * Upgrade Array API to 2024.12
1 parent 911e750 commit b3eafaf

9 files changed

+610
-11
lines changed

.github/workflows/ci.yml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,38 @@ jobs:
3131
poetry run pytest --junit-xml=test-${{ matrix.os }}-Python-${{ matrix.python }}.xml
3232
- uses: codecov/codecov-action@v3
3333

34+
array_api_tests:
35+
env:
36+
ARRAY_API_TESTS_DIR: ${{ github.workspace }}/array-api-tests
37+
runs-on: ubuntu-latest
38+
steps:
39+
- name: Checkout Repo
40+
uses: actions/checkout@v4
41+
- name: Checkout array-api-tests
42+
run: ci/clone_array_api_tests.sh
43+
- name: Set up Python
44+
uses: actions/setup-python@v5
45+
with:
46+
python-version: '3.11'
47+
cache: 'pip'
48+
- name: Install Poetry
49+
run: |
50+
pip install poetry
51+
- name: Build wheel
52+
run: |
53+
python -m poetry build --format wheel
54+
- name: Install build and test dependencies from PyPI
55+
run: |
56+
pip install dist/*.whl
57+
pip install -U setuptools wheel
58+
pip install pytest-xdist hypothesis==6.131.0 -r "$ARRAY_API_TESTS_DIR/requirements.txt"
59+
- name: Run the test suite
60+
env:
61+
ARRAY_API_TESTS_MODULE: finch
62+
run: |
63+
python -c 'import finch'
64+
pytest "$ARRAY_API_TESTS_DIR/array_api_tests/" -v -c "$ARRAY_API_TESTS_DIR/pytest.ini" --ci --max-examples=2 --derandomize --disable-deadline --disable-warnings -n auto --skips-file ci/array-api-skips.txt
65+
3466
on:
3567
# Trigger the workflow on push or pull request,
3668
# but only for the main branch

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2023 Willow Ahrens
3+
Copyright (c) 2025 Willow Ahrens
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

ci/array-api-skips.txt

Lines changed: 461 additions & 0 deletions
Large diffs are not rendered by default.

ci/array-api-tests-rev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
c48410f96fc58e02eea844e6b7f6cc01680f77ce

ci/clone_array_api_tests.sh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#!/usr/bin/env bash
2+
set -euxo pipefail
3+
4+
ARRAY_API_TESTS_DIR="${ARRAY_API_TESTS_DIR:-"../array-api-tests"}"
5+
if [ ! -d "$ARRAY_API_TESTS_DIR" ]; then
6+
git clone --recursive https://github.com/data-apis/array-api-tests.git "$ARRAY_API_TESTS_DIR"
7+
fi
8+
9+
git --git-dir="$ARRAY_API_TESTS_DIR/.git" --work-tree "$ARRAY_API_TESTS_DIR" clean -xddf
10+
git --git-dir="$ARRAY_API_TESTS_DIR/.git" --work-tree "$ARRAY_API_TESTS_DIR" fetch
11+
git --git-dir="$ARRAY_API_TESTS_DIR/.git" --work-tree "$ARRAY_API_TESTS_DIR" reset --hard $(cat "ci/array-api-tests-rev.txt")

src/finch/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@
151151
write,
152152
)
153153
from . import linalg
154+
from ._array_api_info import __array_namespace_info__
154155

155156
__all__ = [
156157
"Tensor",
@@ -287,4 +288,4 @@
287288
"linalg",
288289
]
289290

290-
__array_api_version__: str = "2023.12"
291+
__array_api_version__: str = "2024.12"

src/finch/_array_api_info.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from . import dtypes
2+
from .typing import DType
3+
4+
5+
class __array_namespace_info__:
6+
7+
def capabilities(self) -> dict[str, bool]:
8+
return {
9+
"boolean indexing": True, "data-dependent shapes": True,
10+
}
11+
12+
def default_device(self) -> str:
13+
return "cpu"
14+
15+
def default_dtypes(self, *, device: str | None = None) -> dict[str, DType]:
16+
if device not in ["cpu", None]:
17+
raise ValueError(
18+
"Device not understood. Only \"cpu\" is allowed, but "
19+
f"received: {device}"
20+
)
21+
return {
22+
"real floating": dtypes.float64,
23+
"complex floating": dtypes.complex128,
24+
"integral": dtypes.int_,
25+
"indexing": dtypes.int_,
26+
}
27+
28+
_bool_dtypes = {"bool": dtypes.bool}
29+
_signed_integer_dtypes = {
30+
"int8": dtypes.int8,
31+
"int16": dtypes.int16,
32+
"int32": dtypes.int32,
33+
"int64": dtypes.int64,
34+
}
35+
_unsigned_integer_dtypes = {
36+
"uint8": dtypes.uint8,
37+
"uint16": dtypes.uint16,
38+
"uint32": dtypes.uint32,
39+
"uint64": dtypes.uint64,
40+
}
41+
_real_floating_dtypes = {
42+
"float32": dtypes.float32,
43+
"float64": dtypes.float64,
44+
}
45+
_complex_floating_dtypes = {
46+
"complex64": dtypes.complex64,
47+
"complex128": dtypes.complex128,
48+
}
49+
50+
def dtypes(
51+
self,
52+
*,
53+
device: str | None = None,
54+
kind: str | tuple[str, ...] | None = None,
55+
) -> dict[str, DType]:
56+
if device not in ["cpu", None]:
57+
raise ValueError(
58+
"Device not understood. Only \"cpu\" is allowed, but "
59+
f"received: {device}"
60+
)
61+
if kind is None:
62+
return (
63+
self._bool_dtypes | self._signed_integer_dtypes |
64+
self._unsigned_integer_dtypes | self._real_floating_dtypes |
65+
self._complex_floating_dtypes
66+
)
67+
if kind == "bool":
68+
return self._bool_dtypes
69+
if kind == "signed integer":
70+
return self._signed_integer_dtypes
71+
if kind == "unsigned integer":
72+
return self._unsigned_integer_dtypes
73+
if kind == "integral":
74+
return self._signed_integer_dtypes | self._unsigned_integer_dtypes
75+
if kind == "real floating":
76+
return self._real_floating_dtypes
77+
if kind == "complex floating":
78+
return self._complex_floating_dtypes
79+
if kind == "numeric":
80+
return (
81+
self._signed_integer_dtypes | self._unsigned_integer_dtypes |
82+
self._real_floating_dtypes | self._complex_floating_dtypes
83+
)
84+
if isinstance(kind, tuple):
85+
res = {}
86+
for k in kind:
87+
res.update(self.dtypes(kind=k))
88+
return res
89+
raise ValueError(f"unsupported kind: {kind!r}")
90+
91+
def devices(self) -> list[str]:
92+
return ["cpu"]

src/finch/juliapkg.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"packages": {
33
"Finch": {
44
"uuid": "9177782c-1635-4eb9-9bfb-d9dfa25e6bce",
5-
"version": "1.2.5"
5+
"version": "1.2.7"
66
},
77
"HDF5": {
88
"uuid": "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f",

src/finch/tensor.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -645,9 +645,9 @@ def _raise_julia_copy_not_supported() -> None:
645645

646646
def __array_namespace__(self, *, api_version: str | None = None) -> Any:
647647
if api_version is None:
648-
api_version = "2023.12"
648+
api_version = "2024.12"
649649

650-
if api_version not in {"2021.12", "2022.12", "2023.12"}:
650+
if api_version not in {"2021.12", "2022.12", "2023.12", "2024.12"}:
651651
raise ValueError(f'"{api_version}" Array API version not supported.')
652652
import finch
653653

@@ -717,12 +717,13 @@ def reshape(
717717
) -> Tensor:
718718
if copy is False:
719719
raise ValueError("Unable to avoid copy during reshape.")
720-
dims = [dim if dim >= 0 else jl.Colon() for dim in shape]
721-
obj = jl.swizzle(x._obj, *tuple(reversed(range(1, jl.ndims(x._obj) + 1))))
722-
obj = jl.reshape(obj, *reversed(dims))
723-
obj = jl.swizzle(obj, *tuple(reversed(range(1, jl.ndims(obj) + 1))))
724-
return Tensor(obj)
725-
720+
# TODO: https://github.com/finch-tensor/Finch.jl/issues/743
721+
# Revert to `jl.reshape` implementation once aforementioned
722+
# issue is solved.
723+
warnings.warn("`reshape` densified the input tensor.", PerformanceWarning)
724+
arr = x.todense()
725+
arr = arr.reshape(shape)
726+
return Tensor(arr)
726727

727728
def full(
728729
shape: int | tuple[int, ...],

0 commit comments

Comments
 (0)