-
-
Notifications
You must be signed in to change notification settings - Fork 136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added the offset to RoPE embedding and fixed the pre-commit pyright #799
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I don't think this was actually a bug before? It's really just a new feature for RotaryEmbedding
, in that it now supports an offset.
My concern here is that in order to use it, we'd need to make a backward-incompatible change to MHA. (To pass the index.)
Since we've decided against implementing KV caching in MHA, then I don't think this really comes up.
equinox/nn/_embedding.py
Outdated
@@ -126,18 +126,22 @@ def __call__(...): | |||
def process_heads( | |||
query_heads: Float[Array, "seq_length num_heads qk_size"], | |||
key_heads: Float[Array, "seq_length num_heads qk_size"], | |||
value_heads: Float[Array, "seq_length num_heads vo_size"] | |||
value_heads: Float[Array, "seq_length num_heads vo_size"], | |||
index: Int[Array, ""] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will fail, as we don't pass this extra argument inside MHA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, you're right. I've added a more complete example for this.
equinox/nn/_embedding.py
Outdated
@@ -194,12 +201,14 @@ def precompute_freqs_cis( | |||
def __call__( | |||
self, | |||
x: Float[Array, "seq_length embedding_size"], | |||
offset: Int[Array, ""] = jnp.array(0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Heads-up that we can't use JAX arrays at the global level in Equinox. (Which default values are.)
This is an internal Google restriction, and I try not to break my former coworkers!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I didn't know. Is this just Equinox specific or should I in general avoid JAX arrays for defaults?
In any case, I made this mandatory then and added some more examples. Unfortunately, something like:
offset: Optional[Int[Array, ""]] = None
offset = offset if offset is not None else jnp.array(0)
...
doesn't work under JIT and jax.lax.select
doesn't work either because it evaluates both branches.
You don't happen to have a solution for these kind of issues somewhere in your JAX wizard hat? Understanding JAX's behaviour under JIT is definitely something I need to catch up on!
I agree that this is not a bug per-se but still a shortcoming in our current version, especially when compared to a PyTorch counterpart which includes the
In the example I provided, it's even possible to make it work with the offset even without changing the MHA implementation by using x = self.mha_attention(
query=query,
key_=key_,
value=value,
process_heads=functools.partial(process_heads, index=index), # <--
) But I guess the compiler is pretty smart! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see.
Okay, so on balance I think I'm happy to have this extra argument, but for simplicity let's not include any extra examples in the documentation. We're not discussing this kind of progressive updating anywhere else, and I'd like to keep this change as easy to review as possible.
Ok, so I've removed the Unfold, if you want to knowInitially, I added it, because I wanted to avoid re-JITting. But in the current implementation, we are not enforcing the input array to the RoPE module to be of lengthmax_seq_length (and then applying a mask to effectively only "use" the intended seq_length ). The input array can have any seq_length , which means the module will always be re-JITted if the seq_length of the input array changes. Thus, the max_seq_length argument is obsolete. In other words, if we were to ever include the max_seq_length argument, we would also need to include not just the offset but also the cutoff position and then apply a mask on the input array. But that would be an even breakier change 😄 .(Though, personally, I don't mind breaking changes as long as they are an improvement and properly communicated)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Responding to your comment on backward compatibility -- I think this is very important :)
I'm willing to make a compatibility break if absolutely required, but that's generally been fairly unusual / often only on niche features / only in the early life of a library prior to achieving adoption.
It's just not providing a useful tool to those downstream if we do that kind of thing!
@@ -233,19 +254,22 @@ def __call__( | |||
) | |||
internal_rope_embedding_cache[embedding_size] = freqs_cis | |||
|
|||
freqs_cis = jax.lax.dynamic_slice_in_dim(freqs_cis, offset, seq_len) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks wrong to me. If we hit the else
branch of if embedding_size in internal_rope_embedding_cache ... else:
above, then we'll compute an array of length seq_len
, which will not all be valid when sliced into here -- we'll be indexing off the end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this was a bug. The fix is to ensure that freqs_cis
is at least as long as seq_len + offset
.
I actually caught another bug in the implementation, which I hadn't noticed before. Our previous implementation grouped the sin/cos like this:
but it should have been interleaved like this:
I used the implementation from lucidrains as a reference for the expected values and updated the test accordingly. 2 problems are left to fix:
|
equinox/nn/_embedding.py
Outdated
) -> Complex[Array, "end half_of_embedding_size"]: | ||
freqs = 1.0 / ( | ||
theta | ||
** (jnp.arange(0.0, embedding_size, 2)[jnp.newaxis, :] / embedding_size) | ||
) | ||
|
||
t = jnp.arange(float(end)) | ||
t = jnp.arange(end / 1.0) # promote to float |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(This was also because Pyright kept complaining)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I preferred float(end)
, rather than implicitly promoting. Even if that means a pyright-ignore.
I think both interleaved and non-interleaved are acceptable. In fact when we first wrote this, I checked this against the ESM2 implementation here to ensure correctness. We can't switch this now for backward compatibility. |
Bother, this got autoclosed because the target branch got merged. Feel free to re-open. |
No worries, I'll fix it once I get back from vacation. |
As mentioned in #704, this is the fix for the RoPE embeddings. I also added
pyright: ignore
to thetest_nn.py
file. I think perhaps someone forgot to run the PCH hook. I checked GH and it seems we don't run the Pyright check anymore?