diff --git a/.gitignore b/.gitignore index 57e24bc2..eb0d0e08 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,6 @@ examples/MNIST examples/multipart_serialised.eqx .python-version .DS_Store +.ruff_cache +.pytest_cache +.venv diff --git a/equinox/nn/_embedding.py b/equinox/nn/_embedding.py index 5fb358b0..bc3b1bf2 100644 --- a/equinox/nn/_embedding.py +++ b/equinox/nn/_embedding.py @@ -118,31 +118,36 @@ class RotaryPositionalEmbedding(Module, strict=True): ```python class TransformerBlock(eqx.Module): rope_embeddings: RotaryPositionalEmbedding + mha_attention: MultiheadAttention - def __init__(...): - self.rope_embeddings = RotaryPositionalEmbedding(...) + def __init__(self, embedding_size, ...): + self.rope_embeddings = RotaryPositionalEmbedding(embedding_size) + self.mha_attention = MultiheadAttention(...) - def __call__(...): + def __call__(..., index: Int[Array, ""]): def process_heads( query_heads: Float[Array, "seq_length num_heads qk_size"], key_heads: Float[Array, "seq_length num_heads qk_size"], - value_heads: Float[Array, "seq_length num_heads vo_size"] + value_heads: Float[Array, "seq_length num_heads vo_size"], + index: Int[Array, ""], ) -> tuple[ Float[Array, "seq_length num_heads qk_size"], Float[Array, "seq_length num_heads qk_size"], - Float[Array, "seq_length num_heads vo_size"] + Float[Array, "seq_length num_heads vo_size"], ]: - query_heads = jax.vmap(self.rope_embeddings, - in_axes=1, - out_axes=1)(query_heads) - key_heads = jax.vmap(self.rope_embeddings, - in_axes=1, - out_axes=1)(key_heads) + # index is the autoregressive index of the current token + rope_p = functools.partial(self.rope_embeddings, offset=index) + query_heads = jax.vmap(rope_p, in_axes=1, out_axes=1)(query_heads) + key_heads = jax.vmap(rope_p, in_axes=1, out_axes=1)(key_heads) return query_heads, key_heads, value_heads - x = self.mha_attention(... process_heads=process_heads) - ... + x = self.mha_attention( + ..., + process_heads=functools.partial(process_heads, index=index), + ) + + return x ``` ??? cite @@ -176,14 +181,14 @@ def rotate_half(x: Float[Array, "seq_length embedding_size"]): @staticmethod def precompute_freqs_cis( - embedding_size: int, end: int, theta: float + embedding_size: int, end: Int[ArrayLike, ""], theta: float ) -> Complex[Array, "end half_of_embedding_size"]: freqs = 1.0 / ( theta ** (jnp.arange(0.0, embedding_size, 2)[jnp.newaxis, :] / embedding_size) ) - t = jnp.arange(float(end)) + t = jnp.arange(float(end)) # pyright: ignore freqs_outer = jnp.outer(t, freqs) with jax.numpy_dtype_promotion("standard"): freqs_cis = jnp.cos(freqs_outer) + jnp.sin(freqs_outer) * 1j @@ -194,12 +199,14 @@ def precompute_freqs_cis( def __call__( self, x: Float[Array, "seq_length embedding_size"], + offset: Int[ArrayLike, ""] = 0, *, key: Optional[PRNGKeyArray] = None, ) -> Float[Array, "seq_length embedding_size"]: """**Arguments:** - `x`: A JAX array of shape `(seq_length, embedding_size)`. + - `offset`: The offset to apply to the positional encoding. - `key`: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.) @@ -208,8 +215,8 @@ def __call__( A JAX array of shape `(seq_length, embedding_size)`, with the rotary positional encoding applied to the input. """ - seq_len, embedding_size = x.shape + if embedding_size != self.embedding_size: raise ValueError( f"x.shape[-1] must match self.embedding_size, " @@ -217,21 +224,23 @@ def __call__( ) with jax.ensure_compile_time_eval(): + min_required_seq_len = offset + seq_len # pyright: ignore if embedding_size in internal_rope_embedding_cache: freqs_cis = internal_rope_embedding_cache[embedding_size] freqs_cis_seq_len, _ = freqs_cis.shape - if seq_len > freqs_cis_seq_len: + if min_required_seq_len > freqs_cis_seq_len: # pyright: ignore freqs_cis = self.precompute_freqs_cis( - embedding_size, seq_len, self.theta + embedding_size, min_required_seq_len, self.theta ) internal_rope_embedding_cache[embedding_size] = freqs_cis else: - freqs_cis = freqs_cis[:seq_len] + freqs_cis = freqs_cis[:min_required_seq_len] else: freqs_cis = self.precompute_freqs_cis( - embedding_size, seq_len, self.theta + embedding_size, min_required_seq_len, self.theta ) internal_rope_embedding_cache[embedding_size] = freqs_cis + freqs_cis = jax.lax.dynamic_slice_in_dim(freqs_cis, offset, seq_len) freqs_real = jnp.tile(freqs_cis.real, (1, 2)) freqs_imag = jnp.tile(freqs_cis.imag, (1, 2)) @@ -244,8 +253,8 @@ def __call__( RotaryPositionalEmbedding.__init__.__doc__ = """**Arguments:** - `embedding_size`: Size of the token embeddings. Must be non-negative and even. -- `theta`: The base frequency for the sinusoidal functions. It defines the rate - of oscillation for the sine and cosine waves that encode positional information +- `theta`: The base frequency for the sinusoidal functions. It defines the rate + of oscillation for the sine and cosine waves that encode positional information into the embeddings. The larger the theta value, the slower the oscillations and vice versa. Defaults to 10_000.0 """ diff --git a/tests/test_nn.py b/tests/test_nn.py index 157b799c..dad39b49 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -1,3 +1,4 @@ +import functools import warnings from typing import Union @@ -238,8 +239,8 @@ def test_mlp_learnt_activation(): key=jrandom.PRNGKey(5678), ) x = jnp.array([0.5, 0.7]) - assert mlp.activation.negative_slope.shape == (2, 8) - assert mlp.final_activation.negative_slope.shape == (5,) + assert mlp.activation.negative_slope.shape == (2, 8) # pyright: ignore + assert mlp.final_activation.negative_slope.shape == (5,) # pyright: ignore @eqx.filter_jit @eqx.filter_grad @@ -1352,13 +1353,15 @@ def test_prelu(getkey): def test_rope_embeddings_shapes(getkey): embedding_size = 32 - rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size) n_heads = 4 seq_length = 8 query_size = 32 key_size = 32 + rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size) + rope_embeddings = functools.partial(rope_embeddings, offset=jnp.array(0)) + query_heads = jax.random.normal( key=getkey(), shape=(seq_length, n_heads, query_size) ) @@ -1401,41 +1404,3 @@ def test_rope_embeddings_freqs_cis(): embedding_size, seq_length, theta ) assert jnp.allclose(freqs_cis, expected_freqs_cis, atol=1e-4) - - -def test_rope_embeddings_values(): - # values are generated using - # the script in this gist: - # https://gist.github.com/Artur-Galstyan/d33eda74072fea61545127adb90197b5 - # Those values are generated based on the HuggingFace implementation - # of the Rope embeddings - # (see here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_flax_llama.py#L169) - expected_values = jnp.array( - [ - [ - 0.0, - 1.0, - 2.0, - 3.0, - ], - [-2.887617, 4.9297514, 6.6076975, 7.0496492], - [-12.422148, 8.778215, 3.1129112, 11.177788], - [-13.85559, 12.544218, -12.166454, 15.383192], - [3.1641474, 16.226604, -23.874424, 19.664621], - [26.769577, 19.824234, -12.937918, 24.020819], - [30.30889, 23.335985, 18.258457, 28.450514], - [1.3996639, 26.760752, 41.01269, 32.952423], - ] - ) - - seq_length = 8 - embedding_size = 4 - - x = jnp.arange(seq_length * embedding_size * 1.0).reshape( - seq_length, embedding_size - ) - - rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size) - res = rope_embeddings(x) - - assert jnp.allclose(res, expected_values, atol=1e-6)