Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Jun 7, 2024
1 parent d3209f0 commit 40dd5a1
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
6 changes: 3 additions & 3 deletions src/haliax/nn/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ class Embedding(eqx.Module):
Embed: AxisSpec = eqx.static_field()

@staticmethod
def init(Vocab: Axis, Embed: AxisSpec, *, init_std: float = 1, key, initializer_range: Optional[float] = None):
def init(Vocab: Axis, Embed: AxisSpec, *, init_scale: float = 1, key, initializer_range: Optional[float] = None):
if initializer_range is not None:
warnings.warn("initializer_range is deprecated. Use init_std instead.", DeprecationWarning)
init_std = initializer_range
init_scale = initializer_range

all_axes = (Vocab,) + ensure_tuple(Embed)
output_size = hax.axis_size(Embed)
weight = hax.random.truncated_normal(key, all_axes, -3, 3) * (init_std / math.sqrt(output_size))
weight = hax.random.truncated_normal(key, all_axes, -3, 3) * (init_scale / math.sqrt(output_size))
return Embedding(weight=weight, Vocab=Vocab, Embed=Embed)

def __call__(self, input_ids, *, key: Optional[PRNGKeyArray] = None):
Expand Down
11 changes: 7 additions & 4 deletions tests/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import jax.random as jrandom
import jax.tree_util
import numpy as np
from chex import assert_trees_all_close

import haliax as hax
from haliax._src.fp8 import compute_scale
Expand All @@ -19,18 +20,20 @@
def test_fp8_is_reasonable():
In = hax.Axis("In", 8)
Out = hax.Axis("Out", 8)
linear = Linear.init(In, Out, key=jrandom.PRNGKey(0))
linear = Linear.init(In, Out, key=jrandom.PRNGKey(0), init_scale=0.1)

fp8_linear = Linear.init(In, Out, key=jrandom.PRNGKey(0), dot_general=hax.quantization.Fp8DotGeneralOp.init())
fp8_linear = Linear.init(
In, Out, key=jrandom.PRNGKey(0), dot_general=hax.quantization.Fp8DotGeneralOp.init(), init_scale=0.1
)

input = hax.random.normal(jrandom.PRNGKey(0), In) * 10
input = hax.random.normal(jrandom.PRNGKey(3), In)
output = linear(input)
fp8_output = fp8_linear(input)

assert output.shape == fp8_output.shape
assert output.dtype == fp8_output.dtype

assert jnp.allclose(output.array, fp8_output.array, atol=1e-1, rtol=1e-1)
assert_trees_all_close(output.array, fp8_output.array, atol=1e-2, rtol=5e-2)


# https://github.com/google/flax/blob/6f2b08e024c2fd2f8cec42a6c82408cb35412319/tests/linen/linen_test.py#L1222
Expand Down
1 change: 0 additions & 1 deletion tests/test_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def __init__(self, in_array: NamedArray):
def test_infer_resource_partition_gda_bug():
devices = jax.devices()
with Mesh(np.array(devices).reshape(-1, 1), (ResourceAxis.DATA, ResourceAxis.MODEL)):
jax.config.update("jax_parallel_functions_output_gda", True)
try:

def foo():
Expand Down

0 comments on commit 40dd5a1

Please sign in to comment.