Skip to content

Commit

Permalink
Merge branch 'main' into intersect-axes
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Sep 6, 2024
2 parents 00585c6 + 019fa2d commit 16ed1f1
Show file tree
Hide file tree
Showing 15 changed files with 83 additions and 463 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ Head = hax.Axis("head", 8) # number of attention heads
Key = hax.Axis("key", 64) # key size
Embed = hax.Axis("embed", 512) # embedding size

# alternatively:
#Pos, KPos, Head, Key, Embed = hax.make_axes(pos=1024, key_pos=1024, head=8, key=64, embed=512)


def attention_scores(Key, KPos, query, key, mask):
# how similar is each query to each key
Expand Down
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Occasionally, an axis size can be inferred in some circumstances but not others.

### Axis Manipulation

::: haliax.make_axes
::: haliax.axis.axis_name
::: haliax.axis.concat_axes
::: haliax.axis.union_axes
Expand Down
4 changes: 4 additions & 0 deletions src/haliax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ds,
dslice,
eliminate_axes,
make_axes,
selects_axis,
)
from .core import (
Expand Down Expand Up @@ -893,6 +894,9 @@ def true_divide(x1: NamedOrNumeric, x2: NamedOrNumeric, /) -> NamedOrNumeric:
"AxisSpec",
"AxisSelection",
"AxisSelector",
"make_axes",
"axis_name",
"axis_size",
"NamedArray",
"broadcast_to",
"broadcast_axis",
Expand Down
11 changes: 11 additions & 0 deletions src/haliax/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ def __str__(self):
return f"{self.name}({self.size})"


def make_axes(**kwargs: int) -> Tuple[Axis, ...]:
"""
Convenience function for creating a tuple of Axis objects.
Example:
```
X, Y = axes(X=10, Y=20)
"""
return tuple(Axis(name, size) for name, size in kwargs.items())


AxisSelector = Union[Axis, str]
"""AxisSelector is a type that can be used to select a single axis from an array. str or Axis"""
AxisSelection = Union[AxisSelector, Sequence[AxisSelector]]
Expand Down
5 changes: 2 additions & 3 deletions src/haliax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,11 +373,10 @@ def __getitem__(self, idx: SliceSpec) -> "NamedArray":
Supports indexing like:
>>> X = Axis("x", 10)
>>> Y = Axis("y", 20)
>>> X, Y = haliax.make_axes(X=10, Y=20)
>>> arr = haliax.random.randint(jax.random.PRNGKey(0), (X, Y), 0, X.size)
# slice with ints or slices
>>> arr[{"x": 1, "y": slice(0,10,new_axis=2)}]
>>> arr[{"x": 1, "y": slice(0,10,2)}]
>>> Z = Axis("z", 3)
# so-called "advanced indexing" with NamedArrays.
>>> index_arr = NamedArray(np.array([1, 2, 3]), Z)
Expand Down
41 changes: 20 additions & 21 deletions src/haliax/nn/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def init(
depth: int,
activation: Callable = relu,
*,
out_first: bool = True,
use_bias: bool = True,
use_final_bias: bool = True,
key: PRNGKeyArray,
Expand All @@ -57,36 +58,34 @@ def init(

layers = []

kwargs: dict = {
"use_bias": use_bias,
"dot_general": dot_general,
"init_scale": init_scale,
"out_first": out_first,
}

last_kwargs: dict = {
"use_bias": use_final_bias,
"dot_general": dot_general,
"init_scale": init_scale,
"out_first": out_first,
}

if depth == 0:
# special case: no hidden layers
layers.append(
Linear.init(
Input, Output, use_bias=use_final_bias, key=keys[0], dot_general=dot_general, init_scale=init_scale
)
)
layers.append(Linear.init(Input, Output, key=keys[0], **last_kwargs))
else:
# first hidden layer
layers.append(
Linear.init(
Input, Width, use_bias=use_bias, key=keys[0], dot_general=dot_general, init_scale=init_scale
)
)
layers.append(Linear.init(Input, Width, key=keys[0], **kwargs))
# middle hidden layers
cur = Width
next = Width2
for i in range(1, depth):
layers.append(
Linear.init(
cur, next, use_bias=use_bias, key=keys[i], dot_general=dot_general, init_scale=init_scale
)
)
layers.append(Linear.init(cur, next, key=keys[i], **kwargs))
cur, next = next, cur
# final hidden layer
layers.append(
Linear.init(
cur, Output, use_bias=use_final_bias, key=keys[-1], dot_general=dot_general, init_scale=init_scale
)
)
# final layer
layers.append(Linear.init(cur, Output, key=keys[-1], **last_kwargs))

return MLP(
layers=tuple(layers),
Expand Down
2 changes: 1 addition & 1 deletion src/haliax/nn/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def unstacked(self) -> Sequence[M]:
"""

def unbatch_leaf(x):
if haliax.is_named_array(x):
if isinstance(x, haliax.core.NamedArray):
if haliax.selects_axis(x.axes, self.Block):
return haliax.unbind(x, self.Block)
else:
Expand Down
Loading

0 comments on commit 16ed1f1

Please sign in to comment.