-
Notifications
You must be signed in to change notification settings - Fork 515
Improve assume_pure docs and tests #9038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
I think what's happening with torch_xla.devices() has 4 devices but jax.devices() only have 1. One possibility is that jax[cuda] was not installed so jax.devices() returnts one device and that is the CPU device. Last time I tried to add the install and hit a different error. I am OK with disabling the test for CUDA until later. |
That is a great point! This made me realize that if we land this PR as-is, then not only will I think I'll split out the call_jax part of this PR into a separate one, and unfortunately that one can't be landed unless we fix the PJRT client sharing between PyTorch/XLA and JAX. |
8e682cd
to
d83ffca
Compare
8b110a4
to
debf063
Compare
Draft for registering Jax contextual mesh in call_jax: #9043 (informational only) |
Beefed up the assume_pure tests and updated the docs to mention that mark_sharding is supported thanks to qihqi@'s #8989.
Also update yapf in the dev image to match CI.