Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

at.set behavior for Dask arrays? #134

Open
mdhaber opened this issue Feb 5, 2025 · 7 comments · May be fixed by #135
Open

at.set behavior for Dask arrays? #134

mdhaber opened this issue Feb 5, 2025 · 7 comments · May be fixed by #135

Comments

@mdhaber
Copy link
Contributor

mdhaber commented Feb 5, 2025

Re: scipy/scipy#22469 (comment)

I encountered some code involving Dask and at.set that works as expected in 0.6.0 but not in the more recent commit used by SciPy (c23ac01).

import numpy as np
import dask.array as xp
from scipy._lib._array_api import array_namespace

# import array_api_extra as xpx  # this works
import scipy._lib.array_api_extra as xpx  # this doesn't

x = xp.asarray([1., 2., 3.])
y = xp.asarray([1., 2.5, 3.])
xp = array_namespace(x, y)

r = xp.vecdot(x, y)
mask = xp.asarray(False)
y = xp.asarray(xp.nan)

# r[mask] = xp.nan  # This works
r = xpx.at(r, mask).set(y)  # This doesn't

np.asarray(r)
# TypeError: 'numpy.float64' object does not support item assignment

But this brings up a question - why does array_api_extra.at.set get involved with Dask here anyway? The documentation of Dask Array says:
image

The last line is what's relevant. Doing the regular indexing assignment (when y is a scalar, at least) seems to be something Dask can do, so why replace that with a where?

@crusaderky
Copy link
Contributor

crusaderky commented Feb 6, 2025

The last line is what's relevant. Doing the regular indexing assignment (when y is a scalar, at least) seems to be something Dask can do, so why replace that with a where?

The documentation you reference is specifically for assignment. In-place operations don't work with bool masks:

>>> import dask.array as da
>>> a = da.arange(5)
>>> mask = da.zeros_like(a, dtype=bool)
>>> a[mask] = 1  # OK
>>> a[mask] += 1

ValueError: Boolean index assignment in Dask expects equally shaped arrays.
Example: da1[da2] = da3 where da1.shape == (4,), da2.shape == (4,) and da3.shape == (4,).
Alternatively, you can use the extended API that supportsindexing with tuples.
Example: da1[(da2,)] = da3.

What's happening under the hood, which also explains the misleading exception, is that there is no dask.array.Array.__iadd__. So when you ask for a[mask] += 1, the Python interpreter executes
a[mask] = a[mask] + 1, which fails because both lhs and rhs have shape=(nan, ). (Not defending Dask here - this is clearly something fixable on their end).


I've reproduced your failure. It's a bug in da.Array.__getitem__(idx).__setitem__(()) where idx selects a scalar.
The object returned by __getitem__ is another da.Array, which is writeable, but internally the chunk contains a np.generic, which is not.
Moving on to __setitem__, dask blindly assumes that its chunks are always writeable np.ndarray objects:

    import dask.array as da
    x = xp.zeros(1)
    y = x[0]  # a writeable da.Array, but the chunk is a read-only np.generic
    y[()] = 1  # No failure here
    y.compute()  # TypeError: 'numpy.float64' object does not support item assignment

In turn, this caused a call inside xpx.at to array_api_compat.is_writeable_array(y) to return True, under the assumption that all Dask arrays are writeable - and in fact y is writeable, in the sense that y[()] = 1 doesn't raise, but now you have a corrupted graph.

FWIW the code above fails for numpy too, but numpy is technically compliant with the API standard here because the standard caters for read-only arrays, but makes no provision that a function (in this case __getitem__, but np.vecdot is the same) that takes in input a writeable array must also return one - which leads to surprising behaviour.

@crusaderky
Copy link
Contributor

Upstream issue: dask/dask#11722

xpx patch: #135
This fixes the special case of bool masks, which was unnecessarily overwriting the input in Dask, but doesn't fix a tuple index: xpx.at(r, ()).set(xp.nan) will keep failing for as long as r[()] = xp.nan fails.

@crusaderky
Copy link
Contributor

Upstream fix: dask/dask#11723

@mdhaber
Copy link
Contributor Author

mdhaber commented Feb 6, 2025

Thanks for taking a look, creating the upstream issue, and providing a patch for xpx. When that merges, I guess we can update the SciPy commit and merge scipy/scipy#22469, where this came up.

In that context, though (specifically this line), it does seem like an in-place assignment would be preferable to where, since it's dealing with an exceptional/unusual/rare case. I think it would be nice if we let Dask do in-place assignment where possible.

@crusaderky
Copy link
Contributor

crusaderky commented Feb 7, 2025

I think it would be nice if we let Dask do in-place assignment where possible.

Except... Dask in-place assignment with bool masks are internally implemented with where, but unlike array-api-extra they failed to implement dtype propagation rules:

>>> import numpy as np
>>> a = np.asarray([1,2])
>>> a[[True, False]] = 3.3
>>> a
array([3, 2])

>>> import dask.array as da
>>> a = da.asarray([1, 2])
>>> a[da.asarray([True, False])] = 3.3
>>> a
dask.array<where, shape=(2,), dtype=float64, chunksize=(2,), chunktype=numpy.ndarray>
>>> a.compute()
array([3.3, 2. ])

@crusaderky
Copy link
Contributor

dask/dask#11724

@mdhaber
Copy link
Contributor Author

mdhaber commented Feb 7, 2025

Ah, ok then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants