Skip to content

Commit

Permalink
change default initialization of linear and embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Jun 7, 2024
1 parent 9228e8b commit d3209f0
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/haliax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
AxisSelector,
AxisSpec,
axis_name,
axis_size,
concat_axes,
dblock,
ds,
Expand Down
13 changes: 13 additions & 0 deletions src/haliax/axis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing
from dataclasses import dataclass
from math import prod
from types import EllipsisType
from typing import Dict, List, Mapping, Optional, Sequence, Tuple, Union, overload

Expand Down Expand Up @@ -354,6 +355,17 @@ def _ax_name(ax: AxisSelector) -> str:
return tuple(_ax_name(x) for x in ax)


def axis_size(ax: AxisSpec) -> int:
"""
Returns the size of the axis or the product of the sizes of the axes in the axis spec
"""

if isinstance(ax, Axis):
return ax.size
else:
return prod(axis.size for axis in ensure_tuple(ax)) # type: ignore


class dslice(eqx.Module):
"""
Dynamic slice, comprising a (start, length) pair. Also aliased as ds.
Expand Down Expand Up @@ -524,6 +536,7 @@ def replace_missing_with_ellipsis(ax1: AxisSelection, ax2: AxisSelection) -> Par
"PartialShapeDict",
"ShapeDict",
"axis_name",
"axis_size",
"concat_axes",
"union_axes",
"axis_spec_to_shape_dict",
Expand Down
8 changes: 6 additions & 2 deletions src/haliax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,11 @@ def dot(

@typing.overload
def dot(
self, *args, axis: Optional[AxisSelection], precision: PrecisionLike = None, dot_general=jax.lax.dot_general
self,
*args: "NamedArray",
axis: Optional[AxisSelection],
precision: PrecisionLike = None,
dot_general=jax.lax.dot_general,
) -> "NamedArray":
...

Expand Down Expand Up @@ -1143,7 +1147,7 @@ def flatten_axes(array: NamedArray, old_axes: AxisSelection, new_axis: AxisSelec
"""
old_axes = ensure_tuple(old_axes)
old_axes = array.resolve_axis(old_axes)
total_axis_size = prod(array.axis_size(ax) for ax in old_axes)
total_axis_size = haliax.axis_size(old_axes)

if isinstance(new_axis, Axis):
if new_axis.size != total_axis_size:
Expand Down
11 changes: 9 additions & 2 deletions src/haliax/nn/embedding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import dataclasses
import math
import warnings
from typing import Optional

import equinox as eqx
Expand All @@ -21,9 +23,14 @@ class Embedding(eqx.Module):
Embed: AxisSpec = eqx.static_field()

@staticmethod
def init(Vocab: Axis, Embed: AxisSpec, initializer_range: float = 0.02, *, key):
def init(Vocab: Axis, Embed: AxisSpec, *, init_std: float = 1, key, initializer_range: Optional[float] = None):
if initializer_range is not None:
warnings.warn("initializer_range is deprecated. Use init_std instead.", DeprecationWarning)
init_std = initializer_range

all_axes = (Vocab,) + ensure_tuple(Embed)
weight = hax.random.normal(key, all_axes) * initializer_range
output_size = hax.axis_size(Embed)
weight = hax.random.truncated_normal(key, all_axes, -3, 3) * (init_std / math.sqrt(output_size))
return Embedding(weight=weight, Vocab=Vocab, Embed=Embed)

def __call__(self, input_ids, *, key: Optional[PRNGKeyArray] = None):
Expand Down
14 changes: 12 additions & 2 deletions src/haliax/nn/linear.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import Callable, Optional

import equinox as eqx
Expand Down Expand Up @@ -25,7 +26,14 @@ class Linear(eqx.Module):

@staticmethod
def init(
In: AxisSpec, Out: AxisSpec, *, key, use_bias=True, out_first: bool = False, dot_general=None
In: AxisSpec,
Out: AxisSpec,
*,
key,
use_bias=True,
out_first: bool = False,
dot_general=None,
init_scale: float = 1.0,
) -> "Linear":
"""
Expand All @@ -36,9 +44,11 @@ def init(
use_bias: bool: Whether to use a bias term
out_first: bool: Whether to put output axes first in the weight matrix. out_first is how PyTorch does it.
dot_general: Callable: The dot_general function to use. Defaults to jax.lax.dot_general. For fp8 or int8
init_scale: float: The scale to use for initialization. We scale init by 1/sqrt(Input.size)*init_scale
"""
joint_spec = hax.concat_axis_specs(Out, In) if out_first else hax.concat_axis_specs(In, Out)
weight = hax.random.normal(key, joint_spec) * 0.02
input_size = hax.axis_size(In)
weight = hax.random.truncated_normal(key, joint_spec, -3, 3) * (init_scale / math.sqrt(input_size))
bias = hax.zeros(Out) if use_bias else None

if dot_general is None:
Expand Down

0 comments on commit d3209f0

Please sign in to comment.