Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelize parameter weight computation using PyTorch Distributed #22

Merged
merged 11 commits into from
Jun 3, 2024

Conversation

sadamov
Copy link
Collaborator

@sadamov sadamov commented May 2, 2024

Description

This PR introduces parallelization to the create_parameter_weights.py script using PyTorch Distributed. The main changes include:

  1. Added functions get_rank(), get_world_size(), setup(), and cleanup() to initialize and manage the distributed process group.

    • get_rank() retrieves the rank of the current process in the distributed group.
    • get_world_size() retrieves the total number of processes in the distributed group.
    • setup() initializes the distributed process group using NCCL (for GPU) or gloo (for CPU) backend.
    • cleanup() destroys the distributed process group.
  2. Modified the main() function to take rank and world_size as arguments and set up the distributed environment.

    • The device is set based on the rank and available GPUs.
    • The dataset is adjusted to ensure its size is divisible by (world_size * batch_size) using the adjust_dataset_size() function.
    • A DistributedSampler is used to partition the dataset among the processes.
  3. Parallelized the computation of means and squared values across the dataset.

    • Each process computes the means and squared values for its assigned portion of the dataset.
    • The results are gathered from all processes using dist.all_gather_object().
    • The root process (rank 0) computes the final mean, standard deviation, and flux statistics using the gathered results.
  4. Parallelized the computation of one-step difference means and squared values.

    • Similar to step 3, each process computes the difference means and squared values for its assigned portion of the dataset.
    • The results are gathered from all processes using dist.all_gather_object().
    • The final difference mean and standard deviation are computed using the gathered results.

These changes enable the script to leverage multiple processes/GPUs to speed up the computation of parameter weights, means, and standard deviations. The dataset is partitioned among the processes, and the results are gathered and aggregated by the root process.

To run the script in a distributed manner, it can be launched using Slurm.

Please review the changes and provide any feedback or suggestions.

Simon Adamov added 3 commits May 1, 2024 21:20
@sadamov sadamov added the enhancement New feature or request label May 2, 2024
@sadamov sadamov requested a review from joeloskarsson May 2, 2024 05:16
@sadamov sadamov self-assigned this May 2, 2024
requirements.txt Outdated Show resolved Hide resolved
dist.destroy_process_group()


def adjust_dataset_size(ds, world_size, batch_size):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is my understanding correct that the way the dataset size is adjusted is by ignoring the last samples? While that does not change the statistics much for large datasets I think it needs to be clearly communicated when this script is used. It also means that the statistics will change somewhat depending on how many ranks you run this on right?

Copy link
Collaborator Author

@sadamov sadamov May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, I will present another option that randomly draws samples from the dataset until all ranks are fully filled. At this point you might notice that I am not a computer engineer. 😇 I will in any case add a comment about this and give the user a choice between speed and accuracy.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that really solve the problem though? Then you would get different statistics depending on the random sampling.

What I am thinking more is that it would be nice if we guarantee that the true mean and std.-dev. of the training data is what we return here, no matter parallelization choices. One could for example distribute out as many samples as possible to fill out full batches on all ranks, run those batches in parallel, and then just run any samples left in batches only on rank 0 at the end. Then do the little bit of math to get combine this into the correct mean.

Alternatively, one could pad the dataset with duplicate samples to fill it up for all ranks, and then keep track of the padding samples to not count when reducing to the mean. Seems a bit trickier/messier to implement.

In the end I am not sure if it is worth it to put in the effort. Maybe it's enough to just document that this will drop some samples in the end of the dataset for the statistics computations.

@joeloskarsson
Copy link
Collaborator

Have you done any testing about how long time this takes on CPU vs GPU? I am curious if the gains from GPU acceleration makes up for the time needed to shuffle the data over to it for these computations.

@sadamov
Copy link
Collaborator Author

sadamov commented May 30, 2024

Have you done any testing about how long time this takes on CPU vs GPU? I am curious if the gains from GPU acceleration makes up for the time needed to shuffle the data over to it for these computations.

I have done tests with my 7TB (436524 samples) cosmo training data (zarr-based). As this solution scaled rather well I could reduce the runtime for the create_parameter_weights.py script from roughly 50h to 1h. I don't remember whether CPU/GPU was relevant, should do a proper assessment.

@joeloskarsson
Copy link
Collaborator

I was thinking that once the data is in a zarr, will this script not be replaced by just a few xarray .mean calls, that are already parallelized? So it seems good to know if we should put in the effort to do this ourselves on GPU, or just rely on xarray for that later.

@sadamov
Copy link
Collaborator Author

sadamov commented May 31, 2024

TLDR

I suggest to merge this with main for now for the 10x-speedup on CPU with GLOO and the newly introduced --distributed flag. GPU has no real benefit, so we can remove that if preferred. Suggestions from reviewer were implemented, feel free to double check. Should be replaced with xarray-based solution in the future.

So the smart thing would have been to wait for xarray, but instead I decided to make a little pseudo-scientific study 😮‍💨
The script is now working on CPU/GPU with slurm or locally. No data is lost any longer thanks to an improved Dataset Class that introcudes padding (I had the same idea as you @joeloskarsson). In the following I want to talk about the performance (1. Benchmarks) and the Robustness (2.).

1 Benchmarks

The script was evaluated in three different modes. Note that multi-node is currently not supported (yes, I had to stop myself at some point)

  • Locally, similar to how the script is currently run in main
  • With slurm using 1 GPU-node with 4 A100 GPUS and 128 CPU cores, with NCCL
  • With slurm using 1 CPU-node with 256 CPU cores, with GLOO

For the benchmarks I arbitrarily iterated 500-times over the meps_example resulting in a 2000 samples dataset.
I tracked the exact time required to execute the full script either in slurmvia sacct or in the local terminal.

THE RESULTS: |Local: 55:23min| |Slurm-GPU: 5:56min| |Slurm-CPU:5:38|

So as Joel already suspected, it was not worth it to port these simple calculations to GPU. We can also remove the nccl option from the script if preferred.

2 Robustness

The script must produce the same statistics as the old script from main, in distributed mode and in single-task mode. To assert this, all stats were produced with the current script from main (called old), the new script in single-task mode (no prefix), and the new script in distributed mode called distributed. In the following stdout from the terminal we can see that all stats match with a tolerance of 1e-5 for new and new_distributed on CPU. That being said, the GPU-NCCL runs sometimes only have atol=1e-2 accuracy. Something that could probably be adjusted in the floating point operation settings, but not really worth it for now.

print("diff_std == diff_distributed_std:", torch.allclose(diff_std, diff_distributed_std, atol=1e-5))
...
------------------------------------------------------------------------------
diff_std == diff_distributed_std: True
diff_mean == diff_distributed_mean: True
flux_stats == flux_distributed_stats: True
parameter_mean == parameter_distributed_mean: True
parameter_std == parameter_distributed_std: True
------------------------------------------------------------------------------
diff_std == diff_old_std: True
diff_mean == diff_old_mean: True
flux_stats == flux_old_stats: True
parameter_mean == parameter_old_mean: True
parameter_std == parameter_old_std: True

Notes

My colleagues (generation TikTok) did actually complain that they cannot start training the model for 50h waiting for the stats, so having a faster script is certainly nice. If you actually want to use the --distributed feature, you will need a scheduler. I am personally using SLURM as follows (I will add a note about this to the README.md if we merge:

#!/bin/bash -l
#SBATCH --job-name=NeurWP
#SBATCH --account=s83
#SBATCH --time=02:00:00
#SBATCH --nodes=1
#SBATCH --ntasks=16
#SBATCH --partition=postproc
#SBATCH --mem=444G
#SBATCH --no-requeue
#SBATCH --exclusive
#SBATCH --output=lightning_logs/neurwp_param_out.log
#SBATCH --error=lightning_logs/neurwp_param_err.log

# Load necessary modules
conda activate neural-lam

srun -ul python create_parameter_weights.py --batch_size 16 --distributed

Okay I hope I will find pieces of this useful in other PRs, now please someone stop me from wasting more time here 🤣

Copy link
Collaborator

@joeloskarsson joeloskarsson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You went down quite the rabbit hole here. I'll try to stop you from putting more time on this 😛

While much of this code might be replaced by xarray-based solutions in the future, I still think that your investigation here will be useful when looking over how to do that. So not just a waste of time 😄 And in particular it means that we don't need to prioritize a GPU-compatible implementation for this.

With that said, I agree that we should just merge this now and revisit this script later with xarray. I tested the script and it seems to be working great. Add this to the changelog and then go ahead and merge!

create_parameter_weights.py Outdated Show resolved Hide resolved
@sadamov sadamov merged commit 743c07a into mllam:main Jun 3, 2024
4 checks passed
@sadamov sadamov deleted the feature_parallel_stats_calc branch June 3, 2024 16:46
sadamov added a commit that referenced this pull request Jun 6, 2024
joeloskarsson pushed a commit that referenced this pull request Jun 13, 2024
Changelog updated with missing entry for #22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants