-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parallelize parameter weight computation using PyTorch Distributed #22
Conversation
slurm option is available if script is started with srun/sbatch
create_parameter_weights.py
Outdated
dist.destroy_process_group() | ||
|
||
|
||
def adjust_dataset_size(ds, world_size, batch_size): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is my understanding correct that the way the dataset size is adjusted is by ignoring the last samples? While that does not change the statistics much for large datasets I think it needs to be clearly communicated when this script is used. It also means that the statistics will change somewhat depending on how many ranks you run this on right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, I will present another option that randomly draws samples from the dataset until all ranks are fully filled. At this point you might notice that I am not a computer engineer. 😇 I will in any case add a comment about this and give the user a choice between speed and accuracy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does that really solve the problem though? Then you would get different statistics depending on the random sampling.
What I am thinking more is that it would be nice if we guarantee that the true mean and std.-dev. of the training data is what we return here, no matter parallelization choices. One could for example distribute out as many samples as possible to fill out full batches on all ranks, run those batches in parallel, and then just run any samples left in batches only on rank 0 at the end. Then do the little bit of math to get combine this into the correct mean.
Alternatively, one could pad the dataset with duplicate samples to fill it up for all ranks, and then keep track of the padding samples to not count when reducing to the mean. Seems a bit trickier/messier to implement.
In the end I am not sure if it is worth it to put in the effort. Maybe it's enough to just document that this will drop some samples in the end of the dataset for the statistics computations.
Have you done any testing about how long time this takes on CPU vs GPU? I am curious if the gains from GPU acceleration makes up for the time needed to shuffle the data over to it for these computations. |
I have done tests with my 7TB (436524 samples) cosmo training data (zarr-based). As this solution scaled rather well I could reduce the runtime for the |
I was thinking that once the data is in a zarr, will this script not be replaced by just a few xarray |
TLDRI suggest to merge this with main for now for the 10x-speedup on CPU with GLOO and the newly introduced So the smart thing would have been to wait for xarray, but instead I decided to make a little pseudo-scientific study 😮💨 1 BenchmarksThe script was evaluated in three different modes. Note that multi-node is currently not supported (yes, I had to stop myself at some point)
For the benchmarks I arbitrarily iterated 500-times over the THE RESULTS: |Local: 55:23min| |Slurm-GPU: 5:56min| |Slurm-CPU:5:38| So as Joel already suspected, it was not worth it to port these simple calculations to GPU. We can also remove the 2 RobustnessThe script must produce the same statistics as the old script from
NotesMy colleagues (generation TikTok) did actually complain that they cannot start training the model for 50h waiting for the stats, so having a faster script is certainly nice. If you actually want to use the
Okay I hope I will find pieces of this useful in other PRs, now please someone stop me from wasting more time here 🤣 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You went down quite the rabbit hole here. I'll try to stop you from putting more time on this 😛
While much of this code might be replaced by xarray-based solutions in the future, I still think that your investigation here will be useful when looking over how to do that. So not just a waste of time 😄 And in particular it means that we don't need to prioritize a GPU-compatible implementation for this.
With that said, I agree that we should just merge this now and revisit this script later with xarray. I tested the script and it seems to be working great. Add this to the changelog and then go ahead and merge!
Changelog updated with missing entry for #22
Description
This PR introduces parallelization to the
create_parameter_weights.py
script using PyTorch Distributed. The main changes include:Added functions
get_rank()
,get_world_size()
,setup()
, andcleanup()
to initialize and manage the distributed process group.get_rank()
retrieves the rank of the current process in the distributed group.get_world_size()
retrieves the total number of processes in the distributed group.setup()
initializes the distributed process group using NCCL (for GPU) or gloo (for CPU) backend.cleanup()
destroys the distributed process group.Modified the
main()
function to takerank
andworld_size
as arguments and set up the distributed environment.(world_size * batch_size)
using theadjust_dataset_size()
function.DistributedSampler
is used to partition the dataset among the processes.Parallelized the computation of means and squared values across the dataset.
dist.all_gather_object()
.Parallelized the computation of one-step difference means and squared values.
dist.all_gather_object()
.These changes enable the script to leverage multiple processes/GPUs to speed up the computation of parameter weights, means, and standard deviations. The dataset is partitioned among the processes, and the results are gathered and aggregated by the root process.
To run the script in a distributed manner, it can be launched using Slurm.
Please review the changes and provide any feedback or suggestions.