Skip to content

Commit

Permalink
add errors when scan contract isn't respected, fix logic bug in array…
Browse files Browse the repository at this point in the history
… slicing
  • Loading branch information
dlwh committed Sep 3, 2024
1 parent eaf6a38 commit bf807bb
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/haliax/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def wrapper(*args, **kwargs):


def is_jax_array_like(x):
return hasattr(x, "shape") and hasattr(x, "dtype")
return hasattr(x, "shape") and hasattr(x, "dtype") # and not isinstance(x, haliax.NamedArray)


# adapted from jax but exposed so i can use it
Expand Down
16 changes: 13 additions & 3 deletions src/haliax/nn/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,14 @@ def scan(self, init: T, *extra_args, **extra_kwargs):
(block_args, block_kwargs) = haliax.tree_util.tree_map(
functools.partial(BlockSeq._slice_out, self.Block, i), (extra_args, extra_kwargs)
)
carry, extra = block(carry, *block_args, **block_kwargs)
block_result = block(carry, *block_args, **block_kwargs)
if not isinstance(block_result, (tuple, list)) or len(block_result) != 2:
raise ValueError(
f"BlockSeq.scan expects the block to return a pair of (carry, extra), got {block_result}"
)

carry, extra = block_result

out.append(extra)

# TODO: do we want to stack the outputs?
Expand All @@ -124,8 +131,11 @@ def unstacked(self) -> Sequence[M]:

@staticmethod
def _slice_out(Block, i, x):
if haliax.is_named_array(x) and haliax.selects_axis(x.axes, Block):
return x[Block, i]
if haliax.is_named_array(x):
if haliax.selects_axis(x.axes, Block):
return x[Block, i]
else:
return x
elif haliax.jax_utils.is_jax_array_like(x):
return x[i]
else:
Expand Down
40 changes: 39 additions & 1 deletion tests/test_scan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import equinox as eqx
import jax
import pytest

import haliax as hax
from haliax.nn.scan import BlockSeq, Stacked
Expand Down Expand Up @@ -66,9 +67,14 @@ def init(named, array, static):
x = hax.random.uniform(jax.random.PRNGKey(1), (E,))
y = m.fold(x, key=jax.random.split(jax.random.PRNGKey(2), Block.size))
y_seq = m_seq.fold(x, key=jax.random.split(jax.random.PRNGKey(2), Block.size))

assert hax.all(hax.isclose(y, y_seq, atol=1e-5))

with pytest.raises(ValueError):
m.scan(x, key=jax.random.split(jax.random.PRNGKey(2), Block.size))

with pytest.raises(ValueError):
m_seq.scan(x, key=jax.random.split(jax.random.PRNGKey(2), Block.size))


def test_using_scan():
class Module(eqx.Module):
Expand All @@ -95,3 +101,35 @@ def init(named, array, static):

assert y.axes == (E,)
assert intermediates.axes == (Block, E)


def test_scan_with_aux_named_args():
class Module(eqx.Module):
named: hax.NamedArray
array: jax.Array
static: int = eqx.static_field()

def __call__(self, x, y, *, key):
return x + self.array + self.static + hax.random.normal(key, x.axes), x * 2 + y

@staticmethod
def init(named, array, static):
return Module(named=named, array=array, static=static)

Block = hax.Axis("block", 4)
E = hax.Axis("E", 10)

initial_named = hax.random.uniform(jax.random.PRNGKey(0), (Block, E))
initial_y = hax.random.uniform(jax.random.PRNGKey(1), (E,))

m = Stacked.init(Block, Module)(named=initial_named, array=jax.numpy.ones(Block.size), static=1)
m_seq = BlockSeq.init(Block, Module)(named=initial_named, array=jax.numpy.ones(Block.size), static=1)

x = hax.random.uniform(jax.random.PRNGKey(1), (E,))
z, z_scan = m.scan(x, initial_y, key=jax.random.split(jax.random.PRNGKey(2), Block.size))
z_seq, z_seq_scan = m_seq.scan(x, initial_y, key=jax.random.split(jax.random.PRNGKey(2), Block.size))
assert hax.all(hax.isclose(z, z_seq, atol=1e-5))

z_seq_scan = hax.stack(Block, z_seq_scan)

assert hax.all(hax.isclose(z_scan, z_seq_scan, atol=1e-5))

0 comments on commit bf807bb

Please sign in to comment.