Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WaveMix Multi-GPU CUDA fix #8

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

BhavyaKohli
Copy link

In the original wavemix/__init__.py, some layers are put on 'cuda:0' by default, and this cannot be changed. This behavior is undesirable since upon initialization, some weights are on 'cuda:0' and the rest are on 'cpu', until model.to('cuda:0') is called.

Additionally, on servers with more than one gpu, if gpu 0 is in use, wavemix models cannot be used. There is another bug which prevents conveniently putting the whole model on say 'cuda:4': the DWTForward layers are defined outside the LevelXWaveblocks and since they are not "children" of the waveblocks, when waveblock.to('cuda:4') is called, the xf1, xf2, xf3, xf4 are still on 'cuda:0', causing device mismatch errors.

This patch fixes these issues by changing the default device argument in the defined functions to 'cpu', which makes all weights initialized on 'cpu' at first, and defines xf1, xf2, xf3, xf4 inside their respective waveblocks, which allows waveblock.to('cuda:N') to work as intended.

@BhavyaKohli
Copy link
Author

@pranavphoenix Can you have a look

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant