diff --git a/src/haliax/partitioning.py b/src/haliax/partitioning.py index 48e8b49..5171eae 100644 --- a/src/haliax/partitioning.py +++ b/src/haliax/partitioning.py @@ -139,10 +139,11 @@ def _do_device_put(named): sharding = infer_resource_partitions(named, mapping, mesh=mesh, preserve_existing_shardings=False) assert isinstance(sharding, NamedSharding) + in_sharding = named.array.sharding if is_in_jit(): return with_sharding_constraint(named, sharding) # as a special case, SingleDeviceShardings are routed through jit - elif isinstance(named.array.sharding, SingleDeviceSharding): + elif isinstance(in_sharding, SingleDeviceSharding) and in_sharding._device in sharding.device_set: # TODO(dlwh): this should be unnecessary in JAX soon. Check after 2024-08-01 sharded_array = jax.jit(lambda x: x, out_shardings=sharding)(named) return sharded_array