-
Notifications
You must be signed in to change notification settings - Fork 516
[Blocked] [call_jax] Bridge the torch_xla and JAX mesh #8972
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Test: run a JAX SPMD function that accesses the ambient SPMD mesh from |
In #9038, we realized this is blocked on letting JAX discover the same devices as PyTorch/XLA. In multi-slice training, JAX right now doesn't discover all the slices due to b/374631442 and #8609 (comment) We can't automatically set the JAX contextual mesh as that would break too many multi-slice workflows. |
Draft PR: #9043 |
xb.call_jax
should update the JAX mesh context object to contain the same devices and in the same order before entering the JAX function. This is necessary to ensure that any SPMD computation in JAX has the same semantics as the SPMD computation torch_xla.There is some logic for that in the splash attention kernel and we can factor it out.
Blocked status
In #9038, we realized this is blocked on letting JAX discover the same devices as PyTorch/XLA.
In multi-slice training, JAX right now doesn't discover all the slices due to b/374631442 and #8609 (comment)
We can't automatically set the JAX contextual mesh as that would break too many multi-slice workflows.
The text was updated successfully, but these errors were encountered: