Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the offset to RoPE embedding and fixed the pre-commit pyright #799

Closed
wants to merge 8 commits into from
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ examples/MNIST
examples/multipart_serialised.eqx
.python-version
.DS_Store
.ruff_cache
.pytest_cache
.venv
46 changes: 33 additions & 13 deletions equinox/nn/_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,22 @@ def __call__(...):
def process_heads(
query_heads: Float[Array, "seq_length num_heads qk_size"],
key_heads: Float[Array, "seq_length num_heads qk_size"],
value_heads: Float[Array, "seq_length num_heads vo_size"]
value_heads: Float[Array, "seq_length num_heads vo_size"],
index: Int[Array, ""]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will fail, as we don't pass this extra argument inside MHA?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, you're right. I've added a more complete example for this.

) -> tuple[
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads vo_size"]
]:
query_heads = jax.vmap(self.rope_embeddings,
in_axes=1,
out_axes=1)(query_heads)
key_heads = jax.vmap(self.rope_embeddings,
in_axes=1,
out_axes=1)(key_heads)
# index is the autoregressive index of the current token
rope_partial = functools.partial(
rope_embeddings,
offset=index
)
query_heads = jax.vmap(rope_partial, in_axes=1, out_axes=1)
(query_heads)
key_heads = jax.vmap(rope_partial, in_axes=1, out_axes=1)
(key_heads)

return query_heads, key_heads, value_heads

Expand All @@ -161,13 +165,16 @@ def process_heads(
"""

embedding_size: int = field(static=True)
max_seq_length: int = field(static=True)
Artur-Galstyan marked this conversation as resolved.
Show resolved Hide resolved
theta: float = field(static=True, default=10_000.0)

def __check_init__(self):
if self.embedding_size < 0:
raise ValueError("`embedding_size` must not be negative.")
if (self.embedding_size % 2) != 0:
raise ValueError("`embedding_size` must be even.")
if self.max_seq_length < 0:
raise ValueError("`max_seq_length` must not be negative.")

@staticmethod
def rotate_half(x: Float[Array, "seq_length embedding_size"]):
Expand All @@ -194,12 +201,14 @@ def precompute_freqs_cis(
def __call__(
self,
x: Float[Array, "seq_length embedding_size"],
offset: Int[Array, ""] = jnp.array(0),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heads-up that we can't use JAX arrays at the global level in Equinox. (Which default values are.)

This is an internal Google restriction, and I try not to break my former coworkers!

Copy link
Contributor Author

@Artur-Galstyan Artur-Galstyan Aug 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I didn't know. Is this just Equinox specific or should I in general avoid JAX arrays for defaults?

In any case, I made this mandatory then and added some more examples. Unfortunately, something like:

offset: Optional[Int[Array, ""]] = None

offset = offset if offset is not None else jnp.array(0)
...

doesn't work under JIT and jax.lax.select doesn't work either because it evaluates both branches.

You don't happen to have a solution for these kind of issues somewhere in your JAX wizard hat? Understanding JAX's behaviour under JIT is definitely something I need to catch up on!

*,
key: Optional[PRNGKeyArray] = None,
) -> Float[Array, "seq_length embedding_size"]:
"""**Arguments:**

- `x`: A JAX array of shape `(seq_length, embedding_size)`.
- `offset`: The offset to apply to the positional encoding.
- `key`: Ignored; provided for compatibility with the rest of the Equinox API.
(Keyword only argument.)

Expand All @@ -215,37 +224,48 @@ def __call__(
f"x.shape[-1] must match self.embedding_size, "
f"but {x.shape[-1]} != {self.embedding_size}"
)
if seq_len > self.max_seq_length:
raise ValueError(
f"seq_len must be less than or equal to self.max_seq_length, "
f"but {seq_len} > {self.max_seq_length}"
)

with jax.ensure_compile_time_eval():
if embedding_size in internal_rope_embedding_cache:
freqs_cis = internal_rope_embedding_cache[embedding_size]
freqs_cis_seq_len, _ = freqs_cis.shape
if seq_len > freqs_cis_seq_len:
if self.max_seq_length > freqs_cis_seq_len:
freqs_cis = self.precompute_freqs_cis(
embedding_size, seq_len, self.theta
embedding_size, self.max_seq_length, self.theta
)
internal_rope_embedding_cache[embedding_size] = freqs_cis
else:
freqs_cis = freqs_cis[:seq_len]
freqs_cis = freqs_cis[: self.max_seq_length]
else:
freqs_cis = self.precompute_freqs_cis(
embedding_size, seq_len, self.theta
embedding_size, self.max_seq_length, self.theta
)
internal_rope_embedding_cache[embedding_size] = freqs_cis

freqs_cis = jax.lax.dynamic_slice_in_dim(freqs_cis, offset, seq_len)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks wrong to me. If we hit the else branch of if embedding_size in internal_rope_embedding_cache ... else: above, then we'll compute an array of length seq_len, which will not all be valid when sliced into here -- we'll be indexing off the end.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was a bug. The fix is to ensure that freqs_cis is at least as long as seq_len + offset.


freqs_real = jnp.tile(freqs_cis.real, (1, 2))
freqs_imag = jnp.tile(freqs_cis.imag, (1, 2))

rotate_x = self.rotate_half(x)

x_rope = (x * freqs_real) + (rotate_x * freqs_imag)
return x_rope


RotaryPositionalEmbedding.__init__.__doc__ = """**Arguments:**

- `embedding_size`: Size of the token embeddings. Must be non-negative and even.
- `theta`: The base frequency for the sinusoidal functions. It defines the rate
of oscillation for the sine and cosine waves that encode positional information
- `theta`: The base frequency for the sinusoidal functions. It defines the rate
of oscillation for the sine and cosine waves that encode positional information
into the embeddings. The larger the theta value, the slower the oscillations
and vice versa. Defaults to 10_000.0
- `max_seq_length`: The maximum sequence length for which to precompute the
positional encodings. This is used to determine the size of the precomputed
positional encodings.
"""
16 changes: 12 additions & 4 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ def test_mlp_learnt_activation():
key=jrandom.PRNGKey(5678),
)
x = jnp.array([0.5, 0.7])
assert mlp.activation.negative_slope.shape == (2, 8)
assert mlp.final_activation.negative_slope.shape == (5,)
assert mlp.activation.negative_slope.shape == (2, 8) # pyright: ignore
assert mlp.final_activation.negative_slope.shape == (5,) # pyright: ignore

@eqx.filter_jit
@eqx.filter_grad
Expand Down Expand Up @@ -1352,13 +1352,16 @@ def test_prelu(getkey):

def test_rope_embeddings_shapes(getkey):
embedding_size = 32
rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size)

n_heads = 4
seq_length = 8
query_size = 32
key_size = 32

rope_embeddings = eqx.nn.RotaryPositionalEmbedding(
embedding_size, max_seq_length=seq_length
)

query_heads = jax.random.normal(
key=getkey(), shape=(seq_length, n_heads, query_size)
)
Expand Down Expand Up @@ -1435,7 +1438,12 @@ def test_rope_embeddings_values():
seq_length, embedding_size
)

rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size)
rope_embeddings = eqx.nn.RotaryPositionalEmbedding(
embedding_size, max_seq_length=seq_length
)
res = rope_embeddings(x)

assert jnp.allclose(res, expected_values, atol=1e-6)
res = rope_embeddings(x)

assert jnp.allclose(res, expected_values, atol=1e-6)
Loading