Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the offset to RoPE embedding and fixed the pre-commit pyright #799

Closed
wants to merge 8 commits into from

Conversation

Artur-Galstyan
Copy link
Sponsor Contributor

As mentioned in #704, this is the fix for the RoPE embeddings. I also added pyright: ignore to the test_nn.py file. I think perhaps someone forgot to run the PCH hook. I checked GH and it seems we don't run the Pyright check anymore?

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I don't think this was actually a bug before? It's really just a new feature for RotaryEmbedding, in that it now supports an offset.

My concern here is that in order to use it, we'd need to make a backward-incompatible change to MHA. (To pass the index.)

Since we've decided against implementing KV caching in MHA, then I don't think this really comes up.

@@ -126,18 +126,22 @@ def __call__(...):
def process_heads(
query_heads: Float[Array, "seq_length num_heads qk_size"],
key_heads: Float[Array, "seq_length num_heads qk_size"],
value_heads: Float[Array, "seq_length num_heads vo_size"]
value_heads: Float[Array, "seq_length num_heads vo_size"],
index: Int[Array, ""]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will fail, as we don't pass this extra argument inside MHA?

Copy link
Sponsor Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, you're right. I've added a more complete example for this.

@@ -194,12 +201,14 @@ def precompute_freqs_cis(
def __call__(
self,
x: Float[Array, "seq_length embedding_size"],
offset: Int[Array, ""] = jnp.array(0),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heads-up that we can't use JAX arrays at the global level in Equinox. (Which default values are.)

This is an internal Google restriction, and I try not to break my former coworkers!

Copy link
Sponsor Contributor Author

@Artur-Galstyan Artur-Galstyan Aug 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I didn't know. Is this just Equinox specific or should I in general avoid JAX arrays for defaults?

In any case, I made this mandatory then and added some more examples. Unfortunately, something like:

offset: Optional[Int[Array, ""]] = None

offset = offset if offset is not None else jnp.array(0)
...

doesn't work under JIT and jax.lax.select doesn't work either because it evaluates both branches.

You don't happen to have a solution for these kind of issues somewhere in your JAX wizard hat? Understanding JAX's behaviour under JIT is definitely something I need to catch up on!

@Artur-Galstyan
Copy link
Sponsor Contributor Author

I agree that this is not a bug per-se but still a shortcoming in our current version, especially when compared to a PyTorch counterpart which includes the input_pos in its forward function.

My concern here is that in order to use it, we'd need to make a backward-incompatible change to MHA. (To pass the index.)
Since we've decided against implementing KV caching in MHA, then I don't think this really comes up.

In the example I provided, it's even possible to make it work with the offset even without changing the MHA implementation by using functools.partial. TBH, I didn't think it would work and that it would re-JIT on every call when using this:

        x = self.mha_attention(
            query=query,
            key_=key_,
            value=value,
            process_heads=functools.partial(process_heads, index=index), # <--
        )

But I guess the compiler is pretty smart!

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.

Okay, so on balance I think I'm happy to have this extra argument, but for simplicity let's not include any extra examples in the documentation. We're not discussing this kind of progressive updating anywhere else, and I'd like to keep this change as easy to review as possible.

equinox/nn/_embedding.py Outdated Show resolved Hide resolved
equinox/nn/_embedding.py Outdated Show resolved Hide resolved
@Artur-Galstyan
Copy link
Sponsor Contributor Author

Artur-Galstyan commented Aug 23, 2024

Ok, so I've removed the max_seq_length argument, it's a breaking change with no backwards compatibility but more importantly, because it actually provides no benefit.

Unfold, if you want to know Initially, I added it, because I wanted to avoid re-JITting. But in the current implementation, we are not enforcing the input array to the RoPE module to be of length max_seq_length (and then applying a mask to effectively only "use" the intended seq_length). The input array can have any seq_length, which means the module will always be re-JITted if the seq_length of the input array changes. Thus, the max_seq_length argument is obsolete. In other words, if we were to ever include the max_seq_length argument, we would also need to include not just the offset but also the cutoff position and then apply a mask on the input array. But that would be an even breakier change 😄 .(Though, personally, I don't mind breaking changes as long as they are an improvement and properly communicated)

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Responding to your comment on backward compatibility -- I think this is very important :)

I'm willing to make a compatibility break if absolutely required, but that's generally been fairly unusual / often only on niche features / only in the early life of a library prior to achieving adoption.

It's just not providing a useful tool to those downstream if we do that kind of thing!

equinox/nn/_embedding.py Show resolved Hide resolved
equinox/nn/_embedding.py Outdated Show resolved Hide resolved
@@ -233,19 +254,22 @@ def __call__(
)
internal_rope_embedding_cache[embedding_size] = freqs_cis

freqs_cis = jax.lax.dynamic_slice_in_dim(freqs_cis, offset, seq_len)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks wrong to me. If we hit the else branch of if embedding_size in internal_rope_embedding_cache ... else: above, then we'll compute an array of length seq_len, which will not all be valid when sliced into here -- we'll be indexing off the end.

Copy link
Sponsor Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was a bug. The fix is to ensure that freqs_cis is at least as long as seq_len + offset.

@Artur-Galstyan
Copy link
Sponsor Contributor Author

I actually caught another bug in the implementation, which I hadn't noticed before. Our previous implementation grouped the sin/cos like this:

[cos1, cos2, cos3, cos4, sin1, sin2, sin3, sin4] (all grouped)

but it should have been interleaved like this:

[cos1, sin1, cos2, sin2, cos3, sin3, cos4, sin4]

I used the implementation from lucidrains as a reference for the expected values and updated the test accordingly.

2 problems are left to fix:

  1. Pyright complains at the TODO spots because ArrayLike might be a complex number and Operator ">" not supported for types "complex" and "int" :(

  2. It keeps re-jitting when using integers for the offset and when using arrays we get a ConcretizationTypeError. There is an almost MVP if you're curious. I'll invest more time into this - I feel like there should be a good solution.

) -> Complex[Array, "end half_of_embedding_size"]:
freqs = 1.0 / (
theta
** (jnp.arange(0.0, embedding_size, 2)[jnp.newaxis, :] / embedding_size)
)

t = jnp.arange(float(end))
t = jnp.arange(end / 1.0) # promote to float
Copy link
Sponsor Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(This was also because Pyright kept complaining)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I preferred float(end), rather than implicitly promoting. Even if that means a pyright-ignore.

@patrick-kidger
Copy link
Owner

I think both interleaved and non-interleaved are acceptable. In fact when we first wrote this, I checked this against the ESM2 implementation here to ensure correctness.

We can't switch this now for backward compatibility.

@Artur-Galstyan Artur-Galstyan changed the base branch from main to dev September 10, 2024 20:01
@patrick-kidger patrick-kidger deleted the branch patrick-kidger:dev September 14, 2024 08:01
@patrick-kidger
Copy link
Owner

Bother, this got autoclosed because the target branch got merged. Feel free to re-open.

@Artur-Galstyan
Copy link
Sponsor Contributor Author

No worries, I'll fix it once I get back from vacation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants