You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
From the SigLip paper my understanding is that it doesn't require any all_gather and it's always performing local b x b computation iteratively, where b is micro_batch_size (see this section from the paper).
So if I can fit let's say micro_batch_size 10 (in 8 GPUs), and then I increase the number of GPUs to 16, 32, 64, 128, ... my memory consumption should (more or less) remain the same (just like doing normal DDP). Or simply put, we should be able to scale world_batch_size or the number of nodes by keeping the micro_batch_size constant (in theory) right?
But what I've observed is that the memory consumption spikes as i increase world_batch_size (num of nodes) and I need to lower my micro_batch_size (even to as low as 2 for 128 devices).
I'm wondering if my understanding of siglip is correct that keeping the micro_batch_size constant it allows you to scale world_batch_size or num of devices?
It could also be the case that the siglip GPU implementation I'm using isn't quite how the official TPU one works? For example, I think it could also be possible that while swapping neighbors it doesn't free up the memory, and that's why the consumption accumulates...?
I could be totally wrong on both of these assumptions, so I'd be glad if maybe the authors could provide some insights, so i could validate my hypothesis regarding horizontal scaling of siglip
The text was updated successfully, but these errors were encountered:
From the SigLip paper my understanding is that it doesn't require any
all_gather
and it's always performing localb x b
computation iteratively, where b ismicro_batch_size
(see this section from the paper).So if I can fit let's say
micro_batch_size
10 (in 8 GPUs), and then I increase the number of GPUs to 16, 32, 64, 128, ... my memory consumption should (more or less) remain the same (just like doing normal DDP). Or simply put, we should be able to scaleworld_batch_size
or the number of nodes by keeping themicro_batch_size
constant (in theory) right?But what I've observed is that the memory consumption spikes as i increase
world_batch_size
(num of nodes) and I need to lower mymicro_batch_size
(even to as low as2
for 128 devices).micro_batch_size
constant it allows you to scaleworld_batch_size
or num of devices?I could be totally wrong on both of these assumptions, so I'd be glad if maybe the authors could provide some insights, so i could validate my hypothesis regarding horizontal scaling of siglip
The text was updated successfully, but these errors were encountered: