-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unable to use multiple devices #9
Comments
After further inspection, I think this issue can be expanded to simply request that you allow the flexibility of passing any arguments to the the |
Thanks for bringing this up! Yeah, one issue I see is that then the parallelization strategy also becomes a factor to consider. So I would agree that the best course of action would be to have the option to pass in kwargs for the trainer. Not super sure if this could cause any problems between inference and training kwargs though. I actually wasn't aware that the number of GPUs also plays a role in the lr calculation. Out of interest, do you have a reference for that? This also requires the number of devices to be set explicitly instead of using the default value of (By the way, in the |
Hi, thanks for the quick response! Here is the reference for the learning rate calculation based on device count.
Makes sense, thank you. I think Regarding the I'll try to test this out soon, but don't feel like you have to leave this open. If I find success with the multi-gpu training, I can submit a PR with the changes if that works for you. |
Cool, yeah if you want to draft a pull request, I'd be happy to merge it. To be fair, I am not sure if there are any subtle issues with multi-GPU training. The most important question is if the batch size will vary because that has quite an important influence. But yeah, it should become apparent if the visualization of CIFAR-10 looks off after the default training. Other than that I am not aware of any other issues, but I also do not have much experience with multiple GPUs. |
I implemented the functionality for passing in custom kwargs for the trainer. So far I haven't changed the lr calculation as I am not really using more than 1 GPU at the moment. Currently it issues a warning when multiple devices are used but the lr is left as the default. I would be happy to merge in some code that will account for how the lr changes when using multiple devices though. |
Great, thanks! I'm finding that the implementation is more complex than just configuring the devices and adjusting the learning rate, as I initially proposed. Here are a couple of my findings so far... Firstly, it's crucial to gather the tensors across all devices before calculating the loss. This is important to maximize the sampling of the negative space. Ensuring gradient synchronization across the devices is also essential in this context. Secondly, since batch normalization is utilized in the model, these need to be synchronized as well. It's important that activation scaling occurs across all batches on all devices, not just within each device's own mini-batch. Thirdly, when scaling the batch size significantly (sounds like sizes greater than 1024), a shift from a standard SGD optimizer to LARS becomes necessary. I have implemented the first two points, but my testing was slowed down due to a break over the holiday. Currently, I'm running a training session with a batch size of 128 across 4 GPUs (effectively 512) to replicate the results from the paper. The initial signs are promising – I did observe some clustering in reduced epochs just the other day, so I'm cautiously optimistic. Assuming this training run is successful, the next step will be to implement and test with increased batch sizes and LARS. So, there's still a fair bit of work ahead. |
It appears the number of devices is hardcoded:
t-simcne/tsimcne/tsimcne.py
Line 481 in e2988f5
t-simcne/tsimcne/tsimcne.py
Line 517 in e2988f5
Given the contrastive learning task, it would be preferable to utilize more resources for training and allow the
devices
argument to be passed to your model classes.I'm unsure to what extent other aspects of the model would require changes, but I believe the learning rate calculation in
lr_from_batchsize
would need to be updated as well (batch_size * devices).The text was updated successfully, but these errors were encountered: