-
-
Notifications
You must be signed in to change notification settings - Fork 142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added the offset to RoPE embedding and fixed the pre-commit pyright #799
Changes from 1 commit
7660d67
285cce7
539eaa1
cd28d7b
68867b1
7675b03
284db38
8335792
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,6 @@ examples/MNIST | |
examples/multipart_serialised.eqx | ||
.python-version | ||
.DS_Store | ||
.ruff_cache | ||
.pytest_cache | ||
.venv |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -126,18 +126,22 @@ def __call__(...): | |
def process_heads( | ||
query_heads: Float[Array, "seq_length num_heads qk_size"], | ||
key_heads: Float[Array, "seq_length num_heads qk_size"], | ||
value_heads: Float[Array, "seq_length num_heads vo_size"] | ||
value_heads: Float[Array, "seq_length num_heads vo_size"], | ||
index: Int[Array, ""] | ||
) -> tuple[ | ||
Float[Array, "seq_length num_heads qk_size"], | ||
Float[Array, "seq_length num_heads qk_size"], | ||
Float[Array, "seq_length num_heads vo_size"] | ||
]: | ||
query_heads = jax.vmap(self.rope_embeddings, | ||
in_axes=1, | ||
out_axes=1)(query_heads) | ||
key_heads = jax.vmap(self.rope_embeddings, | ||
in_axes=1, | ||
out_axes=1)(key_heads) | ||
# index is the autoregressive index of the current token | ||
rope_partial = functools.partial( | ||
rope_embeddings, | ||
offset=index | ||
) | ||
query_heads = jax.vmap(rope_partial, in_axes=1, out_axes=1) | ||
(query_heads) | ||
key_heads = jax.vmap(rope_partial, in_axes=1, out_axes=1) | ||
(key_heads) | ||
|
||
return query_heads, key_heads, value_heads | ||
|
||
|
@@ -161,13 +165,16 @@ def process_heads( | |
""" | ||
|
||
embedding_size: int = field(static=True) | ||
max_seq_length: int = field(static=True) | ||
Artur-Galstyan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
theta: float = field(static=True, default=10_000.0) | ||
|
||
def __check_init__(self): | ||
if self.embedding_size < 0: | ||
raise ValueError("`embedding_size` must not be negative.") | ||
if (self.embedding_size % 2) != 0: | ||
raise ValueError("`embedding_size` must be even.") | ||
if self.max_seq_length < 0: | ||
raise ValueError("`max_seq_length` must not be negative.") | ||
|
||
@staticmethod | ||
def rotate_half(x: Float[Array, "seq_length embedding_size"]): | ||
|
@@ -194,12 +201,14 @@ def precompute_freqs_cis( | |
def __call__( | ||
self, | ||
x: Float[Array, "seq_length embedding_size"], | ||
offset: Int[Array, ""] = jnp.array(0), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Heads-up that we can't use JAX arrays at the global level in Equinox. (Which default values are.) This is an internal Google restriction, and I try not to break my former coworkers! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I didn't know. Is this just Equinox specific or should I in general avoid JAX arrays for defaults? In any case, I made this mandatory then and added some more examples. Unfortunately, something like: offset: Optional[Int[Array, ""]] = None
offset = offset if offset is not None else jnp.array(0)
... doesn't work under JIT and You don't happen to have a solution for these kind of issues somewhere in your JAX wizard hat? Understanding JAX's behaviour under JIT is definitely something I need to catch up on! |
||
*, | ||
key: Optional[PRNGKeyArray] = None, | ||
) -> Float[Array, "seq_length embedding_size"]: | ||
"""**Arguments:** | ||
|
||
- `x`: A JAX array of shape `(seq_length, embedding_size)`. | ||
- `offset`: The offset to apply to the positional encoding. | ||
- `key`: Ignored; provided for compatibility with the rest of the Equinox API. | ||
(Keyword only argument.) | ||
|
||
|
@@ -215,37 +224,48 @@ def __call__( | |
f"x.shape[-1] must match self.embedding_size, " | ||
f"but {x.shape[-1]} != {self.embedding_size}" | ||
) | ||
if seq_len > self.max_seq_length: | ||
raise ValueError( | ||
f"seq_len must be less than or equal to self.max_seq_length, " | ||
f"but {seq_len} > {self.max_seq_length}" | ||
) | ||
|
||
with jax.ensure_compile_time_eval(): | ||
if embedding_size in internal_rope_embedding_cache: | ||
freqs_cis = internal_rope_embedding_cache[embedding_size] | ||
freqs_cis_seq_len, _ = freqs_cis.shape | ||
if seq_len > freqs_cis_seq_len: | ||
if self.max_seq_length > freqs_cis_seq_len: | ||
freqs_cis = self.precompute_freqs_cis( | ||
embedding_size, seq_len, self.theta | ||
embedding_size, self.max_seq_length, self.theta | ||
) | ||
internal_rope_embedding_cache[embedding_size] = freqs_cis | ||
else: | ||
freqs_cis = freqs_cis[:seq_len] | ||
freqs_cis = freqs_cis[: self.max_seq_length] | ||
else: | ||
freqs_cis = self.precompute_freqs_cis( | ||
embedding_size, seq_len, self.theta | ||
embedding_size, self.max_seq_length, self.theta | ||
) | ||
internal_rope_embedding_cache[embedding_size] = freqs_cis | ||
|
||
freqs_cis = jax.lax.dynamic_slice_in_dim(freqs_cis, offset, seq_len) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks wrong to me. If we hit the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this was a bug. The fix is to ensure that |
||
|
||
freqs_real = jnp.tile(freqs_cis.real, (1, 2)) | ||
freqs_imag = jnp.tile(freqs_cis.imag, (1, 2)) | ||
|
||
rotate_x = self.rotate_half(x) | ||
|
||
x_rope = (x * freqs_real) + (rotate_x * freqs_imag) | ||
return x_rope | ||
|
||
|
||
RotaryPositionalEmbedding.__init__.__doc__ = """**Arguments:** | ||
|
||
- `embedding_size`: Size of the token embeddings. Must be non-negative and even. | ||
- `theta`: The base frequency for the sinusoidal functions. It defines the rate | ||
of oscillation for the sine and cosine waves that encode positional information | ||
- `theta`: The base frequency for the sinusoidal functions. It defines the rate | ||
of oscillation for the sine and cosine waves that encode positional information | ||
into the embeddings. The larger the theta value, the slower the oscillations | ||
and vice versa. Defaults to 10_000.0 | ||
- `max_seq_length`: The maximum sequence length for which to precompute the | ||
positional encodings. This is used to determine the size of the precomputed | ||
positional encodings. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will fail, as we don't pass this extra argument inside MHA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, you're right. I've added a more complete example for this.