Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the offset to RoPE embedding and fixed the pre-commit pyright #799

Closed
wants to merge 8 commits into from
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ examples/MNIST
examples/multipart_serialised.eqx
.python-version
.DS_Store
.ruff_cache
.pytest_cache
.venv
56 changes: 40 additions & 16 deletions equinox/nn/_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,31 +118,50 @@ class RotaryPositionalEmbedding(Module, strict=True):
```python
class TransformerBlock(eqx.Module):
rope_embeddings: RotaryPositionalEmbedding
mha_attention: MultiheadAttention
Artur-Galstyan marked this conversation as resolved.
Show resolved Hide resolved

def __init__(...):
self.rope_embeddings = RotaryPositionalEmbedding(...)
def __init__(self, embedding_size, ...):
self.rope_embeddings = RotaryPositionalEmbedding(embedding_size)
self.mha_attention = MultiheadAttention(...)

def __call__(...):
def __call__(self, query, key_, value, index):
def process_heads(
query_heads: Float[Array, "seq_length num_heads qk_size"],
key_heads: Float[Array, "seq_length num_heads qk_size"],
value_heads: Float[Array, "seq_length num_heads vo_size"]
value_heads: Float[Array, "seq_length num_heads vo_size"],
index: Int[Array, ""],
) -> tuple[
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads vo_size"]
Float[Array, "seq_length num_heads vo_size"],
]:
query_heads = jax.vmap(self.rope_embeddings,
in_axes=1,
out_axes=1)(query_heads)
key_heads = jax.vmap(self.rope_embeddings,
in_axes=1,
out_axes=1)(key_heads)
# index is the autoregressive index of the current token
rope_p = functools.partial(self.rope_embeddings, offset=index)
query_heads = jax.vmap(rope_p, in_axes=1, out_axes=1)(query_heads)
key_heads = jax.vmap(rope_p, in_axes=1, out_axes=1)(key_heads)

return query_heads, key_heads, value_heads

x = self.mha_attention(... process_heads=process_heads)
...
x = self.mha_attention(
query=query,
key_=key_,
value=value,
process_heads=functools.partial(process_heads, index=index),
)

return x


transformer_block = eqx.filter_jit(
TransformerBlock(embedding_size, ...)
)

q = jnp.ones(shape=(seq_length, query_size))
k = jnp.ones(shape=(seq_length, query_size))
v = jnp.ones(shape=(seq_length, query_size))

out = transformer_block(q, k, v, jnp.array(0))
out = transformer_block(q, k, v, jnp.array(1))
```

??? cite
Expand Down Expand Up @@ -194,12 +213,14 @@ def precompute_freqs_cis(
def __call__(
self,
x: Float[Array, "seq_length embedding_size"],
offset: Int[Array, ""],
Artur-Galstyan marked this conversation as resolved.
Show resolved Hide resolved
*,
key: Optional[PRNGKeyArray] = None,
) -> Float[Array, "seq_length embedding_size"]:
"""**Arguments:**

- `x`: A JAX array of shape `(seq_length, embedding_size)`.
- `offset`: The offset to apply to the positional encoding.
- `key`: Ignored; provided for compatibility with the rest of the Equinox API.
(Keyword only argument.)

Expand All @@ -208,8 +229,8 @@ def __call__(
A JAX array of shape `(seq_length, embedding_size)`, with the rotary positional
encoding applied to the input.
"""

seq_len, embedding_size = x.shape

if embedding_size != self.embedding_size:
raise ValueError(
f"x.shape[-1] must match self.embedding_size, "
Expand All @@ -233,19 +254,22 @@ def __call__(
)
internal_rope_embedding_cache[embedding_size] = freqs_cis

freqs_cis = jax.lax.dynamic_slice_in_dim(freqs_cis, offset, seq_len)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks wrong to me. If we hit the else branch of if embedding_size in internal_rope_embedding_cache ... else: above, then we'll compute an array of length seq_len, which will not all be valid when sliced into here -- we'll be indexing off the end.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was a bug. The fix is to ensure that freqs_cis is at least as long as seq_len + offset.


freqs_real = jnp.tile(freqs_cis.real, (1, 2))
freqs_imag = jnp.tile(freqs_cis.imag, (1, 2))

rotate_x = self.rotate_half(x)

x_rope = (x * freqs_real) + (rotate_x * freqs_imag)
return x_rope


RotaryPositionalEmbedding.__init__.__doc__ = """**Arguments:**

- `embedding_size`: Size of the token embeddings. Must be non-negative and even.
- `theta`: The base frequency for the sinusoidal functions. It defines the rate
of oscillation for the sine and cosine waves that encode positional information
- `theta`: The base frequency for the sinusoidal functions. It defines the rate
of oscillation for the sine and cosine waves that encode positional information
into the embeddings. The larger the theta value, the slower the oscillations
and vice versa. Defaults to 10_000.0
"""
13 changes: 10 additions & 3 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import warnings
from typing import Union

Expand Down Expand Up @@ -238,8 +239,8 @@ def test_mlp_learnt_activation():
key=jrandom.PRNGKey(5678),
)
x = jnp.array([0.5, 0.7])
assert mlp.activation.negative_slope.shape == (2, 8)
assert mlp.final_activation.negative_slope.shape == (5,)
assert mlp.activation.negative_slope.shape == (2, 8) # pyright: ignore
assert mlp.final_activation.negative_slope.shape == (5,) # pyright: ignore

@eqx.filter_jit
@eqx.filter_grad
Expand Down Expand Up @@ -1352,13 +1353,15 @@ def test_prelu(getkey):

def test_rope_embeddings_shapes(getkey):
embedding_size = 32
rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size)

n_heads = 4
seq_length = 8
query_size = 32
key_size = 32

rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size)
rope_embeddings = functools.partial(rope_embeddings, offset=jnp.array(0))

query_heads = jax.random.normal(
key=getkey(), shape=(seq_length, n_heads, query_size)
)
Expand Down Expand Up @@ -1436,6 +1439,10 @@ def test_rope_embeddings_values():
)

rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size)
rope_embeddings = functools.partial(rope_embeddings, offset=jnp.array(0))
res = rope_embeddings(x)

assert jnp.allclose(res, expected_values, atol=1e-6)
res = rope_embeddings(x)

assert jnp.allclose(res, expected_values, atol=1e-6)
Loading