Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when following the JAX tutorial on "Introduction to parallel programming" #25541

Open
Stella-S-Yan opened this issue Dec 17, 2024 · 16 comments
Labels
bug Something isn't working

Comments

@Stella-S-Yan
Copy link
Collaborator

Description

I was following the JAX tutorial on 'Introduction to Parallel Programming' and encountered an error. The tutorial is designed for 8 accelerators, but I modified the settings for 4 accelerators. The code is:

import jax

import jax.numpy as jnp
arr = jnp.arange(32.0).reshape(4, 8)
arr.devices()

from jax.sharding import PartitionSpec as P
mesh = jax.make_mesh((2, 2), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)

arr_sharded = jax.device_put(arr, sharding)

@jax.jit
def f_contract_2(x):
  out = x.sum(axis=0)
  mesh = jax.make_mesh((4,), ('x',))
  sharding = jax.sharding.NamedSharding(mesh, P('x'))
  return jax.lax.with_sharding_constraint(out, sharding)

result = f_contract_2(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)

I received the following error message:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[102], line 35
     32   sharding = jax.sharding.NamedSharding(mesh, P('x'))
     33   return jax.lax.with_sharding_constraint(out, sharding)
---> 35 result = f_contract_2(arr_sharded)
     36 jax.debug.visualize_array_sharding(result)
     37 print(result)

    [... skipping hidden 2 frame]

File [~/miniconda3/envs/learn_jax/lib/python3.11/site-packages/jax/_src/pjit.py:206](http://localhost:8888/lab/workspaces/auto-G/tree/learn_jax/miniconda3/envs/learn_jax/lib/python3.11/site-packages/jax/_src/pjit.py#line=205), in _python_pjit_helper(fun, jit_info, *args, **kwargs)
    203   fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
    204   msg = _device_assignment_mismatch_error(
    205       fun_name, fails, args_flat, api_name, p.arg_names)
--> 206   raise ValueError(msg) from None
    207 except xla.InvalidInputException as e:
    208   arg_names = [''] * len(args_flat) if p.arg_names is None else p.arg_names

ValueError: Received incompatible devices for jitted computation. Got argument x of f_contract_2 with shape float32[4,8] and device ids [0, 1, 2, 3] on platform TPU and sharding_constraint inside jit with device ids [0, 2, 1, 3] on platform TPU at [/tmp/ipykernel_91499/2091669677.py:33:9](http://localhost:8888/tmp/ipykernel_91499/2091669677.py#line=32) (f_contract_2)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.37
jaxlib: 0.4.36
numpy:  2.2.0
python: 3.11.11 (main, Dec 11 2024, 16:28:39) [GCC 11.2.0]
device info: TPU v4-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-04c9a2e6-w-0', release='5.19.0-1022-gcp', version='#24~22.04.1-Ubuntu SMP Sun Apr 23 09:51:08 UTC 2023', machine='x86_64')
@Stella-S-Yan Stella-S-Yan added the bug Something isn't working label Dec 17, 2024
@yashk2810
Copy link
Collaborator

Is this on CPU/TPU/GPU?

But looks like this is working as expected. The mesh throughout the program needs to be the same. You can't use different mesh for inputs and with_sharding_constraint.

@Stella-S-Yan
Copy link
Collaborator Author

This is on TPU. If you check the tutorial, it first creates a (2, 4) mesh and then applies with_sharding_constraint to a (8,) mesh. It seems that JAX allows this, although the mesh needs to be consistent across the program. The mismatch in mesh dimensions might work in certain cases, but it could depend on the specific context and JAX's handling of mesh transformations.

@yashk2810
Copy link
Collaborator

As long as the device order is the same, jit should not complain.

So maybe the action item here is to change the tutorial to only use 1 mesh and not create another one inside jit?

@skye
Copy link
Member

skye commented Dec 17, 2024

Remind me why we don't wanna lift this restriction? Was it too expensive to rewrite the shardings against a single device assignment?

@yashk2810
Copy link
Collaborator

Yes, it's too expensive and lowering usually should not care about devices (there are also a lot of considerations with respect to cache hits too that would be broken).

And there's no real use case for it.

@skye
Copy link
Member

skye commented Dec 17, 2024

Can we just do the rewrite if there are multiple meshes? Then it won't be a regression, since the current behavior is an error.

I don't really understand about lowering not caring about devices... it's just a different sharding. Same with caching, I agree we have to be careful to cache correctly, but presumably we already cache on shardings?

@yashk2810
Copy link
Collaborator

Can we just do the rewrite if there are multiple meshes?

What do you mean by this? rewrite the tutorial? (if yes, that's what I suggested above too)

@skye
Copy link
Member

skye commented Dec 17, 2024

I meant to address it being too expensive. We only need to do the expensive rewrite of shardings if there are multiple meshes. Like assume the first mesh you see is the "default" mesh.

@yashk2810
Copy link
Collaborator

Like assume the first mesh you see is the "default" mesh

I don't think that's a good idea.

But what problem are you solving here? What's the use case? This never happens in practice. You only have 1 mesh.

@Stella-S-Yan
Copy link
Collaborator Author

It seems the use case would be any scenario where jax.lax.with_sharding_constraint() is useful, if such a scenario exists.

@yashk2810
Copy link
Collaborator

I don't think so. There are thousands of with_sharding_constraint everywhere and all of them have the same mesh.

In fact, you should not even talk about devices when specifying a with_sharding_constraint. You should use AbstractMesh instead. Our docs don't talk about that but that's the recommended path.

@skye
Copy link
Member

skye commented Dec 17, 2024

There are thousands of with_sharding_constraint everywhere and all of them have the same mesh.

Of course they all have the same mesh, because it's an error if they don't :P

The fact that we wrote this in our tutorial means someone might naturally try to express shardings this way. A concrete use case would be switching between a mesh over hosts for data loading to a mesh over ICI physical axes for compute (e.g. to reshard input data over ICI). It's always possible to write everything in terms of one mesh, but it might be easier for users to use multiple meshes and let jax do the rewrite, instead of making users do the rewrite.

I didn't know about AbstractMesh (#23022). This is nice, but adds extra boilerplate if you're not dealing with multiple meshes (of the same shape + names), so I don't think it should be the default choice. And it doesn't preclude actually want to use different meshes.

BTW, we'll likely want to support tracing over different meshes if we ever want to support mjit-style tracing MPMD programs right?

@yashk2810
Copy link
Collaborator

Of course they all have the same mesh,

it's a performance footgun too. And there's no real use for switching a mesh in the middle of a computation AFAIK.

A concrete use case would be switching between a mesh over hosts for data loading to a mesh over ICI physical axes for compute (e.g. to reshard input data over ICI)

If I understand what you mean, this already works without any mesh changes. You just need to change the PartitionSpec for it (not the underlying mesh). This pattern is used all over I think.

but it might be easier for users to use multiple meshes and let jax do the rewrite

Just do it outside jit with a device_put. The cost is the same as it will insert a collective permute in both cases. I don't buy the extra convenience argument given the amount of complexity involved in JAX to support that. It's better for users to do that explicitly.

so I don't think it should be the default choice

I think it should be. You should never talk about devices inside jit (just like we have abstract values, we should have an AbstractMesh). It leads to all kinds of problems that the PR description explains. There are more changes that sharding in types is going to bring in this area but the point is never ever talk about devices inside jit.

we'll likely want to support tracing over different meshes if we ever want to support mjit-style tracing MPMD programs right?

Well, mjit already supports that. But within each jit inside a mjit, the mesh stays the same. The mesh can change outside of each jit and that's perfectly fine.

@yashk2810
Copy link
Collaborator

#25553 fixes the tutorial BTW

@skye
Copy link
Member

skye commented Dec 18, 2024

Just do it outside jit with a device_put.

Can you do this with multi-controller jax?

@yashk2810
Copy link
Collaborator

yashk2810 commented Dec 18, 2024

Can you do this with multi-controller jax?

Yup! That's the best part! It's actually being used by some cloud customers! Try it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants