-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add mask
argument to lax.argmax
#25623
Comments
Thanks for the request! I don't think >>> import jax
>>> argmax = partial(jax.lax.argmax, axis=0, index_dtype='int32')
>>> jax.make_jaxpr(argmax)(jax.numpy.arange(10))
{ lambda ; a:i32[10]. let
b:i32[] = argmax[axes=(0,) index_dtype=int32] a
in (b,) } The functions you propose would lower to multiple primitives, so they're not really a great fit for There's a couple ways we could move forward with an idea like this:
What do you think? |
@jakevdp If there's no place to add such a function, I guess we'll have to wait for
It's a bit of a shame that we have to wait for numpy to add useful functionality. |
We don't have to wait for NumPy, but a concern here is that we add some keyword argument to JAX, and in the future NumPy adds a conflicting keyword argument. Then it becomes a somewhat painful deprecation cycle in order to match upstream semantics. We went through this in the past year with the Array API, and I'd like to avoid that if possible. |
That's a valid concern, I agree. Is there a way to directly "petition" for the inclusion of a mask argument to argmax in the Array API? (And also apparently max, for that matter.) To be honest, I'm not entirely clear on the orientation of design decisions for JAX, NumPy, and the Array API. NumPy itself is aiming to target the Array API. In the long term, is JAX aiming to target NumPy, or the Array API directly (which NumPy itself is aiming to target)? My understanding is that the Array API is generated by an informal kind of "consensus" between popular libraries (such as NumPy, JAX, and PyTorch). Doesn't that make the whole thing kind of circular, a catch-22, if we're trying to make design decisions? |
Historically, before the Array API existed, |
I don't think there's anything circular here: if the Array API adds a function or argument to its spec, JAX will adopt it, as will NumPy. If NumPy adds a function or argument to its implementation, JAX will adopt it. There is some influence back up: for example the Array API maintainers have historically been careful not to introduce things that will conflict with existing symbols, but that's not a hard requirement (for example, JAX historically had an |
Thanks for the clarification. Do you know the answer to this question, by any chance?
IMO, it makes logical sense that any operation involving reduction(s) should take an optional |
The best way to propose this is probably to open an issue at https://github.com/data-apis/array-api proposing the change. |
Done: data-apis/array-api#875. |
Feature request: Add a
mask
argument tojax.lax.argmax
(andjax.lax.argmin
) consisting of an array of booleans that indicate which elements to include in the computation.Here is an example implementation:
Code
I can submit a PR for this.
Related, but for
jax.numpy
: #20177.A more general solution that could be re-used in the future for other operations would be to add a
mask
argument tojax.lax.reduce
that controls which elements to include in the reduction. Example:Code
The text was updated successfully, but these errors were encountered: