Skip to content

Commit

Permalink
narrow range where we use jit for sharding
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Jul 15, 2024
1 parent 40cd6fe commit edf6eb7
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/haliax/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,11 @@ def _do_device_put(named):

sharding = infer_resource_partitions(named, mapping, mesh=mesh, preserve_existing_shardings=False)
assert isinstance(sharding, NamedSharding)
in_sharding = named.array.sharding
if is_in_jit():
return with_sharding_constraint(named, sharding)
# as a special case, SingleDeviceShardings are routed through jit
elif isinstance(named.array.sharding, SingleDeviceSharding):
elif isinstance(in_sharding, SingleDeviceSharding) and in_sharding._device in sharding.device_set:
# TODO(dlwh): this should be unnecessary in JAX soon. Check after 2024-08-01
sharded_array = jax.jit(lambda x: x, out_shardings=sharding)(named)
return sharded_array
Expand Down

0 comments on commit edf6eb7

Please sign in to comment.