Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the offset to RoPE embedding and fixed the pre-commit pyright #799

Closed
wants to merge 8 commits into from
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ examples/MNIST
examples/multipart_serialised.eqx
.python-version
.DS_Store
.ruff_cache
.pytest_cache
.venv
101 changes: 81 additions & 20 deletions equinox/nn/_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,31 +118,76 @@ class RotaryPositionalEmbedding(Module, strict=True):
```python
class TransformerBlock(eqx.Module):
rope_embeddings: RotaryPositionalEmbedding
mha_attention: MultiheadAttention
Artur-Galstyan marked this conversation as resolved.
Show resolved Hide resolved

def __init__(...):
self.rope_embeddings = RotaryPositionalEmbedding(...)
def __init__(self, embedding_size, max_seq_length, num_heads, query_size):
self.rope_embeddings = RotaryPositionalEmbedding(
embedding_size, max_seq_length
)
self.mha_attention = MultiheadAttention(
num_heads=num_heads, query_size=query_size, key=jax.random.key(0)
)

def __call__(...):
def __call__(self, query, key_, value, index):
def process_heads(
query_heads: Float[Array, "seq_length num_heads qk_size"],
key_heads: Float[Array, "seq_length num_heads qk_size"],
value_heads: Float[Array, "seq_length num_heads vo_size"]
value_heads: Float[Array, "seq_length num_heads vo_size"],
index: Int[Array, ""],
) -> tuple[
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads vo_size"]
Float[Array, "seq_length num_heads vo_size"],
]:
query_heads = jax.vmap(self.rope_embeddings,
in_axes=1,
out_axes=1)(query_heads)
key_heads = jax.vmap(self.rope_embeddings,
in_axes=1,
out_axes=1)(key_heads)
# index is the autoregressive index of the current token
rope_p = functools.partial(self.rope_embeddings, offset=index)
query_heads = jax.vmap(rope_p, in_axes=1, out_axes=1)(query_heads)
key_heads = jax.vmap(rope_p, in_axes=1, out_axes=1)(key_heads)

return query_heads, key_heads, value_heads

x = self.mha_attention(... process_heads=process_heads)
...
x = self.mha_attention(
query=query,
key_=key_,
value=value,
process_heads=functools.partial(process_heads, index=index),
)

return x

embedding_size = 32
max_seq_length = 8
seq_length = 4
num_heads = 2
query_size = 64

transformer_block = eqx.filter_jit(
TransformerBlock(embedding_size, max_seq_length, num_heads, query_size)
)

q = jnp.ones(shape=(seq_length, query_size))
k = jnp.ones(shape=(seq_length, query_size))
v = jnp.ones(shape=(seq_length, query_size))

out = transformer_block(q, k, v, jnp.array(0))
out = transformer_block(q, k, v, jnp.array(1)) # no re-JITing
```

If you're training a transformer, you likely don't want to use any offset. In
those cases, it can be helpful to use `functools.partial` like so:
```python
embedding_size = 32
max_seq_length = 8

rot_emb = RotaryPositionalEmbedding(
embedding_size=embedding_size, max_seq_length=max_seq_length
)
rot_emb = eqx.filter_jit(rot_emb)
rot_emb_no_offset = functools.partial(rot_emb, offset=jnp.array(0))

x = jnp.ones(shape=(max_seq_length, embedding_size))

assert jnp.allclose(rot_emb(x, offset=jnp.array(0)), rot_emb_no_offset(x))
```

??? cite
Expand All @@ -161,13 +206,16 @@ def process_heads(
"""

embedding_size: int = field(static=True)
max_seq_length: int = field(static=True)
Artur-Galstyan marked this conversation as resolved.
Show resolved Hide resolved
theta: float = field(static=True, default=10_000.0)

def __check_init__(self):
if self.embedding_size < 0:
raise ValueError("`embedding_size` must not be negative.")
if (self.embedding_size % 2) != 0:
raise ValueError("`embedding_size` must be even.")
if self.max_seq_length < 0:
raise ValueError("`max_seq_length` must not be negative.")

@staticmethod
def rotate_half(x: Float[Array, "seq_length embedding_size"]):
Expand All @@ -194,12 +242,14 @@ def precompute_freqs_cis(
def __call__(
self,
x: Float[Array, "seq_length embedding_size"],
offset: Int[Array, ""],
Artur-Galstyan marked this conversation as resolved.
Show resolved Hide resolved
*,
key: Optional[PRNGKeyArray] = None,
) -> Float[Array, "seq_length embedding_size"]:
"""**Arguments:**

- `x`: A JAX array of shape `(seq_length, embedding_size)`.
- `offset`: The offset to apply to the positional encoding.
- `key`: Ignored; provided for compatibility with the rest of the Equinox API.
(Keyword only argument.)

Expand All @@ -208,44 +258,55 @@ def __call__(
A JAX array of shape `(seq_length, embedding_size)`, with the rotary positional
encoding applied to the input.
"""

print("JIT ROPE")
Artur-Galstyan marked this conversation as resolved.
Show resolved Hide resolved
seq_len, embedding_size = x.shape
if embedding_size != self.embedding_size:
raise ValueError(
f"x.shape[-1] must match self.embedding_size, "
f"but {x.shape[-1]} != {self.embedding_size}"
)
if seq_len > self.max_seq_length:
raise ValueError(
f"seq_len must be less than or equal to self.max_seq_length, "
f"but {seq_len} > {self.max_seq_length}"
)

with jax.ensure_compile_time_eval():
if embedding_size in internal_rope_embedding_cache:
freqs_cis = internal_rope_embedding_cache[embedding_size]
freqs_cis_seq_len, _ = freqs_cis.shape
if seq_len > freqs_cis_seq_len:
if self.max_seq_length > freqs_cis_seq_len:
freqs_cis = self.precompute_freqs_cis(
embedding_size, seq_len, self.theta
embedding_size, self.max_seq_length, self.theta
)
internal_rope_embedding_cache[embedding_size] = freqs_cis
else:
freqs_cis = freqs_cis[:seq_len]
freqs_cis = freqs_cis[: self.max_seq_length]
else:
freqs_cis = self.precompute_freqs_cis(
embedding_size, seq_len, self.theta
embedding_size, self.max_seq_length, self.theta
)
internal_rope_embedding_cache[embedding_size] = freqs_cis

freqs_cis = jax.lax.dynamic_slice_in_dim(freqs_cis, offset, seq_len)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks wrong to me. If we hit the else branch of if embedding_size in internal_rope_embedding_cache ... else: above, then we'll compute an array of length seq_len, which will not all be valid when sliced into here -- we'll be indexing off the end.

Copy link
Sponsor Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was a bug. The fix is to ensure that freqs_cis is at least as long as seq_len + offset.


freqs_real = jnp.tile(freqs_cis.real, (1, 2))
freqs_imag = jnp.tile(freqs_cis.imag, (1, 2))

rotate_x = self.rotate_half(x)

x_rope = (x * freqs_real) + (rotate_x * freqs_imag)
return x_rope


RotaryPositionalEmbedding.__init__.__doc__ = """**Arguments:**

- `embedding_size`: Size of the token embeddings. Must be non-negative and even.
- `theta`: The base frequency for the sinusoidal functions. It defines the rate
of oscillation for the sine and cosine waves that encode positional information
- `theta`: The base frequency for the sinusoidal functions. It defines the rate
of oscillation for the sine and cosine waves that encode positional information
into the embeddings. The larger the theta value, the slower the oscillations
and vice versa. Defaults to 10_000.0
- `max_seq_length`: The maximum sequence length for which to precompute the
positional encodings. This is used to determine the size of the precomputed
positional encodings.
"""
19 changes: 15 additions & 4 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import warnings
from typing import Union

Expand Down Expand Up @@ -238,8 +239,8 @@ def test_mlp_learnt_activation():
key=jrandom.PRNGKey(5678),
)
x = jnp.array([0.5, 0.7])
assert mlp.activation.negative_slope.shape == (2, 8)
assert mlp.final_activation.negative_slope.shape == (5,)
assert mlp.activation.negative_slope.shape == (2, 8) # pyright: ignore
assert mlp.final_activation.negative_slope.shape == (5,) # pyright: ignore

@eqx.filter_jit
@eqx.filter_grad
Expand Down Expand Up @@ -1352,13 +1353,17 @@ def test_prelu(getkey):

def test_rope_embeddings_shapes(getkey):
embedding_size = 32
rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size)

n_heads = 4
seq_length = 8
query_size = 32
key_size = 32

rope_embeddings = eqx.nn.RotaryPositionalEmbedding(
embedding_size, max_seq_length=seq_length
)
rope_embeddings = functools.partial(rope_embeddings, offset=jnp.array(0))

query_heads = jax.random.normal(
key=getkey(), shape=(seq_length, n_heads, query_size)
)
Expand Down Expand Up @@ -1435,7 +1440,13 @@ def test_rope_embeddings_values():
seq_length, embedding_size
)

rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size)
rope_embeddings = eqx.nn.RotaryPositionalEmbedding(
embedding_size, max_seq_length=seq_length
)
rope_embeddings = functools.partial(rope_embeddings, offset=jnp.array(0))
res = rope_embeddings(x)

assert jnp.allclose(res, expected_values, atol=1e-6)
res = rope_embeddings(x)

assert jnp.allclose(res, expected_values, atol=1e-6)
Loading