Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SigLip memory consumption increases as we scale number of GPUs #942

Open
khalidsaifullaah opened this issue Sep 26, 2024 · 5 comments
Open

Comments

@khalidsaifullaah
Copy link

khalidsaifullaah commented Sep 26, 2024

From the SigLip paper my understanding is that it doesn't require any all_gather and it's always performing local b x b computation iteratively, where b is micro_batch_size (see this section from the paper).
image

So if I can fit let's say micro_batch_size 10 (in 8 GPUs), and then I increase the number of GPUs to 16, 32, 64, 128, ... my memory consumption should (more or less) remain the same (just like doing normal DDP). Or simply put, we should be able to scale world_batch_size or the number of nodes by keeping the micro_batch_size constant (in theory) right?

But what I've observed is that the memory consumption spikes as i increase world_batch_size (num of nodes) and I need to lower my micro_batch_size (even to as low as 2 for 128 devices).

  1. I'm wondering if my understanding of siglip is correct that keeping the micro_batch_size constant it allows you to scale world_batch_size? It could also be the case that they do some sort of TPU trick (i don't have much insights re that)?
  2. I have only skimmed through the siglip implementation here and I think it could also be possible that while swapping neighbors it doesn't free up the memory, and that's why the consumption accumulates...?

I could be totally wrong on both of these, so I'd be glad to know if anyone tried scaling world_batch_size and have had similar results, so i could validate my hypothesis

@rwightman
Copy link
Collaborator

rwightman commented Oct 3, 2024

@khalidsaifullaah yeah, it's not working quite as efficiently as it should. I feel my current isend/irecv impl, while in theory should be reasonable, it appears it may not a well optimized approach.

Looking at the big_vision codebase where SigLIP authors have the original ver of the models, there's no chunked sigmoid impl in the current code, but there's an impl in a deprecated file, interestingly there's a comment that states the 'ppermute' version (which should be equivalent to send/recv neighbour exchange, and should itself be more optimal being one high level op instead of multiple) used more memory than doing the 'hot potatoe' with all_reduce instead. Hmmm https://github.com/google-research/big_vision/blob/46b2456f54b9d4f829d1925b78943372b376153d/big_vision/trainers/proj/image_text/_deprecated_contrastive.py#L168-L200

I was thinking of trying an all_reduce impl... all_reduce and all_gather should be among the most optimized collective ops through software stack and network as they are most heavily used.

@rwightman
Copy link
Collaborator

rwightman commented Oct 26, 2024

@khalidsaifullaah I'm experimenting with diff impl of the loss to see if any scale better in #971 ... feel free to try, feedback would be welcome

@khalidsaifullaah
Copy link
Author

Oh awesome, I had moved on to implementing a different dist loss. However in my quick test of the new commit, i still seem to get OOM when horizontal scale to 128 gpus (mbsz=2, impl="reduce"). I'll do more tests on the other impl settings

@long8v
Copy link

long8v commented Oct 29, 2024

awesome! From my observations, when training with SigLIP loss using 100+ GPUs, I noticed it was considerably slower compared to CLIP loss. It would be really helpful for me if you could also report the 'train/batch_time' metrics for each implementation type @khalidsaifullaah @rwightman

@rwightman
Copy link
Collaborator

@khalidsaifullaah @long8v FWIW I wouldn't necessarily say no extra overhead as the world size increases is the passing criteria, I feel with gradient buffers, allocator behaviour, etc there's still likely to be some impact from the world size. It should be more efficient than CLIP loss though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants