Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: at should not force overwrite in Dask when copy=None #135

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 13 additions & 17 deletions src/array_api_extra/_lib/_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,16 +275,11 @@ def _op(
msg = f"copy must be True, False, or None; got {copy!r}"
raise ValueError(msg)

if copy is None:
writeable = is_writeable_array(x)
copy = not writeable
elif copy:
writeable = None
else:
writeable = is_writeable_array(x)
writeable = None if copy else is_writeable_array(x)

# JAX inside jax.jit and Dask don't support in-place updates with boolean
# mask. However we can handle the common special case of 0-dimensional y
# JAX inside jax.jit doesn't support in-place updates with boolean
# masks; Dask exclusively supports __setitem__ but not iops.
# We can handle the common special case of 0-dimensional y
# with where(idx, y, x) instead.
if (
(is_dask_array(idx) or is_jax_array(idx))
Expand All @@ -293,21 +288,22 @@ def _op(
):
y_xp = xp.asarray(y, dtype=x.dtype)
if y_xp.ndim == 0:
if out_of_place_op:
if out_of_place_op: # add(), subtract(), ...
# FIXME: suppress inf warnings on dask with lazywhere
out = xp.where(idx, out_of_place_op(x, y_xp), x)
# Undo int->float promotion on JAX after _AtOp.DIVIDE
out = xp.astype(out, x.dtype, copy=False)
else:
else: # set()
out = xp.where(idx, y_xp, x)

if copy:
return out
x[()] = out
return x
if copy is False:
x[()] = out
return x
return out

# else: this will work on eager JAX and crash on jax.jit and Dask

if copy:
if copy or (copy is None and not writeable):
if is_jax_array(x):
# Use JAX's at[]
func = cast(Callable[[Array], Array], getattr(x.at[idx], at_op.value))
Expand All @@ -331,7 +327,7 @@ def _op(
msg = f"Can't update read-only array {x}"
raise ValueError(msg)

if in_place_op:
if in_place_op: # add(), subtract(), ...
x[self._idx] = in_place_op(x[self._idx], y)
else: # set()
x[self._idx] = y
Expand Down
169 changes: 130 additions & 39 deletions tests/test_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Callable, Generator
from contextlib import contextmanager
from types import ModuleType
from typing import Any, cast
from typing import cast

import numpy as np
import pytest
Expand All @@ -23,12 +23,13 @@
]


def at_op( # type: ignore[no-any-explicit]
def at_op(
x: Array,
idx: Index,
op: _AtOp,
y: Array | object,
**kwargs: Any, # Test the default copy=None
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array:
"""
Wrapper around at(x, idx).op(y, copy=copy, xp=xp).
Expand All @@ -39,30 +40,33 @@ def at_op( # type: ignore[no-any-explicit]
which is not a common use case.
"""
if isinstance(idx, (slice | tuple)):
return _at_op(x, None, pickle.dumps(idx), op, y, **kwargs)
return _at_op(x, idx, None, op, y, **kwargs)
return _at_op(x, None, pickle.dumps(idx), op, y, copy=copy, xp=xp)
return _at_op(x, idx, None, op, y, copy=copy, xp=xp)


def _at_op( # type: ignore[no-any-explicit]
def _at_op(
x: Array,
idx: Index | None,
idx_pickle: bytes | None,
op: _AtOp,
y: Array | object,
**kwargs: Any,
copy: bool | None,
xp: ModuleType | None = None,
) -> Array:
"""jitted helper of at_op"""
if idx_pickle:
idx = pickle.loads(idx_pickle)
meth = cast(Callable[..., Array], getattr(at(x, idx), op.value)) # type: ignore[no-any-explicit]
return meth(y, **kwargs)
return meth(y, copy=copy, xp=xp)


lazy_xp_function(_at_op, static_argnames=("op", "idx_pickle", "copy", "xp"))


@contextmanager
def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
def assert_copy(
array: Array, copy: bool | None, expect_copy: bool | None = None
) -> Generator[None, None, None]:
if copy is False and not is_writeable_array(array):
with pytest.raises((TypeError, ValueError)):
yield
Expand All @@ -72,24 +76,23 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
array_orig = xp.asarray(array, copy=True)
yield

if copy is None:
copy = not is_writeable_array(array)
xp_assert_equal(xp.all(array == array_orig), xp.asarray(copy))
if expect_copy is None:
expect_copy = copy

if expect_copy:
# Original has not been modified
xp_assert_equal(array, array_orig)
elif expect_copy is False:
# Original has been modified
with pytest.raises(AssertionError):
xp_assert_equal(array, array_orig)
# Test nothing for copy=None. Dask changes behaviour depending on
# whether it's a special case of a bool mask with scalar RHS or not.


@pytest.mark.parametrize("copy", [False, True, None])
@pytest.mark.parametrize(
("kwargs", "expect_copy"),
[
pytest.param({"copy": True}, True, id="copy=True"),
pytest.param({"copy": False}, False, id="copy=False"),
# Behavior is backend-specific
pytest.param({"copy": None}, None, id="copy=None"),
# Test that the copy parameter defaults to None
pytest.param({}, None, id="no copy kwarg"),
],
)
@pytest.mark.parametrize(
("op", "y", "expect"),
("op", "y", "expect_list"),
[
(_AtOp.SET, 40.0, [10.0, 40.0, 40.0]),
(_AtOp.ADD, 40.0, [10.0, 60.0, 70.0]),
Expand All @@ -102,14 +105,13 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
],
)
@pytest.mark.parametrize(
("bool_mask", "shaped_y"),
("bool_mask", "x_ndim", "y_ndim"),
[
(False, False),
(False, True),
(True, False), # Uses xp.where(idx, y, x) on JAX and Dask
(False, 1, 0),
(False, 1, 1),
(True, 1, 0), # Uses xp.where(idx, y, x) on JAX and Dask
pytest.param(
True,
True,
*(True, 1, 1),
marks=(
pytest.mark.skip_xp_backend( # test passes when copy=False
Backend.JAX, reason="bool mask update with shaped rhs"
Expand All @@ -119,29 +121,65 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
),
),
),
(False, 0, 0),
(True, 0, 0),
],
)
def test_update_ops(
xp: ModuleType,
kwargs: dict[str, bool | None],
expect_copy: bool | None,
copy: bool | None,
op: _AtOp,
y: float,
expect: list[float],
expect_list: list[float],
bool_mask: bool,
shaped_y: bool,
x_ndim: int,
y_ndim: int,
):
x = xp.asarray([10.0, 20.0, 30.0])
idx = xp.asarray([False, True, True]) if bool_mask else slice(1, None)
if shaped_y:
if x_ndim == 1:
x = xp.asarray([10.0, 20.0, 30.0])
idx = xp.asarray([False, True, True]) if bool_mask else slice(1, None)
expect: list[float] | float = expect_list
else:
idx = xp.asarray(True) if bool_mask else ()
# Pick an element that does change with the operation
if op is _AtOp.MIN:
x = xp.asarray(30.0)
expect = expect_list[2]
else:
x = xp.asarray(20.0)
expect = expect_list[1]

if y_ndim == 1:
y = xp.asarray([y, y])

with assert_copy(x, expect_copy):
z = at_op(x, idx, op, y, **kwargs)
with assert_copy(x, copy):
z = at_op(x, idx, op, y, copy=copy)
assert isinstance(z, type(x))
xp_assert_equal(z, xp.asarray(expect))


@pytest.mark.parametrize("op", list(_AtOp))
def test_copy_default(xp: ModuleType, library: Backend, op: _AtOp):
"""
Test that the default copy behaviour is False for writeable arrays
and True for read-only ones.
"""
x = xp.asarray([1.0, 10.0, 20.0])
expect_copy = not is_writeable_array(x)
meth = cast(Callable[..., Array], getattr(at(x)[:2], op.value)) # type: ignore[no-any-explicit]
with assert_copy(x, None, expect_copy):
_ = meth(2.0)

x = xp.asarray([1.0, 10.0, 20.0])
# Dask's default copy value is True for bool masks,
# even if the arrays are writeable.
expect_copy = not is_writeable_array(x) or library is Backend.DASK
idx = xp.asarray([True, True, False])
meth = cast(Callable[..., Array], getattr(at(x, idx), op.value)) # type: ignore[no-any-explicit]
with assert_copy(x, None, expect_copy):
_ = meth(2.0)


def test_copy_invalid():
a = np.asarray([1, 2, 3])
with pytest.raises(ValueError, match="copy"):
Expand Down Expand Up @@ -259,3 +297,56 @@ def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
# inf - inf -> nan with a warning
z = at_op(x, idx, _AtOp.SUBTRACT, math.inf)
xp_assert_equal(z, xp.asarray([math.inf, -math.inf, -math.inf]))


@pytest.mark.parametrize(
"copy",
[
None,
pytest.param(
False,
marks=[
pytest.mark.skip_xp_backend(
Backend.NUMPY, reason="np.generic is read-only"
),
pytest.mark.skip_xp_backend(
Backend.NUMPY_READONLY, reason="read-only backend"
),
pytest.mark.skip_xp_backend(Backend.JAX, reason="read-only backend"),
pytest.mark.skip_xp_backend(Backend.SPARSE, reason="read-only backend"),
pytest.mark.xfail_xp_backend(Backend.DASK, reason="dask/dask#11722"),
],
),
],
)
@pytest.mark.parametrize(
"bool_mask",
[
pytest.param(
False,
marks=pytest.mark.xfail_xp_backend(Backend.DASK, reason="dask/dask#11722"),
),
True,
],
)
def test_gh134(xp: ModuleType, bool_mask: bool, copy: bool | None):
"""
Test that xpx.at doesn't encroach in a bug of dask.array.Array.__setitem__, which
blindly assumes that chunk contents are writeable np.ndarray objects:

https://github.com/dask/dask/issues/11722

In other words: when special-casing bool masks for Dask, unless the user explicitly
asks for copy=False, do not needlessly write back to the input.
"""
x = xp.zeros(1)

# In numpy, we have a writeable np.ndarray in input and a read-only np.generic in
# output. As both are Arrays, this behaviour is Array API compliant.
# In Dask, we have a writeable da.Array on both sides, and if you call __setitem__
# on it all seems fine, but when you compute() your graph is corrupted.
y = x[0]

idx = xp.asarray(True) if bool_mask else ()
z = at_op(y, idx, _AtOp.SET, 1, copy=copy)
xp_assert_equal(z, xp.asarray(1, dtype=x.dtype))