Skip to content

Commit

Permalink
hnadjklhnad
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Jun 27, 2024
1 parent 46abeff commit 40cd6fe
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/haliax/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,26 +128,26 @@ def shard(x: T, mapping: Optional[ResourceMapping] = None, mesh: Optional[Mesh]
warnings.warn("Sharding constraints are not supported in jit on metal", RuntimeWarning)
return x

def _do_device_put(x):
if not isinstance(x, NamedArray):
return x
def _do_device_put(named):
if not isinstance(named, NamedArray):
return named

if not is_jax_array_like(x.array):
if not is_jax_array_like(named.array):
# this happens when we filter out params for things like lora.
# could use eqx.partition to avoid this, but eh
return x
return named

sharding = infer_resource_partitions(x, mapping, mesh=mesh, preserve_existing_shardings=False)
sharding = infer_resource_partitions(named, mapping, mesh=mesh, preserve_existing_shardings=False)
assert isinstance(sharding, NamedSharding)
if is_in_jit():
return with_sharding_constraint(x, sharding)
return with_sharding_constraint(named, sharding)
# as a special case, SingleDeviceShardings are routed through jit
elif isinstance(x.array.sharding, SingleDeviceSharding):
elif isinstance(named.array.sharding, SingleDeviceSharding):
# TODO(dlwh): this should be unnecessary in JAX soon. Check after 2024-08-01
sharded_array = jax.jit(lambda x: x, out_shardings=sharding)(x)
return NamedArray(sharded_array, x.axes)
sharded_array = jax.jit(lambda x: x, out_shardings=sharding)(named)
return sharded_array
else:
ret = jax.device_put(x, sharding)
ret = jax.device_put(named, sharding)
return ret

return htu.tree_map(_do_device_put, x)
Expand Down

0 comments on commit 40cd6fe

Please sign in to comment.