-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Maxvit #65
base: main
Are you sure you want to change the base?
Maxvit #65
Conversation
Needs testing but main code part should be done |
On random input, it works. After testing vqgan on webdataset on sbatch script, I'll test this too |
I added a custom TransformerLayer for MaxVit. Let me know if anyone has ideas on formulating this differently! My next step is mainly testing this out and comparing the Vram usage. Then, I'll open for review |
Fixed! |
I can run the code without any shape errors as of now but now I'm noticing that the maxvit layers do oom while the counterpart doesn't. I think I'm initializing some parameters to be too large which I plan to check tomorrow. |
Ok! I found that the main issue with the memory was the feed-forward networks in each transformer layer. They have the most parameters in the transformer layers and in max vit, we needed 3 instead of just 1. So that makes the memory usage per layer roughly 3 times. I fixed it so the size of the model is only 2 times now. The checklist now is
|
Without max vit: With maxvit: I think the main way google resolved this higher vram usage is to use optimizers like Lion and adafactor vs adamw since AdamW copies the weights of the model With lion: so proportional to the input weights it's better with maxvit. |
@williamberman @patil-suraj @sayakpaul @pcuenca I think I'm pretty much done. Let me know if there are any experiments/code changes that are recommended! The TLDR for this pr is this is the attention format that google used for the second stage of muse to reduce vram usage with the higher sequence length from using a f8 vqgan vs a f16 vqgan. This pr is heavily inspired from lucidrian's maxvit implementation here |
Draft pr where I'm adapting maxvit from lucidrian's code. The corresponding issue is here