Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

haliax.dot: allow defining output order a la einsum #2

Open
dlwh opened this issue Apr 10, 2023 · 8 comments
Open

haliax.dot: allow defining output order a la einsum #2

dlwh opened this issue Apr 10, 2023 · 8 comments

Comments

@dlwh
Copy link
Member

dlwh commented Apr 10, 2023

Sidi asked for something like this, and I think it's a good idea. (He asked for proper einsum support, which is also a good idea)

cf stanford-crfm/levanter#107

I'd like to support something like:

haliax.dot(Embed, key, query, out_axes=(..., Head, KeySeqLen, SeqLen))

which would force the order

@dlwh dlwh transferred this issue from stanford-crfm/levanter Jun 26, 2023
@reachtarunhere
Copy link

I really like Haliax so far I do think that we should proper einsum support. To me the string often doubles as documentation.

@dlwh
Copy link
Member Author

dlwh commented Jan 12, 2024

yeah I was thinking that too. I added out_axes in the secret-ish dev branch. WDYT the syntax should be like?

@dlwh
Copy link
Member Author

dlwh commented Jan 12, 2024

(and thanks!)

@reachtarunhere
Copy link

reachtarunhere commented Jan 21, 2024

yeah I was thinking that too. I added out_axes in the secret-ish dev branch. WDYT the syntax should be like?

Ideally I would prefer a similar syntax to einops as shown here.

I very much prefer the explicit einsum("i j, j k -> i k", mat1, mat2) vs the dot syntax currently in the lib.

We can enforce that instead of random i j k the real axes names we have in haliax are used on the left hand side.

This does have some disadvantage over the dot method in terms of privacy of axes. For something like bmm the above code will break while the dot code will work just fine. This can be countered by batching over the axes not mentioned?

@dlwh
Copy link
Member Author

dlwh commented Jan 21, 2024

Yeah, I am coming around to this point of view. WDYT about the syntax for new-einops-style rearrange, particularly the dev version?

We could support this syntax for dot with something like:

  • support normal einops syntax, including short name-capture: hax.dot("... c h w, h w d -> ... c d", a, b)
  • hax.dot("{h, w} -> ", a, b) means "contract h and w", analogous to hax.dot(a, b, axis=("h", "w"))
  • hax.dot("{h, w} -> ... channel embed", a, b) means "contract h and w and ensure that the result ends with [channel, embed]" (by transposing/einsum)
  • hax.dot(" -> batch channel embed", a, b) could mean "contract all but the named dims". Not entirely sure how I feel about that one, but used situationally it's probably ok

@reachtarunhere
Copy link

Looks great to me. Pretty much what I am looking for (maybe except the last one haha)

We can also have a hax.einsum which calls hax.dot after resolving all this stuff instead of expanding hax.dot

@dlwh dlwh mentioned this issue Feb 5, 2024
@dlwh
Copy link
Member Author

dlwh commented Feb 5, 2024

@reachtarunhere any chance I could get you to look at https://github.com/stanford-crfm/haliax/pull/63/files#diff-b1aa00624eecf36f969b62aaee977cfac454841fa3dc40f480759e68bda5473bR57 and lmk what you think? Just asking you to glance at the docs, but if you want to look deeper that would be lovely too :-)

@reachtarunhere
Copy link

@dlwh just back from vacation. Happy to take a look later today :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants