From 7660d676e4c54ff7ef4805999ce9e8f24579e1e0 Mon Sep 17 00:00:00 2001 From: Artur Galstyan Date: Mon, 12 Aug 2024 22:05:10 +0200 Subject: [PATCH 1/8] added pyright ignore to test - aren't we using that anymore? --- .gitignore | 3 +++ equinox/nn/_embedding.py | 46 ++++++++++++++++++++++++++++------------ tests/test_nn.py | 16 ++++++++++---- 3 files changed, 48 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index 57e24bc2..eb0d0e08 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,6 @@ examples/MNIST examples/multipart_serialised.eqx .python-version .DS_Store +.ruff_cache +.pytest_cache +.venv diff --git a/equinox/nn/_embedding.py b/equinox/nn/_embedding.py index 5fb358b0..d3febd49 100644 --- a/equinox/nn/_embedding.py +++ b/equinox/nn/_embedding.py @@ -126,18 +126,22 @@ def __call__(...): def process_heads( query_heads: Float[Array, "seq_length num_heads qk_size"], key_heads: Float[Array, "seq_length num_heads qk_size"], - value_heads: Float[Array, "seq_length num_heads vo_size"] + value_heads: Float[Array, "seq_length num_heads vo_size"], + index: Int[Array, ""] ) -> tuple[ Float[Array, "seq_length num_heads qk_size"], Float[Array, "seq_length num_heads qk_size"], Float[Array, "seq_length num_heads vo_size"] ]: - query_heads = jax.vmap(self.rope_embeddings, - in_axes=1, - out_axes=1)(query_heads) - key_heads = jax.vmap(self.rope_embeddings, - in_axes=1, - out_axes=1)(key_heads) + # index is the autoregressive index of the current token + rope_partial = functools.partial( + rope_embeddings, + offset=index + ) + query_heads = jax.vmap(rope_partial, in_axes=1, out_axes=1) + (query_heads) + key_heads = jax.vmap(rope_partial, in_axes=1, out_axes=1) + (key_heads) return query_heads, key_heads, value_heads @@ -161,6 +165,7 @@ def process_heads( """ embedding_size: int = field(static=True) + max_seq_length: int = field(static=True) theta: float = field(static=True, default=10_000.0) def __check_init__(self): @@ -168,6 +173,8 @@ def __check_init__(self): raise ValueError("`embedding_size` must not be negative.") if (self.embedding_size % 2) != 0: raise ValueError("`embedding_size` must be even.") + if self.max_seq_length < 0: + raise ValueError("`max_seq_length` must not be negative.") @staticmethod def rotate_half(x: Float[Array, "seq_length embedding_size"]): @@ -194,12 +201,14 @@ def precompute_freqs_cis( def __call__( self, x: Float[Array, "seq_length embedding_size"], + offset: Int[Array, ""] = jnp.array(0), *, key: Optional[PRNGKeyArray] = None, ) -> Float[Array, "seq_length embedding_size"]: """**Arguments:** - `x`: A JAX array of shape `(seq_length, embedding_size)`. + - `offset`: The offset to apply to the positional encoding. - `key`: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.) @@ -215,28 +224,36 @@ def __call__( f"x.shape[-1] must match self.embedding_size, " f"but {x.shape[-1]} != {self.embedding_size}" ) + if seq_len > self.max_seq_length: + raise ValueError( + f"seq_len must be less than or equal to self.max_seq_length, " + f"but {seq_len} > {self.max_seq_length}" + ) with jax.ensure_compile_time_eval(): if embedding_size in internal_rope_embedding_cache: freqs_cis = internal_rope_embedding_cache[embedding_size] freqs_cis_seq_len, _ = freqs_cis.shape - if seq_len > freqs_cis_seq_len: + if self.max_seq_length > freqs_cis_seq_len: freqs_cis = self.precompute_freqs_cis( - embedding_size, seq_len, self.theta + embedding_size, self.max_seq_length, self.theta ) internal_rope_embedding_cache[embedding_size] = freqs_cis else: - freqs_cis = freqs_cis[:seq_len] + freqs_cis = freqs_cis[: self.max_seq_length] else: freqs_cis = self.precompute_freqs_cis( - embedding_size, seq_len, self.theta + embedding_size, self.max_seq_length, self.theta ) internal_rope_embedding_cache[embedding_size] = freqs_cis + freqs_cis = jax.lax.dynamic_slice_in_dim(freqs_cis, offset, seq_len) + freqs_real = jnp.tile(freqs_cis.real, (1, 2)) freqs_imag = jnp.tile(freqs_cis.imag, (1, 2)) rotate_x = self.rotate_half(x) + x_rope = (x * freqs_real) + (rotate_x * freqs_imag) return x_rope @@ -244,8 +261,11 @@ def __call__( RotaryPositionalEmbedding.__init__.__doc__ = """**Arguments:** - `embedding_size`: Size of the token embeddings. Must be non-negative and even. -- `theta`: The base frequency for the sinusoidal functions. It defines the rate - of oscillation for the sine and cosine waves that encode positional information +- `theta`: The base frequency for the sinusoidal functions. It defines the rate + of oscillation for the sine and cosine waves that encode positional information into the embeddings. The larger the theta value, the slower the oscillations and vice versa. Defaults to 10_000.0 +- `max_seq_length`: The maximum sequence length for which to precompute the + positional encodings. This is used to determine the size of the precomputed + positional encodings. """ diff --git a/tests/test_nn.py b/tests/test_nn.py index 157b799c..493f032b 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -238,8 +238,8 @@ def test_mlp_learnt_activation(): key=jrandom.PRNGKey(5678), ) x = jnp.array([0.5, 0.7]) - assert mlp.activation.negative_slope.shape == (2, 8) - assert mlp.final_activation.negative_slope.shape == (5,) + assert mlp.activation.negative_slope.shape == (2, 8) # pyright: ignore + assert mlp.final_activation.negative_slope.shape == (5,) # pyright: ignore @eqx.filter_jit @eqx.filter_grad @@ -1352,13 +1352,16 @@ def test_prelu(getkey): def test_rope_embeddings_shapes(getkey): embedding_size = 32 - rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size) n_heads = 4 seq_length = 8 query_size = 32 key_size = 32 + rope_embeddings = eqx.nn.RotaryPositionalEmbedding( + embedding_size, max_seq_length=seq_length + ) + query_heads = jax.random.normal( key=getkey(), shape=(seq_length, n_heads, query_size) ) @@ -1435,7 +1438,12 @@ def test_rope_embeddings_values(): seq_length, embedding_size ) - rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size) + rope_embeddings = eqx.nn.RotaryPositionalEmbedding( + embedding_size, max_seq_length=seq_length + ) + res = rope_embeddings(x) + + assert jnp.allclose(res, expected_values, atol=1e-6) res = rope_embeddings(x) assert jnp.allclose(res, expected_values, atol=1e-6) From 285cce7f7964a8769acef8d771db6250345b7c1b Mon Sep 17 00:00:00 2001 From: Artur Galstyan Date: Sat, 17 Aug 2024 14:02:31 +0200 Subject: [PATCH 2/8] added better docstring, made offset mandatory --- equinox/nn/_embedding.py | 75 +++++++++++++++++++++++++++++++--------- tests/test_nn.py | 3 ++ 2 files changed, 61 insertions(+), 17 deletions(-) diff --git a/equinox/nn/_embedding.py b/equinox/nn/_embedding.py index d3febd49..a02fa075 100644 --- a/equinox/nn/_embedding.py +++ b/equinox/nn/_embedding.py @@ -118,35 +118,76 @@ class RotaryPositionalEmbedding(Module, strict=True): ```python class TransformerBlock(eqx.Module): rope_embeddings: RotaryPositionalEmbedding + mha_attention: MultiheadAttention - def __init__(...): - self.rope_embeddings = RotaryPositionalEmbedding(...) + def __init__(self, embedding_size, max_seq_length, num_heads, query_size): + self.rope_embeddings = RotaryPositionalEmbedding( + embedding_size, max_seq_length + ) + self.mha_attention = MultiheadAttention( + num_heads=num_heads, query_size=query_size, key=jax.random.key(0) + ) - def __call__(...): + def __call__(self, query, key_, value, index): def process_heads( query_heads: Float[Array, "seq_length num_heads qk_size"], key_heads: Float[Array, "seq_length num_heads qk_size"], value_heads: Float[Array, "seq_length num_heads vo_size"], - index: Int[Array, ""] + index: Int[Array, ""], ) -> tuple[ Float[Array, "seq_length num_heads qk_size"], Float[Array, "seq_length num_heads qk_size"], - Float[Array, "seq_length num_heads vo_size"] + Float[Array, "seq_length num_heads vo_size"], ]: # index is the autoregressive index of the current token - rope_partial = functools.partial( - rope_embeddings, - offset=index - ) - query_heads = jax.vmap(rope_partial, in_axes=1, out_axes=1) - (query_heads) - key_heads = jax.vmap(rope_partial, in_axes=1, out_axes=1) - (key_heads) + rope_p = functools.partial(self.rope_embeddings, offset=index) + query_heads = jax.vmap(rope_p, in_axes=1, out_axes=1)(query_heads) + key_heads = jax.vmap(rope_p, in_axes=1, out_axes=1)(key_heads) return query_heads, key_heads, value_heads - x = self.mha_attention(... process_heads=process_heads) - ... + x = self.mha_attention( + query=query, + key_=key_, + value=value, + process_heads=functools.partial(process_heads, index=index), + ) + + return x + + embedding_size = 32 + max_seq_length = 8 + seq_length = 4 + num_heads = 2 + query_size = 64 + + transformer_block = eqx.filter_jit( + TransformerBlock(embedding_size, max_seq_length, num_heads, query_size) + ) + + q = jnp.ones(shape=(seq_length, query_size)) + k = jnp.ones(shape=(seq_length, query_size)) + v = jnp.ones(shape=(seq_length, query_size)) + + out = transformer_block(q, k, v, jnp.array(0)) + out = transformer_block(q, k, v, jnp.array(1)) # no re-JITing + ``` + + If you're training a transformer, you likely don't want to use any offset. In + those cases, it can be helpful to use `functools.partial` like so: + ```python + embedding_size = 32 + max_seq_length = 8 + + rot_emb = RotaryPositionalEmbedding( + embedding_size=embedding_size, max_seq_length=max_seq_length + ) + rot_emb = eqx.filter_jit(rot_emb) + rot_emb_no_offset = functools.partial(rot_emb, offset=jnp.array(0)) + + x = jnp.ones(shape=(max_seq_length, embedding_size)) + + assert jnp.allclose(rot_emb(x, offset=jnp.array(0)), rot_emb_no_offset(x)) ``` ??? cite @@ -201,7 +242,7 @@ def precompute_freqs_cis( def __call__( self, x: Float[Array, "seq_length embedding_size"], - offset: Int[Array, ""] = jnp.array(0), + offset: Int[Array, ""], *, key: Optional[PRNGKeyArray] = None, ) -> Float[Array, "seq_length embedding_size"]: @@ -217,7 +258,7 @@ def __call__( A JAX array of shape `(seq_length, embedding_size)`, with the rotary positional encoding applied to the input. """ - + print("JIT ROPE") seq_len, embedding_size = x.shape if embedding_size != self.embedding_size: raise ValueError( diff --git a/tests/test_nn.py b/tests/test_nn.py index 493f032b..e4d73c36 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -1,3 +1,4 @@ +import functools import warnings from typing import Union @@ -1361,6 +1362,7 @@ def test_rope_embeddings_shapes(getkey): rope_embeddings = eqx.nn.RotaryPositionalEmbedding( embedding_size, max_seq_length=seq_length ) + rope_embeddings = functools.partial(rope_embeddings, offset=jnp.array(0)) query_heads = jax.random.normal( key=getkey(), shape=(seq_length, n_heads, query_size) @@ -1441,6 +1443,7 @@ def test_rope_embeddings_values(): rope_embeddings = eqx.nn.RotaryPositionalEmbedding( embedding_size, max_seq_length=seq_length ) + rope_embeddings = functools.partial(rope_embeddings, offset=jnp.array(0)) res = rope_embeddings(x) assert jnp.allclose(res, expected_values, atol=1e-6) From 539eaa1e038a28006ae5c55c1dc316b4de5d8a62 Mon Sep 17 00:00:00 2001 From: Artur Galstyan Date: Sun, 18 Aug 2024 20:52:28 +0200 Subject: [PATCH 3/8] remove spurious print statement --- equinox/nn/_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/equinox/nn/_embedding.py b/equinox/nn/_embedding.py index a02fa075..3ef393cb 100644 --- a/equinox/nn/_embedding.py +++ b/equinox/nn/_embedding.py @@ -258,7 +258,7 @@ def __call__( A JAX array of shape `(seq_length, embedding_size)`, with the rotary positional encoding applied to the input. """ - print("JIT ROPE") + seq_len, embedding_size = x.shape if embedding_size != self.embedding_size: raise ValueError( From cd28d7b26f218c6d424d8427a991cffd636d25eb Mon Sep 17 00:00:00 2001 From: Artur Galstyan Date: Sun, 18 Aug 2024 21:07:00 +0200 Subject: [PATCH 4/8] shortened docs to essentials only --- equinox/nn/_embedding.py | 30 +++--------------------------- 1 file changed, 3 insertions(+), 27 deletions(-) diff --git a/equinox/nn/_embedding.py b/equinox/nn/_embedding.py index 3ef393cb..c011d1e8 100644 --- a/equinox/nn/_embedding.py +++ b/equinox/nn/_embedding.py @@ -120,13 +120,11 @@ class TransformerBlock(eqx.Module): rope_embeddings: RotaryPositionalEmbedding mha_attention: MultiheadAttention - def __init__(self, embedding_size, max_seq_length, num_heads, query_size): + def __init__(self, embedding_size, max_seq_length, ...): self.rope_embeddings = RotaryPositionalEmbedding( embedding_size, max_seq_length ) - self.mha_attention = MultiheadAttention( - num_heads=num_heads, query_size=query_size, key=jax.random.key(0) - ) + self.mha_attention = MultiheadAttention(...) def __call__(self, query, key_, value, index): def process_heads( @@ -155,11 +153,6 @@ def process_heads( return x - embedding_size = 32 - max_seq_length = 8 - seq_length = 4 - num_heads = 2 - query_size = 64 transformer_block = eqx.filter_jit( TransformerBlock(embedding_size, max_seq_length, num_heads, query_size) @@ -170,24 +163,7 @@ def process_heads( v = jnp.ones(shape=(seq_length, query_size)) out = transformer_block(q, k, v, jnp.array(0)) - out = transformer_block(q, k, v, jnp.array(1)) # no re-JITing - ``` - - If you're training a transformer, you likely don't want to use any offset. In - those cases, it can be helpful to use `functools.partial` like so: - ```python - embedding_size = 32 - max_seq_length = 8 - - rot_emb = RotaryPositionalEmbedding( - embedding_size=embedding_size, max_seq_length=max_seq_length - ) - rot_emb = eqx.filter_jit(rot_emb) - rot_emb_no_offset = functools.partial(rot_emb, offset=jnp.array(0)) - - x = jnp.ones(shape=(max_seq_length, embedding_size)) - - assert jnp.allclose(rot_emb(x, offset=jnp.array(0)), rot_emb_no_offset(x)) + out = transformer_block(q, k, v, jnp.array(1)) ``` ??? cite From 68867b102694cee2b5d806b773dc32ef52652d3e Mon Sep 17 00:00:00 2001 From: Artur Galstyan Date: Fri, 23 Aug 2024 22:18:47 +0200 Subject: [PATCH 5/8] removed max_seq_len --- equinox/nn/_embedding.py | 29 ++++++++--------------------- tests/test_nn.py | 8 ++------ 2 files changed, 10 insertions(+), 27 deletions(-) diff --git a/equinox/nn/_embedding.py b/equinox/nn/_embedding.py index c011d1e8..c7426040 100644 --- a/equinox/nn/_embedding.py +++ b/equinox/nn/_embedding.py @@ -120,10 +120,8 @@ class TransformerBlock(eqx.Module): rope_embeddings: RotaryPositionalEmbedding mha_attention: MultiheadAttention - def __init__(self, embedding_size, max_seq_length, ...): - self.rope_embeddings = RotaryPositionalEmbedding( - embedding_size, max_seq_length - ) + def __init__(self, embedding_size, ...): + self.rope_embeddings = RotaryPositionalEmbedding(embedding_size) self.mha_attention = MultiheadAttention(...) def __call__(self, query, key_, value, index): @@ -155,7 +153,7 @@ def process_heads( transformer_block = eqx.filter_jit( - TransformerBlock(embedding_size, max_seq_length, num_heads, query_size) + TransformerBlock(embedding_size, ...) ) q = jnp.ones(shape=(seq_length, query_size)) @@ -182,7 +180,6 @@ def process_heads( """ embedding_size: int = field(static=True) - max_seq_length: int = field(static=True) theta: float = field(static=True, default=10_000.0) def __check_init__(self): @@ -190,8 +187,6 @@ def __check_init__(self): raise ValueError("`embedding_size` must not be negative.") if (self.embedding_size % 2) != 0: raise ValueError("`embedding_size` must be even.") - if self.max_seq_length < 0: - raise ValueError("`max_seq_length` must not be negative.") @staticmethod def rotate_half(x: Float[Array, "seq_length embedding_size"]): @@ -234,33 +229,28 @@ def __call__( A JAX array of shape `(seq_length, embedding_size)`, with the rotary positional encoding applied to the input. """ - seq_len, embedding_size = x.shape + if embedding_size != self.embedding_size: raise ValueError( f"x.shape[-1] must match self.embedding_size, " f"but {x.shape[-1]} != {self.embedding_size}" ) - if seq_len > self.max_seq_length: - raise ValueError( - f"seq_len must be less than or equal to self.max_seq_length, " - f"but {seq_len} > {self.max_seq_length}" - ) with jax.ensure_compile_time_eval(): if embedding_size in internal_rope_embedding_cache: freqs_cis = internal_rope_embedding_cache[embedding_size] freqs_cis_seq_len, _ = freqs_cis.shape - if self.max_seq_length > freqs_cis_seq_len: + if seq_len > freqs_cis_seq_len: freqs_cis = self.precompute_freqs_cis( - embedding_size, self.max_seq_length, self.theta + embedding_size, seq_len, self.theta ) internal_rope_embedding_cache[embedding_size] = freqs_cis else: - freqs_cis = freqs_cis[: self.max_seq_length] + freqs_cis = freqs_cis[:seq_len] else: freqs_cis = self.precompute_freqs_cis( - embedding_size, self.max_seq_length, self.theta + embedding_size, seq_len, self.theta ) internal_rope_embedding_cache[embedding_size] = freqs_cis @@ -282,7 +272,4 @@ def __call__( of oscillation for the sine and cosine waves that encode positional information into the embeddings. The larger the theta value, the slower the oscillations and vice versa. Defaults to 10_000.0 -- `max_seq_length`: The maximum sequence length for which to precompute the - positional encodings. This is used to determine the size of the precomputed - positional encodings. """ diff --git a/tests/test_nn.py b/tests/test_nn.py index e4d73c36..6806116a 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -1359,9 +1359,7 @@ def test_rope_embeddings_shapes(getkey): query_size = 32 key_size = 32 - rope_embeddings = eqx.nn.RotaryPositionalEmbedding( - embedding_size, max_seq_length=seq_length - ) + rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size) rope_embeddings = functools.partial(rope_embeddings, offset=jnp.array(0)) query_heads = jax.random.normal( @@ -1440,9 +1438,7 @@ def test_rope_embeddings_values(): seq_length, embedding_size ) - rope_embeddings = eqx.nn.RotaryPositionalEmbedding( - embedding_size, max_seq_length=seq_length - ) + rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size) rope_embeddings = functools.partial(rope_embeddings, offset=jnp.array(0)) res = rope_embeddings(x) From 7675b03364318a02d4a0f76895122e3617d37087 Mon Sep 17 00:00:00 2001 From: Artur Galstyan Date: Thu, 5 Sep 2024 21:36:02 +0200 Subject: [PATCH 6/8] shortened docstring and kept only the relevant stuff --- equinox/nn/_embedding.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/equinox/nn/_embedding.py b/equinox/nn/_embedding.py index c7426040..e2871483 100644 --- a/equinox/nn/_embedding.py +++ b/equinox/nn/_embedding.py @@ -124,7 +124,7 @@ def __init__(self, embedding_size, ...): self.rope_embeddings = RotaryPositionalEmbedding(embedding_size) self.mha_attention = MultiheadAttention(...) - def __call__(self, query, key_, value, index): + def __call__(..., index: Int[Array, ""]): def process_heads( query_heads: Float[Array, "seq_length num_heads qk_size"], key_heads: Float[Array, "seq_length num_heads qk_size"], @@ -143,25 +143,11 @@ def process_heads( return query_heads, key_heads, value_heads x = self.mha_attention( - query=query, - key_=key_, - value=value, + ..., process_heads=functools.partial(process_heads, index=index), ) return x - - - transformer_block = eqx.filter_jit( - TransformerBlock(embedding_size, ...) - ) - - q = jnp.ones(shape=(seq_length, query_size)) - k = jnp.ones(shape=(seq_length, query_size)) - v = jnp.ones(shape=(seq_length, query_size)) - - out = transformer_block(q, k, v, jnp.array(0)) - out = transformer_block(q, k, v, jnp.array(1)) ``` ??? cite From 284db389a7f27db75026ad43b7f09d933c7be0fd Mon Sep 17 00:00:00 2001 From: Artur Galstyan Date: Fri, 6 Sep 2024 22:10:53 +0200 Subject: [PATCH 7/8] fixed subtle cos/sin bug in RoPE, fixed tests, fixed offset overflow bug --- equinox/nn/_embedding.py | 31 ++++++----- tests/test_nn.py | 112 +++++++++++++++++++++++++++------------ 2 files changed, 97 insertions(+), 46 deletions(-) diff --git a/equinox/nn/_embedding.py b/equinox/nn/_embedding.py index e2871483..e52aa6f4 100644 --- a/equinox/nn/_embedding.py +++ b/equinox/nn/_embedding.py @@ -181,14 +181,14 @@ def rotate_half(x: Float[Array, "seq_length embedding_size"]): @staticmethod def precompute_freqs_cis( - embedding_size: int, end: int, theta: float + embedding_size: int, end: Int[ArrayLike, ""], theta: float ) -> Complex[Array, "end half_of_embedding_size"]: freqs = 1.0 / ( theta ** (jnp.arange(0.0, embedding_size, 2)[jnp.newaxis, :] / embedding_size) ) - t = jnp.arange(float(end)) + t = jnp.arange(end / 1.0) # promote to float freqs_outer = jnp.outer(t, freqs) with jax.numpy_dtype_promotion("standard"): freqs_cis = jnp.cos(freqs_outer) + jnp.sin(freqs_outer) * 1j @@ -199,7 +199,7 @@ def precompute_freqs_cis( def __call__( self, x: Float[Array, "seq_length embedding_size"], - offset: Int[Array, ""], + offset: Int[ArrayLike, ""] = 0, *, key: Optional[PRNGKeyArray] = None, ) -> Float[Array, "seq_length embedding_size"]: @@ -224,30 +224,37 @@ def __call__( ) with jax.ensure_compile_time_eval(): + min_required_seq_len = offset + seq_len # pyright: ignore TODO: fix typing if embedding_size in internal_rope_embedding_cache: freqs_cis = internal_rope_embedding_cache[embedding_size] freqs_cis_seq_len, _ = freqs_cis.shape - if seq_len > freqs_cis_seq_len: + if min_required_seq_len > freqs_cis_seq_len: # pyright: ignore TODO: fix typing freqs_cis = self.precompute_freqs_cis( - embedding_size, seq_len, self.theta + embedding_size, min_required_seq_len, self.theta ) internal_rope_embedding_cache[embedding_size] = freqs_cis else: - freqs_cis = freqs_cis[:seq_len] + freqs_cis = freqs_cis[:min_required_seq_len] else: freqs_cis = self.precompute_freqs_cis( - embedding_size, seq_len, self.theta + embedding_size, min_required_seq_len, self.theta ) internal_rope_embedding_cache[embedding_size] = freqs_cis - freqs_cis = jax.lax.dynamic_slice_in_dim(freqs_cis, offset, seq_len) - freqs_real = jnp.tile(freqs_cis.real, (1, 2)) - freqs_imag = jnp.tile(freqs_cis.imag, (1, 2)) + half_size = embedding_size // 2 + freqs_cos = freqs_cis.real[:, :half_size] + freqs_sin = freqs_cis.imag[:, :half_size] + + x_cos, x_sin = x[..., :half_size], x[..., half_size:] - rotate_x = self.rotate_half(x) + x_rope_cos = x_cos * freqs_cos - x_sin * freqs_sin + x_rope_sin = x_cos * freqs_sin + x_sin * freqs_cos + + x_rope = jnp.stack([x_rope_cos, x_rope_sin], axis=-1).reshape( + seq_len, embedding_size + ) - x_rope = (x * freqs_real) + (rotate_x * freqs_imag) return x_rope diff --git a/tests/test_nn.py b/tests/test_nn.py index 6806116a..125f4e79 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -1407,42 +1407,86 @@ def test_rope_embeddings_freqs_cis(): def test_rope_embeddings_values(): - # values are generated using - # the script in this gist: - # https://gist.github.com/Artur-Galstyan/d33eda74072fea61545127adb90197b5 - # Those values are generated based on the HuggingFace implementation - # of the Rope embeddings - # (see here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_flax_llama.py#L169) - expected_values = jnp.array( - [ - [ - 0.0, - 1.0, - 2.0, - 3.0, - ], - [-2.887617, 4.9297514, 6.6076975, 7.0496492], - [-12.422148, 8.778215, 3.1129112, 11.177788], - [-13.85559, 12.544218, -12.166454, 15.383192], - [3.1641474, 16.226604, -23.874424, 19.664621], - [26.769577, 19.824234, -12.937918, 24.020819], - [30.30889, 23.335985, 18.258457, 28.450514], - [1.3996639, 26.760752, 41.01269, 32.952423], - ] - ) + # These values are generation using the RoPE implementation from lucidrains + # https://github.com/lucidrains/rotary-embedding-torch - seq_length = 8 - embedding_size = 4 + # The gist that generates these values can be found here: + # https://gist.github.com/Artur-Galstyan/8fd9df6d09a5262671dd934d43f91663 - x = jnp.arange(seq_length * embedding_size * 1.0).reshape( - seq_length, embedding_size - ) + expected_values = [ + jnp.array( + [ + [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], + [-0.3012, 1.3818, 0.8952, 1.0948, 0.9900, 1.0099, 0.9990, 1.0010], + [-1.3254, 0.4932, 0.7814, 1.1787, 0.9798, 1.0198, 0.9980, 1.0020], + [-1.1311, -0.8489, 0.6598, 1.2509, 0.9696, 1.0295, 0.9970, 1.0030], + ] + ), + jnp.array( + [ + [-0.3012, 1.3818, 0.8952, 1.0948, 0.9900, 1.0099, 0.9990, 1.0010], + [-1.3254, 0.4932, 0.7814, 1.1787, 0.9798, 1.0198, 0.9980, 1.0020], + [-1.1311, -0.8489, 0.6598, 1.2509, 0.9696, 1.0295, 0.9970, 1.0030], + [0.1032, -1.4104, 0.5316, 1.3105, 0.9592, 1.0392, 0.9960, 1.0040], + ] + ), + jnp.array( + [ + [-1.3254, 0.4932, 0.7814, 1.1787, 0.9798, 1.0198, 0.9980, 1.0020], + [-1.1311, -0.8489, 0.6598, 1.2509, 0.9696, 1.0295, 0.9970, 1.0030], + [0.1032, -1.4104, 0.5316, 1.3105, 0.9592, 1.0392, 0.9960, 1.0040], + [1.2426, -0.6753, 0.3982, 1.3570, 0.9488, 1.0487, 0.9950, 1.0050], + ] + ), + jnp.array( + [ + [-1.1311, -0.8489, 0.6598, 1.2509, 0.9696, 1.0295, 0.9970, 1.0030], + [0.1032, -1.4104, 0.5316, 1.3105, 0.9592, 1.0392, 0.9960, 1.0040], + [1.2426, -0.6753, 0.3982, 1.3570, 0.9488, 1.0487, 0.9950, 1.0050], + [1.2396, 0.6808, 0.2607, 1.3900, 0.9382, 1.0582, 0.9940, 1.0060], + ] + ), + jnp.array( + [ + [0.1032, -1.4104, 0.5316, 1.3105, 0.9592, 1.0392, 0.9960, 1.0040], + [1.2426, -0.6753, 0.3982, 1.3570, 0.9488, 1.0487, 0.9950, 1.0050], + [1.2396, 0.6808, 0.2607, 1.3900, 0.9382, 1.0582, 0.9940, 1.0060], + [0.0969, 1.4109, 0.1206, 1.4091, 0.9276, 1.0675, 0.9930, 1.0070], + ] + ), + jnp.array( + [ + [1.2426, -0.6753, 0.3982, 1.3570, 0.9488, 1.0487, 0.9950, 1.0050], + [1.2396, 0.6808, 0.2607, 1.3900, 0.9382, 1.0582, 0.9940, 1.0060], + [0.0969, 1.4109, 0.1206, 1.4091, 0.9276, 1.0675, 0.9930, 1.0070], + [-1.1349, 0.8439, -0.0206, 1.4141, 0.9169, 1.0767, 0.9920, 1.0080], + ] + ), + jnp.array( + [ + [1.2396, 0.6808, 0.2607, 1.3900, 0.9382, 1.0582, 0.9940, 1.0060], + [0.0969, 1.4109, 0.1206, 1.4091, 0.9276, 1.0675, 0.9930, 1.0070], + [-1.1349, 0.8439, -0.0206, 1.4141, 0.9169, 1.0767, 0.9920, 1.0080], + [-1.3232, -0.4990, -0.1617, 1.4049, 0.9061, 1.0858, 0.9910, 1.0090], + ] + ), + jnp.array( + [ + [0.0969, 1.4109, 0.1206, 1.4091, 0.9276, 1.0675, 0.9930, 1.0070], + [-1.1349, 0.8439, -0.0206, 1.4141, 0.9169, 1.0767, 0.9920, 1.0080], + [-1.3232, -0.4990, -0.1617, 1.4049, 0.9061, 1.0858, 0.9910, 1.0090], + [-0.2951, -1.3831, -0.3012, 1.3818, 0.8952, 1.0948, 0.9900, 1.0099], + ] + ), + ] - rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size) - rope_embeddings = functools.partial(rope_embeddings, offset=jnp.array(0)) - res = rope_embeddings(x) + embedding_size = 8 + seq_len = 4 - assert jnp.allclose(res, expected_values, atol=1e-6) - res = rope_embeddings(x) + x = jnp.ones(shape=(seq_len, embedding_size)) - assert jnp.allclose(res, expected_values, atol=1e-6) + rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size) + rope_embeddings = eqx.filter_jit(rope_embeddings) + for i in range(len(expected_values)): + out = rope_embeddings(x, i) + assert jnp.allclose(out, expected_values[i], atol=1e-4) From 83357926992fd047b4b39017c141a0598dc6e5d8 Mon Sep 17 00:00:00 2001 From: Artur Galstyan Date: Tue, 10 Sep 2024 21:49:54 +0200 Subject: [PATCH 8/8] revert interleaved to grouped and fix test --- equinox/nn/_embedding.py | 22 ++++------ tests/test_nn.py | 86 ---------------------------------------- 2 files changed, 7 insertions(+), 101 deletions(-) diff --git a/equinox/nn/_embedding.py b/equinox/nn/_embedding.py index e52aa6f4..bc3b1bf2 100644 --- a/equinox/nn/_embedding.py +++ b/equinox/nn/_embedding.py @@ -188,7 +188,7 @@ def precompute_freqs_cis( ** (jnp.arange(0.0, embedding_size, 2)[jnp.newaxis, :] / embedding_size) ) - t = jnp.arange(end / 1.0) # promote to float + t = jnp.arange(float(end)) # pyright: ignore freqs_outer = jnp.outer(t, freqs) with jax.numpy_dtype_promotion("standard"): freqs_cis = jnp.cos(freqs_outer) + jnp.sin(freqs_outer) * 1j @@ -224,11 +224,11 @@ def __call__( ) with jax.ensure_compile_time_eval(): - min_required_seq_len = offset + seq_len # pyright: ignore TODO: fix typing + min_required_seq_len = offset + seq_len # pyright: ignore if embedding_size in internal_rope_embedding_cache: freqs_cis = internal_rope_embedding_cache[embedding_size] freqs_cis_seq_len, _ = freqs_cis.shape - if min_required_seq_len > freqs_cis_seq_len: # pyright: ignore TODO: fix typing + if min_required_seq_len > freqs_cis_seq_len: # pyright: ignore freqs_cis = self.precompute_freqs_cis( embedding_size, min_required_seq_len, self.theta ) @@ -242,19 +242,11 @@ def __call__( internal_rope_embedding_cache[embedding_size] = freqs_cis freqs_cis = jax.lax.dynamic_slice_in_dim(freqs_cis, offset, seq_len) - half_size = embedding_size // 2 - freqs_cos = freqs_cis.real[:, :half_size] - freqs_sin = freqs_cis.imag[:, :half_size] - - x_cos, x_sin = x[..., :half_size], x[..., half_size:] - - x_rope_cos = x_cos * freqs_cos - x_sin * freqs_sin - x_rope_sin = x_cos * freqs_sin + x_sin * freqs_cos - - x_rope = jnp.stack([x_rope_cos, x_rope_sin], axis=-1).reshape( - seq_len, embedding_size - ) + freqs_real = jnp.tile(freqs_cis.real, (1, 2)) + freqs_imag = jnp.tile(freqs_cis.imag, (1, 2)) + rotate_x = self.rotate_half(x) + x_rope = (x * freqs_real) + (rotate_x * freqs_imag) return x_rope diff --git a/tests/test_nn.py b/tests/test_nn.py index 125f4e79..dad39b49 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -1404,89 +1404,3 @@ def test_rope_embeddings_freqs_cis(): embedding_size, seq_length, theta ) assert jnp.allclose(freqs_cis, expected_freqs_cis, atol=1e-4) - - -def test_rope_embeddings_values(): - # These values are generation using the RoPE implementation from lucidrains - # https://github.com/lucidrains/rotary-embedding-torch - - # The gist that generates these values can be found here: - # https://gist.github.com/Artur-Galstyan/8fd9df6d09a5262671dd934d43f91663 - - expected_values = [ - jnp.array( - [ - [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], - [-0.3012, 1.3818, 0.8952, 1.0948, 0.9900, 1.0099, 0.9990, 1.0010], - [-1.3254, 0.4932, 0.7814, 1.1787, 0.9798, 1.0198, 0.9980, 1.0020], - [-1.1311, -0.8489, 0.6598, 1.2509, 0.9696, 1.0295, 0.9970, 1.0030], - ] - ), - jnp.array( - [ - [-0.3012, 1.3818, 0.8952, 1.0948, 0.9900, 1.0099, 0.9990, 1.0010], - [-1.3254, 0.4932, 0.7814, 1.1787, 0.9798, 1.0198, 0.9980, 1.0020], - [-1.1311, -0.8489, 0.6598, 1.2509, 0.9696, 1.0295, 0.9970, 1.0030], - [0.1032, -1.4104, 0.5316, 1.3105, 0.9592, 1.0392, 0.9960, 1.0040], - ] - ), - jnp.array( - [ - [-1.3254, 0.4932, 0.7814, 1.1787, 0.9798, 1.0198, 0.9980, 1.0020], - [-1.1311, -0.8489, 0.6598, 1.2509, 0.9696, 1.0295, 0.9970, 1.0030], - [0.1032, -1.4104, 0.5316, 1.3105, 0.9592, 1.0392, 0.9960, 1.0040], - [1.2426, -0.6753, 0.3982, 1.3570, 0.9488, 1.0487, 0.9950, 1.0050], - ] - ), - jnp.array( - [ - [-1.1311, -0.8489, 0.6598, 1.2509, 0.9696, 1.0295, 0.9970, 1.0030], - [0.1032, -1.4104, 0.5316, 1.3105, 0.9592, 1.0392, 0.9960, 1.0040], - [1.2426, -0.6753, 0.3982, 1.3570, 0.9488, 1.0487, 0.9950, 1.0050], - [1.2396, 0.6808, 0.2607, 1.3900, 0.9382, 1.0582, 0.9940, 1.0060], - ] - ), - jnp.array( - [ - [0.1032, -1.4104, 0.5316, 1.3105, 0.9592, 1.0392, 0.9960, 1.0040], - [1.2426, -0.6753, 0.3982, 1.3570, 0.9488, 1.0487, 0.9950, 1.0050], - [1.2396, 0.6808, 0.2607, 1.3900, 0.9382, 1.0582, 0.9940, 1.0060], - [0.0969, 1.4109, 0.1206, 1.4091, 0.9276, 1.0675, 0.9930, 1.0070], - ] - ), - jnp.array( - [ - [1.2426, -0.6753, 0.3982, 1.3570, 0.9488, 1.0487, 0.9950, 1.0050], - [1.2396, 0.6808, 0.2607, 1.3900, 0.9382, 1.0582, 0.9940, 1.0060], - [0.0969, 1.4109, 0.1206, 1.4091, 0.9276, 1.0675, 0.9930, 1.0070], - [-1.1349, 0.8439, -0.0206, 1.4141, 0.9169, 1.0767, 0.9920, 1.0080], - ] - ), - jnp.array( - [ - [1.2396, 0.6808, 0.2607, 1.3900, 0.9382, 1.0582, 0.9940, 1.0060], - [0.0969, 1.4109, 0.1206, 1.4091, 0.9276, 1.0675, 0.9930, 1.0070], - [-1.1349, 0.8439, -0.0206, 1.4141, 0.9169, 1.0767, 0.9920, 1.0080], - [-1.3232, -0.4990, -0.1617, 1.4049, 0.9061, 1.0858, 0.9910, 1.0090], - ] - ), - jnp.array( - [ - [0.0969, 1.4109, 0.1206, 1.4091, 0.9276, 1.0675, 0.9930, 1.0070], - [-1.1349, 0.8439, -0.0206, 1.4141, 0.9169, 1.0767, 0.9920, 1.0080], - [-1.3232, -0.4990, -0.1617, 1.4049, 0.9061, 1.0858, 0.9910, 1.0090], - [-0.2951, -1.3831, -0.3012, 1.3818, 0.8952, 1.0948, 0.9900, 1.0099], - ] - ), - ] - - embedding_size = 8 - seq_len = 4 - - x = jnp.ones(shape=(seq_len, embedding_size)) - - rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size) - rope_embeddings = eqx.filter_jit(rope_embeddings) - for i in range(len(expected_values)): - out = rope_embeddings(x, i) - assert jnp.allclose(out, expected_values[i], atol=1e-4)