Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make Linear support overlapping input/axis names? #53

Open
dlwh opened this issue Dec 20, 2023 · 5 comments
Open

make Linear support overlapping input/axis names? #53

dlwh opened this issue Dec 20, 2023 · 5 comments
Labels
good first issue Good for newcomers

Comments

@dlwh
Copy link
Member

dlwh commented Dec 20, 2023

Currently Haliax requires that all names in a single named array be unique. In general I think this is a good constraint. However, for Linear layers it's frequently a nuisance, since one often projects to something of the same shape, or you might want to keep the same name ("hidden").

So, it might be a good idea to support overlapping names. This will complicate the implementation quite a bit but simplify some juggling outside. I think this is worth the complexity?

Probably we'd rename overlapping "output" names to ${name}_out and then rename them in the result back to ${name}. If we make this a contract, then you can use it to control sharding.

@cooljoseph1
Copy link
Contributor

I think it's easiest to just always rename all axes to ${name}_in for in axes and ${name}_out for out axes. (This guarantees there will be no conflicting names, since the in axes all end in "in" whereas the out axes all end in "out".) I've implemented that in the above pull request.

Is there a reason to not do this (e.g., performance issues)?

@dlwh
Copy link
Member Author

dlwh commented Sep 4, 2024

It messes up FSDP, or at least it makes it so you have to specify that both Embed_in an Embed_out are sharded, which is a bit noisier

@cooljoseph1
Copy link
Contributor

I don't know how sharding works in Haliax. Would you mind explaining why it messes up sharding?

@dlwh
Copy link
Member Author

dlwh commented Sep 5, 2024

well, "messes up" is a bit strong, but the key idea behind sharding in Haliax is mapping named axes to a device mesh axis (cf the tutorial https://colab.research.google.com/drive/1QX4yH3zRFF3Xiibf1aahETcSQ5nbcUMz). Currently to set up FSDP, we do:

model = hax.shard(model, {"embed": "data"})

and this means that every "embed" axis in the model is sharded across the data axis of the device mesh. To add tensor parallelism, you'd do something like:

model = hax.shard(model, {"embed": "data", "mlp": "model"})

With your change, we'd have to do

model = hax.shard(model, {"embed_in": "data", "embed_out": "data"})

which seems noisier. WDYT?

@cooljoseph1
Copy link
Contributor

They seem to be pretty much the same noisy to me, and I think it's fine to make that change. In the first one you need to have separate names for all your axes in a sequence of linear layers, which can be just as confusing.

I think it ultimately comes down to needing a disjoint union of axes specs, not a union, and I don't think this is possible without renaming things.

Perhaps one could create some kind of tree (or DAG) of axes that are derived from other axes and then automagically when sharding also shard any sub-axes, but that feels like overcomplicating things.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

2 participants