Skip to content

Commit

Permalink
fixed subtle cos/sin bug in RoPE, fixed tests, fixed offset overflow bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Artur-Galstyan committed Sep 6, 2024
1 parent 7675b03 commit 284db38
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 46 deletions.
31 changes: 19 additions & 12 deletions equinox/nn/_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,14 @@ def rotate_half(x: Float[Array, "seq_length embedding_size"]):

@staticmethod
def precompute_freqs_cis(
embedding_size: int, end: int, theta: float
embedding_size: int, end: Int[ArrayLike, ""], theta: float
) -> Complex[Array, "end half_of_embedding_size"]:
freqs = 1.0 / (
theta
** (jnp.arange(0.0, embedding_size, 2)[jnp.newaxis, :] / embedding_size)
)

t = jnp.arange(float(end))
t = jnp.arange(end / 1.0) # promote to float
freqs_outer = jnp.outer(t, freqs)
with jax.numpy_dtype_promotion("standard"):
freqs_cis = jnp.cos(freqs_outer) + jnp.sin(freqs_outer) * 1j
Expand All @@ -199,7 +199,7 @@ def precompute_freqs_cis(
def __call__(
self,
x: Float[Array, "seq_length embedding_size"],
offset: Int[Array, ""],
offset: Int[ArrayLike, ""] = 0,
*,
key: Optional[PRNGKeyArray] = None,
) -> Float[Array, "seq_length embedding_size"]:
Expand All @@ -224,30 +224,37 @@ def __call__(
)

with jax.ensure_compile_time_eval():
min_required_seq_len = offset + seq_len # pyright: ignore TODO: fix typing
if embedding_size in internal_rope_embedding_cache:
freqs_cis = internal_rope_embedding_cache[embedding_size]
freqs_cis_seq_len, _ = freqs_cis.shape
if seq_len > freqs_cis_seq_len:
if min_required_seq_len > freqs_cis_seq_len: # pyright: ignore TODO: fix typing
freqs_cis = self.precompute_freqs_cis(
embedding_size, seq_len, self.theta
embedding_size, min_required_seq_len, self.theta
)
internal_rope_embedding_cache[embedding_size] = freqs_cis
else:
freqs_cis = freqs_cis[:seq_len]
freqs_cis = freqs_cis[:min_required_seq_len]
else:
freqs_cis = self.precompute_freqs_cis(
embedding_size, seq_len, self.theta
embedding_size, min_required_seq_len, self.theta
)
internal_rope_embedding_cache[embedding_size] = freqs_cis

freqs_cis = jax.lax.dynamic_slice_in_dim(freqs_cis, offset, seq_len)

freqs_real = jnp.tile(freqs_cis.real, (1, 2))
freqs_imag = jnp.tile(freqs_cis.imag, (1, 2))
half_size = embedding_size // 2
freqs_cos = freqs_cis.real[:, :half_size]
freqs_sin = freqs_cis.imag[:, :half_size]

x_cos, x_sin = x[..., :half_size], x[..., half_size:]

rotate_x = self.rotate_half(x)
x_rope_cos = x_cos * freqs_cos - x_sin * freqs_sin
x_rope_sin = x_cos * freqs_sin + x_sin * freqs_cos

x_rope = jnp.stack([x_rope_cos, x_rope_sin], axis=-1).reshape(
seq_len, embedding_size
)

x_rope = (x * freqs_real) + (rotate_x * freqs_imag)
return x_rope


Expand Down
112 changes: 78 additions & 34 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,42 +1407,86 @@ def test_rope_embeddings_freqs_cis():


def test_rope_embeddings_values():
# values are generated using
# the script in this gist:
# https://gist.github.com/Artur-Galstyan/d33eda74072fea61545127adb90197b5
# Those values are generated based on the HuggingFace implementation
# of the Rope embeddings
# (see here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_flax_llama.py#L169)
expected_values = jnp.array(
[
[
0.0,
1.0,
2.0,
3.0,
],
[-2.887617, 4.9297514, 6.6076975, 7.0496492],
[-12.422148, 8.778215, 3.1129112, 11.177788],
[-13.85559, 12.544218, -12.166454, 15.383192],
[3.1641474, 16.226604, -23.874424, 19.664621],
[26.769577, 19.824234, -12.937918, 24.020819],
[30.30889, 23.335985, 18.258457, 28.450514],
[1.3996639, 26.760752, 41.01269, 32.952423],
]
)
# These values are generation using the RoPE implementation from lucidrains
# https://github.com/lucidrains/rotary-embedding-torch

seq_length = 8
embedding_size = 4
# The gist that generates these values can be found here:
# https://gist.github.com/Artur-Galstyan/8fd9df6d09a5262671dd934d43f91663

x = jnp.arange(seq_length * embedding_size * 1.0).reshape(
seq_length, embedding_size
)
expected_values = [
jnp.array(
[
[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
[-0.3012, 1.3818, 0.8952, 1.0948, 0.9900, 1.0099, 0.9990, 1.0010],
[-1.3254, 0.4932, 0.7814, 1.1787, 0.9798, 1.0198, 0.9980, 1.0020],
[-1.1311, -0.8489, 0.6598, 1.2509, 0.9696, 1.0295, 0.9970, 1.0030],
]
),
jnp.array(
[
[-0.3012, 1.3818, 0.8952, 1.0948, 0.9900, 1.0099, 0.9990, 1.0010],
[-1.3254, 0.4932, 0.7814, 1.1787, 0.9798, 1.0198, 0.9980, 1.0020],
[-1.1311, -0.8489, 0.6598, 1.2509, 0.9696, 1.0295, 0.9970, 1.0030],
[0.1032, -1.4104, 0.5316, 1.3105, 0.9592, 1.0392, 0.9960, 1.0040],
]
),
jnp.array(
[
[-1.3254, 0.4932, 0.7814, 1.1787, 0.9798, 1.0198, 0.9980, 1.0020],
[-1.1311, -0.8489, 0.6598, 1.2509, 0.9696, 1.0295, 0.9970, 1.0030],
[0.1032, -1.4104, 0.5316, 1.3105, 0.9592, 1.0392, 0.9960, 1.0040],
[1.2426, -0.6753, 0.3982, 1.3570, 0.9488, 1.0487, 0.9950, 1.0050],
]
),
jnp.array(
[
[-1.1311, -0.8489, 0.6598, 1.2509, 0.9696, 1.0295, 0.9970, 1.0030],
[0.1032, -1.4104, 0.5316, 1.3105, 0.9592, 1.0392, 0.9960, 1.0040],
[1.2426, -0.6753, 0.3982, 1.3570, 0.9488, 1.0487, 0.9950, 1.0050],
[1.2396, 0.6808, 0.2607, 1.3900, 0.9382, 1.0582, 0.9940, 1.0060],
]
),
jnp.array(
[
[0.1032, -1.4104, 0.5316, 1.3105, 0.9592, 1.0392, 0.9960, 1.0040],
[1.2426, -0.6753, 0.3982, 1.3570, 0.9488, 1.0487, 0.9950, 1.0050],
[1.2396, 0.6808, 0.2607, 1.3900, 0.9382, 1.0582, 0.9940, 1.0060],
[0.0969, 1.4109, 0.1206, 1.4091, 0.9276, 1.0675, 0.9930, 1.0070],
]
),
jnp.array(
[
[1.2426, -0.6753, 0.3982, 1.3570, 0.9488, 1.0487, 0.9950, 1.0050],
[1.2396, 0.6808, 0.2607, 1.3900, 0.9382, 1.0582, 0.9940, 1.0060],
[0.0969, 1.4109, 0.1206, 1.4091, 0.9276, 1.0675, 0.9930, 1.0070],
[-1.1349, 0.8439, -0.0206, 1.4141, 0.9169, 1.0767, 0.9920, 1.0080],
]
),
jnp.array(
[
[1.2396, 0.6808, 0.2607, 1.3900, 0.9382, 1.0582, 0.9940, 1.0060],
[0.0969, 1.4109, 0.1206, 1.4091, 0.9276, 1.0675, 0.9930, 1.0070],
[-1.1349, 0.8439, -0.0206, 1.4141, 0.9169, 1.0767, 0.9920, 1.0080],
[-1.3232, -0.4990, -0.1617, 1.4049, 0.9061, 1.0858, 0.9910, 1.0090],
]
),
jnp.array(
[
[0.0969, 1.4109, 0.1206, 1.4091, 0.9276, 1.0675, 0.9930, 1.0070],
[-1.1349, 0.8439, -0.0206, 1.4141, 0.9169, 1.0767, 0.9920, 1.0080],
[-1.3232, -0.4990, -0.1617, 1.4049, 0.9061, 1.0858, 0.9910, 1.0090],
[-0.2951, -1.3831, -0.3012, 1.3818, 0.8952, 1.0948, 0.9900, 1.0099],
]
),
]

rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size)
rope_embeddings = functools.partial(rope_embeddings, offset=jnp.array(0))
res = rope_embeddings(x)
embedding_size = 8
seq_len = 4

assert jnp.allclose(res, expected_values, atol=1e-6)
res = rope_embeddings(x)
x = jnp.ones(shape=(seq_len, embedding_size))

assert jnp.allclose(res, expected_values, atol=1e-6)
rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size)
rope_embeddings = eqx.filter_jit(rope_embeddings)
for i in range(len(expected_values)):
out = rope_embeddings(x, i)
assert jnp.allclose(out, expected_values[i], atol=1e-4)

0 comments on commit 284db38

Please sign in to comment.