-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error when following the JAX tutorial on "Introduction to parallel programming" #25541
Comments
Is this on CPU/TPU/GPU? But looks like this is working as expected. The mesh throughout the program needs to be the same. You can't use different mesh for inputs and with_sharding_constraint. |
This is on TPU. If you check the tutorial, it first creates a (2, 4) mesh and then applies with_sharding_constraint to a (8,) mesh. It seems that JAX allows this, although the mesh needs to be consistent across the program. The mismatch in mesh dimensions might work in certain cases, but it could depend on the specific context and JAX's handling of mesh transformations. |
As long as the device order is the same, jit should not complain. So maybe the action item here is to change the tutorial to only use 1 mesh and not create another one inside jit? |
Remind me why we don't wanna lift this restriction? Was it too expensive to rewrite the shardings against a single device assignment? |
Yes, it's too expensive and lowering usually should not care about devices (there are also a lot of considerations with respect to cache hits too that would be broken). And there's no real use case for it. |
Can we just do the rewrite if there are multiple meshes? Then it won't be a regression, since the current behavior is an error. I don't really understand about lowering not caring about devices... it's just a different sharding. Same with caching, I agree we have to be careful to cache correctly, but presumably we already cache on shardings? |
What do you mean by this? rewrite the tutorial? (if yes, that's what I suggested above too) |
I meant to address it being too expensive. We only need to do the expensive rewrite of shardings if there are multiple meshes. Like assume the first mesh you see is the "default" mesh. |
I don't think that's a good idea. But what problem are you solving here? What's the use case? This never happens in practice. You only have 1 mesh. |
It seems the use case would be any scenario where jax.lax.with_sharding_constraint() is useful, if such a scenario exists. |
I don't think so. There are thousands of with_sharding_constraint everywhere and all of them have the same mesh. In fact, you should not even talk about devices when specifying a with_sharding_constraint. You should use |
Of course they all have the same mesh, because it's an error if they don't :P The fact that we wrote this in our tutorial means someone might naturally try to express shardings this way. A concrete use case would be switching between a mesh over hosts for data loading to a mesh over ICI physical axes for compute (e.g. to reshard input data over ICI). It's always possible to write everything in terms of one mesh, but it might be easier for users to use multiple meshes and let jax do the rewrite, instead of making users do the rewrite. I didn't know about AbstractMesh (#23022). This is nice, but adds extra boilerplate if you're not dealing with multiple meshes (of the same shape + names), so I don't think it should be the default choice. And it doesn't preclude actually want to use different meshes. BTW, we'll likely want to support tracing over different meshes if we ever want to support mjit-style tracing MPMD programs right? |
it's a performance footgun too. And there's no real use for switching a mesh in the middle of a computation AFAIK.
If I understand what you mean, this already works without any mesh changes. You just need to change the PartitionSpec for it (not the underlying mesh). This pattern is used all over I think.
Just do it outside
I think it should be. You should never talk about devices inside jit (just like we have abstract values, we should have an AbstractMesh). It leads to all kinds of problems that the PR description explains. There are more changes that sharding in types is going to bring in this area but the point is never ever talk about devices inside jit.
Well, |
#25553 fixes the tutorial BTW |
Can you do this with multi-controller jax? |
Yup! That's the best part! It's actually being used by some cloud customers! Try it! |
Description
I was following the JAX tutorial on 'Introduction to Parallel Programming' and encountered an error. The tutorial is designed for 8 accelerators, but I modified the settings for 4 accelerators. The code is:
I received the following error message:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: