Skip to content

Commit 72d7635

Browse files
authored
more reduction tests (#11)
* boolean reduce * order reduce * index ordering tests * cumulative reduce tests `cumprod` / `cumulative_prod` are not part of the array API * skip `cumprod` since it's most likely not implemented * fix the expected value for cumulative reductions `Variable` implements n-d cum ops by iterating over the axes. * pin `numpy<2.1` until `xarray` issues a new release * resolve the `argmin` / `argmax` warnings by opting into the new behavior * try installing the nightly version of `xarray` * print the type of `actual` if it didn't match * try the recommended syntax for `pip` installing deps * put the options on a single line * upgrade packages * install nightly `xarray` as a separate step
1 parent 3c092cb commit 72d7635

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

.github/workflows/ci.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ jobs:
6565
python=${{matrix.python-version}}
6666
conda
6767
68+
- name: Install nightly xarray
69+
run: |
70+
python -m pip install --upgrade --pre -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple xarray
71+
6872
- name: Install xarray-array-testing
6973
run: |
7074
python -m pip install --no-deps -e .

xarray_array_testing/reduction.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from contextlib import nullcontext
22

33
import hypothesis.strategies as st
4+
import numpy as np
45
import pytest
56
import xarray.testing.strategies as xrst
67
from hypothesis import given
@@ -24,4 +25,76 @@ def test_variable_numerical_reduce(self, op, data):
2425
# compute using xp.<OP>(array)
2526
expected = getattr(self.xp, op)(variable.data)
2627

28+
assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
29+
self.assert_equal(actual, expected)
30+
31+
@pytest.mark.parametrize("op", ["all", "any"])
32+
@given(st.data())
33+
def test_variable_boolean_reduce(self, op, data):
34+
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
35+
36+
with self.expected_errors(op, variable=variable):
37+
# compute using xr.Variable.<OP>()
38+
actual = getattr(variable, op)().data
39+
# compute using xp.<OP>(array)
40+
expected = getattr(self.xp, op)(variable.data)
41+
42+
assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
43+
self.assert_equal(actual, expected)
44+
45+
@pytest.mark.parametrize("op", ["max", "min"])
46+
@given(st.data())
47+
def test_variable_order_reduce(self, op, data):
48+
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
49+
50+
with self.expected_errors(op, variable=variable):
51+
# compute using xr.Variable.<OP>()
52+
actual = getattr(variable, op)().data
53+
# compute using xp.<OP>(array)
54+
expected = getattr(self.xp, op)(variable.data)
55+
56+
assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
57+
self.assert_equal(actual, expected)
58+
59+
@pytest.mark.parametrize("op", ["argmax", "argmin"])
60+
@given(st.data())
61+
def test_variable_order_reduce_index(self, op, data):
62+
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
63+
64+
with self.expected_errors(op, variable=variable):
65+
# compute using xr.Variable.<OP>()
66+
actual = {k: v.item() for k, v in getattr(variable, op)(dim=...).items()}
67+
68+
# compute using xp.<OP>(array)
69+
index = getattr(self.xp, op)(variable.data)
70+
unraveled = np.unravel_index(index, variable.shape)
71+
expected = dict(zip(variable.dims, unraveled))
72+
73+
self.assert_equal(actual, expected)
74+
75+
@pytest.mark.parametrize(
76+
"op",
77+
[
78+
"cumsum",
79+
pytest.param(
80+
"cumprod",
81+
marks=pytest.mark.skip(reason="not yet included in the array api"),
82+
),
83+
],
84+
)
85+
@given(st.data())
86+
def test_variable_cumulative_reduce(self, op, data):
87+
array_api_names = {"cumsum": "cumulative_sum", "cumprod": "cumulative_prod"}
88+
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
89+
90+
with self.expected_errors(op, variable=variable):
91+
# compute using xr.Variable.<OP>()
92+
actual = getattr(variable, op)().data
93+
# compute using xp.<OP>(array)
94+
# Variable implements n-d cumulative ops by iterating over dims
95+
expected = variable.data
96+
for axis in range(variable.ndim):
97+
expected = getattr(self.xp, array_api_names[op])(expected, axis=axis)
98+
99+
assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
27100
self.assert_equal(actual, expected)

0 commit comments

Comments
 (0)