From 284db389a7f27db75026ad43b7f09d933c7be0fd Mon Sep 17 00:00:00 2001 From: Artur Galstyan Date: Fri, 6 Sep 2024 22:10:53 +0200 Subject: [PATCH] fixed subtle cos/sin bug in RoPE, fixed tests, fixed offset overflow bug --- equinox/nn/_embedding.py | 31 ++++++----- tests/test_nn.py | 112 +++++++++++++++++++++++++++------------ 2 files changed, 97 insertions(+), 46 deletions(-) diff --git a/equinox/nn/_embedding.py b/equinox/nn/_embedding.py index e2871483..e52aa6f4 100644 --- a/equinox/nn/_embedding.py +++ b/equinox/nn/_embedding.py @@ -181,14 +181,14 @@ def rotate_half(x: Float[Array, "seq_length embedding_size"]): @staticmethod def precompute_freqs_cis( - embedding_size: int, end: int, theta: float + embedding_size: int, end: Int[ArrayLike, ""], theta: float ) -> Complex[Array, "end half_of_embedding_size"]: freqs = 1.0 / ( theta ** (jnp.arange(0.0, embedding_size, 2)[jnp.newaxis, :] / embedding_size) ) - t = jnp.arange(float(end)) + t = jnp.arange(end / 1.0) # promote to float freqs_outer = jnp.outer(t, freqs) with jax.numpy_dtype_promotion("standard"): freqs_cis = jnp.cos(freqs_outer) + jnp.sin(freqs_outer) * 1j @@ -199,7 +199,7 @@ def precompute_freqs_cis( def __call__( self, x: Float[Array, "seq_length embedding_size"], - offset: Int[Array, ""], + offset: Int[ArrayLike, ""] = 0, *, key: Optional[PRNGKeyArray] = None, ) -> Float[Array, "seq_length embedding_size"]: @@ -224,30 +224,37 @@ def __call__( ) with jax.ensure_compile_time_eval(): + min_required_seq_len = offset + seq_len # pyright: ignore TODO: fix typing if embedding_size in internal_rope_embedding_cache: freqs_cis = internal_rope_embedding_cache[embedding_size] freqs_cis_seq_len, _ = freqs_cis.shape - if seq_len > freqs_cis_seq_len: + if min_required_seq_len > freqs_cis_seq_len: # pyright: ignore TODO: fix typing freqs_cis = self.precompute_freqs_cis( - embedding_size, seq_len, self.theta + embedding_size, min_required_seq_len, self.theta ) internal_rope_embedding_cache[embedding_size] = freqs_cis else: - freqs_cis = freqs_cis[:seq_len] + freqs_cis = freqs_cis[:min_required_seq_len] else: freqs_cis = self.precompute_freqs_cis( - embedding_size, seq_len, self.theta + embedding_size, min_required_seq_len, self.theta ) internal_rope_embedding_cache[embedding_size] = freqs_cis - freqs_cis = jax.lax.dynamic_slice_in_dim(freqs_cis, offset, seq_len) - freqs_real = jnp.tile(freqs_cis.real, (1, 2)) - freqs_imag = jnp.tile(freqs_cis.imag, (1, 2)) + half_size = embedding_size // 2 + freqs_cos = freqs_cis.real[:, :half_size] + freqs_sin = freqs_cis.imag[:, :half_size] + + x_cos, x_sin = x[..., :half_size], x[..., half_size:] - rotate_x = self.rotate_half(x) + x_rope_cos = x_cos * freqs_cos - x_sin * freqs_sin + x_rope_sin = x_cos * freqs_sin + x_sin * freqs_cos + + x_rope = jnp.stack([x_rope_cos, x_rope_sin], axis=-1).reshape( + seq_len, embedding_size + ) - x_rope = (x * freqs_real) + (rotate_x * freqs_imag) return x_rope diff --git a/tests/test_nn.py b/tests/test_nn.py index 6806116a..125f4e79 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -1407,42 +1407,86 @@ def test_rope_embeddings_freqs_cis(): def test_rope_embeddings_values(): - # values are generated using - # the script in this gist: - # https://gist.github.com/Artur-Galstyan/d33eda74072fea61545127adb90197b5 - # Those values are generated based on the HuggingFace implementation - # of the Rope embeddings - # (see here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_flax_llama.py#L169) - expected_values = jnp.array( - [ - [ - 0.0, - 1.0, - 2.0, - 3.0, - ], - [-2.887617, 4.9297514, 6.6076975, 7.0496492], - [-12.422148, 8.778215, 3.1129112, 11.177788], - [-13.85559, 12.544218, -12.166454, 15.383192], - [3.1641474, 16.226604, -23.874424, 19.664621], - [26.769577, 19.824234, -12.937918, 24.020819], - [30.30889, 23.335985, 18.258457, 28.450514], - [1.3996639, 26.760752, 41.01269, 32.952423], - ] - ) + # These values are generation using the RoPE implementation from lucidrains + # https://github.com/lucidrains/rotary-embedding-torch - seq_length = 8 - embedding_size = 4 + # The gist that generates these values can be found here: + # https://gist.github.com/Artur-Galstyan/8fd9df6d09a5262671dd934d43f91663 - x = jnp.arange(seq_length * embedding_size * 1.0).reshape( - seq_length, embedding_size - ) + expected_values = [ + jnp.array( + [ + [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], + [-0.3012, 1.3818, 0.8952, 1.0948, 0.9900, 1.0099, 0.9990, 1.0010], + [-1.3254, 0.4932, 0.7814, 1.1787, 0.9798, 1.0198, 0.9980, 1.0020], + [-1.1311, -0.8489, 0.6598, 1.2509, 0.9696, 1.0295, 0.9970, 1.0030], + ] + ), + jnp.array( + [ + [-0.3012, 1.3818, 0.8952, 1.0948, 0.9900, 1.0099, 0.9990, 1.0010], + [-1.3254, 0.4932, 0.7814, 1.1787, 0.9798, 1.0198, 0.9980, 1.0020], + [-1.1311, -0.8489, 0.6598, 1.2509, 0.9696, 1.0295, 0.9970, 1.0030], + [0.1032, -1.4104, 0.5316, 1.3105, 0.9592, 1.0392, 0.9960, 1.0040], + ] + ), + jnp.array( + [ + [-1.3254, 0.4932, 0.7814, 1.1787, 0.9798, 1.0198, 0.9980, 1.0020], + [-1.1311, -0.8489, 0.6598, 1.2509, 0.9696, 1.0295, 0.9970, 1.0030], + [0.1032, -1.4104, 0.5316, 1.3105, 0.9592, 1.0392, 0.9960, 1.0040], + [1.2426, -0.6753, 0.3982, 1.3570, 0.9488, 1.0487, 0.9950, 1.0050], + ] + ), + jnp.array( + [ + [-1.1311, -0.8489, 0.6598, 1.2509, 0.9696, 1.0295, 0.9970, 1.0030], + [0.1032, -1.4104, 0.5316, 1.3105, 0.9592, 1.0392, 0.9960, 1.0040], + [1.2426, -0.6753, 0.3982, 1.3570, 0.9488, 1.0487, 0.9950, 1.0050], + [1.2396, 0.6808, 0.2607, 1.3900, 0.9382, 1.0582, 0.9940, 1.0060], + ] + ), + jnp.array( + [ + [0.1032, -1.4104, 0.5316, 1.3105, 0.9592, 1.0392, 0.9960, 1.0040], + [1.2426, -0.6753, 0.3982, 1.3570, 0.9488, 1.0487, 0.9950, 1.0050], + [1.2396, 0.6808, 0.2607, 1.3900, 0.9382, 1.0582, 0.9940, 1.0060], + [0.0969, 1.4109, 0.1206, 1.4091, 0.9276, 1.0675, 0.9930, 1.0070], + ] + ), + jnp.array( + [ + [1.2426, -0.6753, 0.3982, 1.3570, 0.9488, 1.0487, 0.9950, 1.0050], + [1.2396, 0.6808, 0.2607, 1.3900, 0.9382, 1.0582, 0.9940, 1.0060], + [0.0969, 1.4109, 0.1206, 1.4091, 0.9276, 1.0675, 0.9930, 1.0070], + [-1.1349, 0.8439, -0.0206, 1.4141, 0.9169, 1.0767, 0.9920, 1.0080], + ] + ), + jnp.array( + [ + [1.2396, 0.6808, 0.2607, 1.3900, 0.9382, 1.0582, 0.9940, 1.0060], + [0.0969, 1.4109, 0.1206, 1.4091, 0.9276, 1.0675, 0.9930, 1.0070], + [-1.1349, 0.8439, -0.0206, 1.4141, 0.9169, 1.0767, 0.9920, 1.0080], + [-1.3232, -0.4990, -0.1617, 1.4049, 0.9061, 1.0858, 0.9910, 1.0090], + ] + ), + jnp.array( + [ + [0.0969, 1.4109, 0.1206, 1.4091, 0.9276, 1.0675, 0.9930, 1.0070], + [-1.1349, 0.8439, -0.0206, 1.4141, 0.9169, 1.0767, 0.9920, 1.0080], + [-1.3232, -0.4990, -0.1617, 1.4049, 0.9061, 1.0858, 0.9910, 1.0090], + [-0.2951, -1.3831, -0.3012, 1.3818, 0.8952, 1.0948, 0.9900, 1.0099], + ] + ), + ] - rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size) - rope_embeddings = functools.partial(rope_embeddings, offset=jnp.array(0)) - res = rope_embeddings(x) + embedding_size = 8 + seq_len = 4 - assert jnp.allclose(res, expected_values, atol=1e-6) - res = rope_embeddings(x) + x = jnp.ones(shape=(seq_len, embedding_size)) - assert jnp.allclose(res, expected_values, atol=1e-6) + rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size) + rope_embeddings = eqx.filter_jit(rope_embeddings) + for i in range(len(expected_values)): + out = rope_embeddings(x, i) + assert jnp.allclose(out, expected_values[i], atol=1e-4)