Skip to content

Commit

Permalink
No default in_resources (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Jun 12, 2024
1 parent 072fc9e commit ac6dcae
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
2 changes: 0 additions & 2 deletions src/haliax/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,6 @@ def _call(self, is_lower, *args, **kwargs):
my_pjit_args = dict(**self._pjit_args)

if in_axis_resources is not None or axis_resources is not None:
if in_axis_resources is None:
in_axis_resources = axis_resources
in_resources = infer_resource_partitions(
(dynamic_donated, dynamic_reserved),
in_axis_resources,
Expand Down
26 changes: 24 additions & 2 deletions tests/test_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(self, in_array: NamedArray):

devices = jax.devices()
with Mesh(np.array(devices).reshape(-1, 1), (ResourceAxis.DATA, ResourceAxis.MODEL)):
mod = named_jit(ModWithArgs)(hax.ones((Dim1, Dim2)))
mod = named_jit(ModWithArgs)(hax.shard(hax.ones((Dim1, Dim2))))
assert isinstance(mod, ModWithArgs)
assert mod.array.array.shape == (Dim1.size, Dim2.size)
assert mod.array2.array.shape == (Dim3.size,)
Expand Down Expand Up @@ -173,7 +173,7 @@ def assert_eq(x, y):

jax.debug.inspect_array_sharding(arr.array, callback=lambda x: assert_eq(x, expected))

@named_jit(in_axis_resources={}, out_axis_resources=resource_map)
@named_jit(out_axis_resources=resource_map)
def do_shard(x, y):
x = hax.shard(x, resource_map)
assert_inside_pjit(x, NamedSharding(mesh, PartitionSpec(None, ResourceAxis.DATA)))
Expand Down Expand Up @@ -293,3 +293,25 @@ def test_cross_device_sharding():
z_devices = z.array.devices()

assert set(d.platform for d in x_devices) == set(d.platform for d in z_devices)


def test_named_jit_no_in_axis_resources():
mesh = Mesh(np.array(jax.devices()).reshape(-1, 1), (ResourceAxis.DATA, ResourceAxis.MODEL))
with axis_mapping(resource_map), mesh:

class MyModule(eqx.Module):
array: NamedArray

def __init__(self):
self.array = hax.ones((Dim1, Dim2))

data = hax.ones((Dim1, Dim2))
data = hax.shard(data, {})

@named_jit(axis_resources=resource_map)
def fn(data):
mod = MyModule()
return mod.array

r = fn(data)
assert r.array.sharding.device_set == set(jax.devices())

0 comments on commit ac6dcae

Please sign in to comment.