Skip to content

[Blocked] [call_jax] Bridge the torch_xla and JAX mesh #8972

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
tengyifei opened this issue Apr 14, 2025 · 3 comments
Open

[Blocked] [call_jax] Bridge the torch_xla and JAX mesh #8972

tengyifei opened this issue Apr 14, 2025 · 3 comments
Assignees
Labels
enhancement New feature or request torchxla2

Comments

@tengyifei
Copy link
Collaborator

tengyifei commented Apr 14, 2025

xb.call_jax should update the JAX mesh context object to contain the same devices and in the same order before entering the JAX function. This is necessary to ensure that any SPMD computation in JAX has the same semantics as the SPMD computation torch_xla.

There is some logic for that in the splash attention kernel and we can factor it out.


Blocked status

In #9038, we realized this is blocked on letting JAX discover the same devices as PyTorch/XLA.

In multi-slice training, JAX right now doesn't discover all the slices due to b/374631442 and #8609 (comment)

We can't automatically set the JAX contextual mesh as that would break too many multi-slice workflows.

@tengyifei tengyifei self-assigned this Apr 16, 2025
@ysiraichi ysiraichi added enhancement New feature or request torchxla2 labels Apr 17, 2025
@tengyifei
Copy link
Collaborator Author

Test: run a JAX SPMD function that accesses the ambient SPMD mesh from xb.call_jax. It should still work and ends up using the devices in the PyTorch/XLA SPMD mesh.

tengyifei added a commit that referenced this issue Apr 25, 2025
Now we can run a JAX SPMD function that accesses the ambient SPMD mesh
from xb.call_jax.

Fixes #8972.

Also I beefed up the assume_pure tests and updated the docs to mention
that mark_sharding is supported thanks to qihqi@'
#8989.
tengyifei added a commit that referenced this issue Apr 25, 2025
Now we can run a JAX SPMD function that accesses the ambient SPMD mesh
from xb.call_jax.

Fixes #8972.

Also I beefed up the assume_pure tests and updated the docs to mention
that mark_sharding is supported thanks to qihqi@'
#8989.
@tengyifei tengyifei changed the title [call_jax] Bridge the torch_xla and JAX mesh [Blocked] [call_jax] Bridge the torch_xla and JAX mesh Apr 25, 2025
@tengyifei
Copy link
Collaborator Author

In #9038, we realized this is blocked on letting JAX discover the same devices as PyTorch/XLA.

In multi-slice training, JAX right now doesn't discover all the slices due to b/374631442 and #8609 (comment)

We can't automatically set the JAX contextual mesh as that would break too many multi-slice workflows.

@tengyifei
Copy link
Collaborator Author

Draft PR: #9043

tengyifei added a commit that referenced this issue Apr 25, 2025
Now we can run a JAX SPMD function that accesses the ambient SPMD mesh
from xb.call_jax.

Fixes #8972.

Also I beefed up the assume_pure tests and updated the docs to mention
that mark_sharding is supported thanks to qihqi@'
#8989.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request torchxla2
Projects
None yet
Development

No branches or pull requests

2 participants