Skip to content

Improve assume_pure docs and tests #9038

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 26, 2025
Merged

Improve assume_pure docs and tests #9038

merged 4 commits into from
Apr 26, 2025

Conversation

tengyifei
Copy link
Collaborator

@tengyifei tengyifei commented Apr 25, 2025

Beefed up the assume_pure tests and updated the docs to mention that mark_sharding is supported thanks to qihqi@'s #8989.

Also update yapf in the dev image to match CI.

@tengyifei tengyifei marked this pull request as ready for review April 25, 2025 01:49
@tengyifei tengyifei requested a review from qihqi April 25, 2025 01:49
@qihqi
Copy link
Collaborator

qihqi commented Apr 25, 2025

I think what's happening with ValueError: torch_xla device ID [1 2 3] not found in available JAX devices is:

torch_xla.devices() has 4 devices but jax.devices() only have 1. One possibility is that jax[cuda] was not installed so jax.devices() returnts one device and that is the CPU device.

Last time I tried to add the install and hit a different error. I am OK with disabling the test for CUDA until later.

@tengyifei
Copy link
Collaborator Author

I think what's happening with ValueError: torch_xla device ID [1 2 3] not found in available JAX devices is [...]

That is a great point!

This made me realize that if we land this PR as-is, then not only will call_jax not work on GPUs, it also won't work under multi-slice TPUs. That in turn means shard_as won't work, and scan won't work either, preventing training e.g. Llama 3.1 405B.

I think I'll split out the call_jax part of this PR into a separate one, and unfortunately that one can't be landed unless we fix the PJRT client sharing between PyTorch/XLA and JAX.

@tengyifei tengyifei force-pushed the yifeit/call-jax-mesh branch from 8e682cd to d83ffca Compare April 25, 2025 18:42
@tengyifei tengyifei changed the title [call_jax] Bridge the torch_xla and JAX mesh and improve assume_pure Improve assume_pure docs and tests Apr 25, 2025
@tengyifei tengyifei enabled auto-merge (squash) April 25, 2025 19:39
Now we can run a JAX SPMD function that accesses the ambient SPMD mesh
from xb.call_jax.

Fixes #8972.

Also I beefed up the assume_pure tests and updated the docs to mention
that mark_sharding is supported thanks to qihqi@'
#8989.
@tengyifei tengyifei force-pushed the yifeit/call-jax-mesh branch from 8b110a4 to debf063 Compare April 25, 2025 22:34
@tengyifei
Copy link
Collaborator Author

Draft for registering Jax contextual mesh in call_jax: #9043 (informational only)

@tengyifei tengyifei merged commit 7d681a9 into master Apr 26, 2025
23 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants