From cdaf082913692738735c34ec0c8f81acc18e674b Mon Sep 17 00:00:00 2001 From: japols Date: Tue, 24 Sep 2024 12:51:24 +0000 Subject: [PATCH 01/13] feat: initial implementation of dataloader memory optimization --- src/anemoi/training/data/dataset.py | 6 +++++- src/anemoi/training/train/forecaster.py | 14 ++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index caaa986b..d0a2f4bb 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -222,7 +222,11 @@ def __iter__(self) -> torch.Tensor: x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables") self.ensemble_dim = 1 - yield torch.from_numpy(x) + if self.model_comm_group_rank == 0: + yield torch.from_numpy(x) + else: + # yield dummy data only with shape information for non-root ranks + yield torch.tensor(x.shape, dtype=torch.long) def __repr__(self) -> str: return f""" diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index ff1acfd7..51ba2d7c 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -219,9 +219,19 @@ def _step( validation_mode: bool = False, ) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]: del batch_idx + + # preprocess batch and broadcast from gpu0 to model comm group + if self.model_comm_group_rank == 0: # note that this defaults to 0 if model_comm_group is None + # for validation not normalized in-place because remappers cannot be applied in-place + batch = self.model.pre_processors(batch, in_place=not validation_mode) + else: + shape = (batch.shape[0],) + tuple(batch[0].tolist()) + batch = torch.zeros(shape, device=self.device) + + if self.model_comm_group is not None: + torch.distributed.broadcast(batch, src=0, group=self.model_comm_group) + loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False) - # for validation not normalized in-place because remappers cannot be applied in-place - batch = self.model.pre_processors(batch, in_place=not validation_mode) metrics = {} # start rollout of preprocessed batch From fcc7c933b859bbb2f2806f0b69f87ca5428c9f4a Mon Sep 17 00:00:00 2001 From: japols Date: Wed, 2 Oct 2024 09:14:41 +0000 Subject: [PATCH 02/13] fix: non-reader tasks actually return before reading --- src/anemoi/training/data/dataset.py | 61 ++++++++++++++++++++++++++--- 1 file changed, 56 insertions(+), 5 deletions(-) diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index d0a2f4bb..3ed31654 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -12,6 +12,9 @@ from functools import cached_property from typing import Callable +import psutil +from collections import defaultdict + import numpy as np import torch from einops import rearrange @@ -191,6 +194,17 @@ def __iter__(self) -> torch.Tensor: Currently it receives data with an ensemble dimension, which is discarded for now. (Until the code is "ensemble native".) """ + if self.model_comm_group_rank != 0: + # yield dummy data only with shape information for non-root ranks + shape = (self.rollout + self.multi_step, self.data.shape[2], self.data.shape[3], self.data.shape[1]) + LOGGER.debug(f"Rank {self.model_comm_group_rank} using dummy shape {shape}") + + for _ in self.chunk_index_range: + yield torch.tensor(shape, dtype=torch.long) + + self._log_memory_usage() + return + if self.shuffle: shuffled_chunk_indices = self.rng.choice( self.chunk_index_range, @@ -222,11 +236,48 @@ def __iter__(self) -> torch.Tensor: x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables") self.ensemble_dim = 1 - if self.model_comm_group_rank == 0: - yield torch.from_numpy(x) - else: - # yield dummy data only with shape information for non-root ranks - yield torch.tensor(x.shape, dtype=torch.long) + yield torch.from_numpy(x) + LOGGER.debug(f"Worker {self.worker_id} yielded data with shape {x.shape}") + + self._log_memory_usage() + + def _log_memory_usage(self) -> None: + """Log detailed memory usage, including RSS, PSS, USS, and shared memory.""" + LOGGER.debug(f"Worker {self.worker_id} logging memory usage.") + process = psutil.Process(os.getpid()) + mem_info = self._get_mem_info(process.pid) + + # Log the detailed memory usage + LOGGER.debug( + "Worker %d (pid %d, global_rank %d, model comm group %d) memory usage (in MB): " + "RSS: %.2f MB, PSS: %.2f MB, USS: %.2f MB, Shared: %.2f MB, Shared File: %.2f MB", + self.worker_id, + os.getpid(), + self.global_rank, + self.model_comm_group_id, + mem_info['rss'] / 1024 ** 2, + mem_info['pss'] / 1024 ** 2, + mem_info['uss'] / 1024 ** 2, + mem_info['shared'] / 1024 ** 2, + mem_info['shared_file'] / 1024 ** 2 + ) + + def _get_mem_info(self, pid: int) -> dict[str, int]: + """Retrieve detailed memory information for the given process.""" + res = defaultdict(int) + + # Iterate through memory maps to gather memory details + for mmap in psutil.Process(pid).memory_maps(): + res['rss'] += mmap.rss + res['pss'] += mmap.pss + res['uss'] += mmap.private_clean + mmap.private_dirty + res['shared'] += mmap.shared_clean + mmap.shared_dirty + + # If the path points to a file, classify it as shared file memory + if mmap.path.startswith('/'): + res['shared_file'] += mmap.shared_clean + mmap.shared_dirty + + return res def __repr__(self) -> str: return f""" From 5d171c769ebdd6081ae82cdc227253677173a621 Mon Sep 17 00:00:00 2001 From: japols Date: Mon, 7 Oct 2024 08:30:38 +0000 Subject: [PATCH 03/13] feat: add reader_group to define per-model_comm_group read behaviour via dataloader.read_frequency --- .../config/dataloader/native_grid.yaml | 2 + src/anemoi/training/data/datamodule.py | 10 +++- src/anemoi/training/data/dataset.py | 51 ++----------------- src/anemoi/training/distributed/strategy.py | 38 +++++++++++++- src/anemoi/training/train/forecaster.py | 36 +++++++++++-- src/anemoi/training/train/train.py | 1 + 6 files changed, 87 insertions(+), 51 deletions(-) diff --git a/src/anemoi/training/config/dataloader/native_grid.yaml b/src/anemoi/training/config/dataloader/native_grid.yaml index 35ed7e35..5535af3a 100644 --- a/src/anemoi/training/config/dataloader/native_grid.yaml +++ b/src/anemoi/training/config/dataloader/native_grid.yaml @@ -1,5 +1,7 @@ prefetch_factor: 2 +read_frequency: 1 + num_workers: training: 8 validation: 8 diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index 4e1e4d1b..f20c6a9f 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -74,11 +74,18 @@ def __init__(self, config: DictConfig) -> None: * self.config.hardware.num_nodes // self.config.hardware.num_gpus_per_model ) # number of model communication groups + + self.reader_group_id = self.model_comm_group_rank // self.config.dataloader.read_frequency + self.reader_group_rank = self.model_comm_group_rank % self.config.dataloader.read_frequency + LOGGER.debug( - "Rank %d model communication group number %d, with local model communication group rank %d", + "Rank %d model communication group number %d, with local model communication group rank %d, " + "reader group number %d, with local reader group rank %d", self.global_rank, self.model_comm_group_id, self.model_comm_group_rank, + self.reader_group_id, + self.reader_group_rank, ) # Set the maximum rollout to be expected @@ -168,6 +175,7 @@ def _get_dataset( model_comm_group_rank=self.model_comm_group_rank, model_comm_group_id=self.model_comm_group_id, model_comm_num_groups=self.model_comm_num_groups, + reader_group_rank=self.reader_group_rank, shuffle=shuffle, label=label, ) diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index 3ed31654..1453882c 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -38,6 +38,7 @@ def __init__( model_comm_group_rank: int = 0, model_comm_group_id: int = 0, model_comm_num_groups: int = 1, + reader_group_rank: int = 0, shuffle: bool = True, label: str = "generic", ) -> None: @@ -82,6 +83,9 @@ def __init__( self.model_comm_group_id = model_comm_group_id self.global_rank = int(os.environ.get("SLURM_PROCID", "0")) + # Reader group info + self.reader_group_rank = reader_group_rank + # additional state vars (lazy init) self.n_samples_per_worker = 0 self.chunk_index_range: np.ndarray | None = None @@ -194,15 +198,11 @@ def __iter__(self) -> torch.Tensor: Currently it receives data with an ensemble dimension, which is discarded for now. (Until the code is "ensemble native".) """ - if self.model_comm_group_rank != 0: + if self.reader_group_rank != 0: # yield dummy data only with shape information for non-root ranks shape = (self.rollout + self.multi_step, self.data.shape[2], self.data.shape[3], self.data.shape[1]) - LOGGER.debug(f"Rank {self.model_comm_group_rank} using dummy shape {shape}") - for _ in self.chunk_index_range: yield torch.tensor(shape, dtype=torch.long) - - self._log_memory_usage() return if self.shuffle: @@ -237,47 +237,6 @@ def __iter__(self) -> torch.Tensor: self.ensemble_dim = 1 yield torch.from_numpy(x) - LOGGER.debug(f"Worker {self.worker_id} yielded data with shape {x.shape}") - - self._log_memory_usage() - - def _log_memory_usage(self) -> None: - """Log detailed memory usage, including RSS, PSS, USS, and shared memory.""" - LOGGER.debug(f"Worker {self.worker_id} logging memory usage.") - process = psutil.Process(os.getpid()) - mem_info = self._get_mem_info(process.pid) - - # Log the detailed memory usage - LOGGER.debug( - "Worker %d (pid %d, global_rank %d, model comm group %d) memory usage (in MB): " - "RSS: %.2f MB, PSS: %.2f MB, USS: %.2f MB, Shared: %.2f MB, Shared File: %.2f MB", - self.worker_id, - os.getpid(), - self.global_rank, - self.model_comm_group_id, - mem_info['rss'] / 1024 ** 2, - mem_info['pss'] / 1024 ** 2, - mem_info['uss'] / 1024 ** 2, - mem_info['shared'] / 1024 ** 2, - mem_info['shared_file'] / 1024 ** 2 - ) - - def _get_mem_info(self, pid: int) -> dict[str, int]: - """Retrieve detailed memory information for the given process.""" - res = defaultdict(int) - - # Iterate through memory maps to gather memory details - for mmap in psutil.Process(pid).memory_maps(): - res['rss'] += mmap.rss - res['pss'] += mmap.pss - res['uss'] += mmap.private_clean + mmap.private_dirty - res['shared'] += mmap.shared_clean + mmap.shared_dirty - - # If the path points to a file, classify it as shared file memory - if mmap.path.startswith('/'): - res['shared_file'] += mmap.shared_clean + mmap.shared_dirty - - return res def __repr__(self) -> str: return f""" diff --git a/src/anemoi/training/distributed/strategy.py b/src/anemoi/training/distributed/strategy.py index c15828ca..0eb3fab5 100644 --- a/src/anemoi/training/distributed/strategy.py +++ b/src/anemoi/training/distributed/strategy.py @@ -24,19 +24,26 @@ class DDPGroupStrategy(DDPStrategy): """Distributed Data Parallel strategy with group communication.""" - def __init__(self, num_gpus_per_model: int, **kwargs: dict) -> None: + def __init__( + self, + num_gpus_per_model: int, + read_frequency: int, + **kwargs: dict) -> None: """Initialize the distributed strategy. Parameters ---------- num_gpus_per_model : int Number of GPUs per model to shard over. + read_frequency : int + Frequency of dataloader readers per model group. **kwargs : dict Additional keyword arguments. """ super().__init__(**kwargs) self.model_comm_group_size = num_gpus_per_model + self.read_frequency = read_frequency def setup(self, trainer: pl.Trainer) -> None: assert self.accelerator is not None, "Accelerator is not initialized for distributed strategy" @@ -71,6 +78,35 @@ def setup(self, trainer: pl.Trainer) -> None: str(model_comm_group_ranks[model_comm_group_id]), ) + # set up reader groups by further splitting model_comm_group_ranks with read_frequency + assert self.model_comm_group_size % self.read_frequency == 0, ( + f"Number of GPUs per model ({self.model_comm_group_size}) must be divisible by the read frequency " + f"({self.read_frequency})." + ) + + reader_group_ranks = np.array([ + np.split(group_ranks, int(self.model_comm_group_size / self.read_frequency)) + for group_ranks in model_comm_group_ranks + ]) # Shape: (num_model_comm_groups, model_comm_grp_size/read_freq, read_freq) + reader_groups = [ + [torch.distributed.new_group(x) for x in group_ranks] + for group_ranks in reader_group_ranks + ] + reader_group_id = model_comm_group_rank // self.read_frequency + reader_group_rank = model_comm_group_rank % self.read_frequency + # get all reader groups of the current model group + model_reader_groups = reader_groups[model_comm_group_id] + self.model.set_reader_groups(model_reader_groups) + + LOGGER.debug( + "Rank %d reader group is %s, model_comm_group number %d, local reader group number %d, local reader group rank %d", + self.global_rank, + str(reader_group_ranks[model_comm_group_id, reader_group_id]), + model_comm_group_id, + reader_group_id, + reader_group_rank, + ) + # register hooks for correct gradient reduction self.register_parameter_hooks() diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 51ba2d7c..785da4ce 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -110,6 +110,7 @@ def __init__( self.use_zero_optimizer = config.training.zero_optimizer self.model_comm_group = None + self.reader_groups = None LOGGER.debug("Rollout window length: %d", self.rollout) LOGGER.debug("Rollout increase every : %d epochs", self.rollout_epoch_increment) @@ -124,6 +125,21 @@ def __init__( config.hardware.num_gpus_per_node * config.hardware.num_nodes / config.hardware.num_gpus_per_model, ) + self.reader_group_size = config.dataloader.read_frequency + self.reader_group_id = self.model_comm_group_rank // self.reader_group_size + self.reader_group_rank = self.model_comm_group_rank % self.reader_group_size + # root rank (global rank) of the reader group + self.reader_group_root = (int(os.environ.get("SLURM_PROCID", "0")) // self.reader_group_size) * self.reader_group_size + + LOGGER.debug( + f"GraphForecaster: " + f"Rank {os.environ.get('SLURM_PROCID', '0')} model_comm_group_id: {self.model_comm_group_id}" + f" model_comm_group_rank: {self.model_comm_group_rank}" + f" reader_group_id: {self.reader_group_id}" + f" reader_group_rank: {self.reader_group_rank}" + f" reader_group_root: {self.reader_group_root}", + ) + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x, self.model_comm_group) @@ -187,6 +203,10 @@ def set_model_comm_group(self, model_comm_group: ProcessGroup) -> None: LOGGER.debug("set_model_comm_group: %s", model_comm_group) self.model_comm_group = model_comm_group + def set_reader_groups(self, reader_groups: list[ProcessGroup]) -> None: + LOGGER.debug("set_reader_groups: %s", reader_groups) + self.reader_groups = reader_groups + def advance_input( self, x: torch.Tensor, @@ -220,16 +240,26 @@ def _step( ) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]: del batch_idx + # TODO: change to reader group # preprocess batch and broadcast from gpu0 to model comm group - if self.model_comm_group_rank == 0: # note that this defaults to 0 if model_comm_group is None + if self.reader_group_rank == 0: # for validation not normalized in-place because remappers cannot be applied in-place batch = self.model.pre_processors(batch, in_place=not validation_mode) else: shape = (batch.shape[0],) + tuple(batch[0].tolist()) batch = torch.zeros(shape, device=self.device) - if self.model_comm_group is not None: - torch.distributed.broadcast(batch, src=0, group=self.model_comm_group) + if self.reader_groups is not None and self.reader_group_size > 1: + if self.reader_group_rank == 0: + LOGGER.debug(f"Rank {int(os.environ.get('SLURM_PROCID', '0'))} broadcasting batch") + else: + LOGGER.debug(f"Rank {int(os.environ.get('SLURM_PROCID', '0'))} waiting for broadcast from rank {self.reader_group_root}") + + torch.distributed.broadcast(batch, src=self.reader_group_root, group=self.reader_groups[self.reader_group_id]) + + # Synchronize after the broadcast to ensure that model_comm_group and reader_group don't overlap + # see https://pytorch.org/docs/stable/distributed.html#torch.distributed.new_group WARNING + torch.distributed.barrier(group=self.reader_groups[self.reader_group_id]) loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False) metrics = {} diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index f48b9467..6821668f 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -308,6 +308,7 @@ def strategy(self) -> DDPGroupStrategy: """Training strategy.""" return DDPGroupStrategy( self.config.hardware.num_gpus_per_model, + self.config.dataloader.read_frequency, static_graph=not self.config.training.accum_grad_batches > 1, ) From ee9459391080751d6389779b5da9ca8a0a47cdfd Mon Sep 17 00:00:00 2001 From: japols Date: Wed, 9 Oct 2024 08:51:30 +0000 Subject: [PATCH 04/13] docs: cleanup, add comments --- CHANGELOG.md | 1 + .../training/config/dataloader/native_grid.yaml | 7 +++++++ src/anemoi/training/data/datamodule.py | 1 + src/anemoi/training/data/dataset.py | 5 +---- src/anemoi/training/distributed/strategy.py | 3 ++- src/anemoi/training/train/forecaster.py | 11 +++-------- 6 files changed, 15 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e1bbd55c..db81e451 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ Keep it human-readable, your future self will thank you! ### Added - Codeowners file (#56) - Changelog merge strategy (#56) +- Feature: Add reader groups to reduce CPU memory usage (#75) #### Miscellaneous diff --git a/src/anemoi/training/config/dataloader/native_grid.yaml b/src/anemoi/training/config/dataloader/native_grid.yaml index 5535af3a..4fe5ba20 100644 --- a/src/anemoi/training/config/dataloader/native_grid.yaml +++ b/src/anemoi/training/config/dataloader/native_grid.yaml @@ -1,5 +1,12 @@ prefetch_factor: 2 +# ============ +# read_frequency: +# Only ever read_frequency-th GPU of each model commm group reads data +# to reduce CPU memory usage. +# The number of GPUs per model must be divisible by read_frequency. +# Default: 1 (all GPUs read data) +# ============ read_frequency: 1 num_workers: diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index b4fd5055..8df94361 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -75,6 +75,7 @@ def __init__(self, config: DictConfig) -> None: // self.config.hardware.num_gpus_per_model ) # number of model communication groups + # get reader_group_id (inside model_comm_group) and reader_group_rank (inside reader_group) self.reader_group_id = self.model_comm_group_rank // self.config.dataloader.read_frequency self.reader_group_rank = self.model_comm_group_rank % self.config.dataloader.read_frequency diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index d58e8361..ca3bcda3 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -12,9 +12,6 @@ from functools import cached_property from typing import Callable -import psutil -from collections import defaultdict - import numpy as np import torch from einops import rearrange @@ -211,7 +208,7 @@ def __iter__(self) -> torch.Tensor: now. (Until the code is "ensemble native".) """ if self.reader_group_rank != 0: - # yield dummy data only with shape information for non-root ranks + # yield dummy data only with shape information for non-root ranks (shape used for broadcast) shape = (self.rollout + self.multi_step, self.data.shape[2], self.data.shape[3], self.data.shape[1]) for _ in self.chunk_index_range: yield torch.tensor(shape, dtype=torch.long) diff --git a/src/anemoi/training/distributed/strategy.py b/src/anemoi/training/distributed/strategy.py index 0eb3fab5..b27c2e7e 100644 --- a/src/anemoi/training/distributed/strategy.py +++ b/src/anemoi/training/distributed/strategy.py @@ -78,7 +78,8 @@ def setup(self, trainer: pl.Trainer) -> None: str(model_comm_group_ranks[model_comm_group_id]), ) - # set up reader groups by further splitting model_comm_group_ranks with read_frequency + # set up reader groups by further splitting model_comm_group_ranks with read_frequency: + assert self.model_comm_group_size % self.read_frequency == 0, ( f"Number of GPUs per model ({self.model_comm_group_size}) must be divisible by the read frequency " f"({self.read_frequency})." diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 785da4ce..0c7f3f97 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -128,7 +128,7 @@ def __init__( self.reader_group_size = config.dataloader.read_frequency self.reader_group_id = self.model_comm_group_rank // self.reader_group_size self.reader_group_rank = self.model_comm_group_rank % self.reader_group_size - # root rank (global rank) of the reader group + # global rank of the root of the current reader group (required for broadcasting): self.reader_group_root = (int(os.environ.get("SLURM_PROCID", "0")) // self.reader_group_size) * self.reader_group_size LOGGER.debug( @@ -240,21 +240,16 @@ def _step( ) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]: del batch_idx - # TODO: change to reader group - # preprocess batch and broadcast from gpu0 to model comm group + # preprocess batch and broadcast from reader_group rank 0 to reader_group if self.reader_group_rank == 0: # for validation not normalized in-place because remappers cannot be applied in-place batch = self.model.pre_processors(batch, in_place=not validation_mode) else: + # init batch tensor with correct shape on non-root ranks shape = (batch.shape[0],) + tuple(batch[0].tolist()) batch = torch.zeros(shape, device=self.device) if self.reader_groups is not None and self.reader_group_size > 1: - if self.reader_group_rank == 0: - LOGGER.debug(f"Rank {int(os.environ.get('SLURM_PROCID', '0'))} broadcasting batch") - else: - LOGGER.debug(f"Rank {int(os.environ.get('SLURM_PROCID', '0'))} waiting for broadcast from rank {self.reader_group_root}") - torch.distributed.broadcast(batch, src=self.reader_group_root, group=self.reader_groups[self.reader_group_id]) # Synchronize after the broadcast to ensure that model_comm_group and reader_group don't overlap From 3c6b5c9daecd2847b15c3086e9b98e726270f835 Mon Sep 17 00:00:00 2001 From: japols Date: Wed, 9 Oct 2024 15:37:43 +0000 Subject: [PATCH 05/13] refactor: Pass model/reader group information from DDPGroupStrategy instead of SLURM_PROCID --- CHANGELOG.md | 2 +- src/anemoi/training/data/datamodule.py | 38 ------ src/anemoi/training/data/dataset.py | 70 +++++++---- src/anemoi/training/distributed/strategy.py | 132 ++++++++++++++------ src/anemoi/training/train/forecaster.py | 74 +++++------ src/anemoi/training/train/train.py | 5 +- 6 files changed, 183 insertions(+), 138 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index db81e451..fd6093a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ Keep it human-readable, your future self will thank you! ### Added - Codeowners file (#56) - Changelog merge strategy (#56) -- Feature: Add reader groups to reduce CPU memory usage (#75) +- Feature: Add reader groups to reduce CPU memory usage [#76](https://github.com/ecmwf/anemoi-training/pull/76) #### Miscellaneous diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index 8df94361..907a9b8f 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -6,7 +6,6 @@ # nor does it submit to any jurisdiction. import logging -import os from functools import cached_property from typing import Callable @@ -56,39 +55,6 @@ def __init__(self, config: DictConfig) -> None: timestep, ) - self.global_rank = int(os.environ.get("SLURM_PROCID", "0")) # global rank - self.model_comm_group_id = ( - self.global_rank // self.config.hardware.num_gpus_per_model - ) # id of the model communication group the rank is participating in - self.model_comm_group_rank = ( - self.global_rank % self.config.hardware.num_gpus_per_model - ) # rank within one model communication group - total_gpus = self.config.hardware.num_gpus_per_node * self.config.hardware.num_nodes - assert ( - total_gpus - ) % self.config.hardware.num_gpus_per_model == 0, ( - f"GPUs per model {self.config.hardware.num_gpus_per_model} does not divide total GPUs {total_gpus}" - ) - self.model_comm_num_groups = ( - self.config.hardware.num_gpus_per_node - * self.config.hardware.num_nodes - // self.config.hardware.num_gpus_per_model - ) # number of model communication groups - - # get reader_group_id (inside model_comm_group) and reader_group_rank (inside reader_group) - self.reader_group_id = self.model_comm_group_rank // self.config.dataloader.read_frequency - self.reader_group_rank = self.model_comm_group_rank % self.config.dataloader.read_frequency - - LOGGER.debug( - "Rank %d model communication group number %d, with local model communication group rank %d, " - "reader group number %d, with local reader group rank %d", - self.global_rank, - self.model_comm_group_id, - self.model_comm_group_rank, - self.reader_group_id, - self.reader_group_rank, - ) - # Set the maximum rollout to be expected self.rollout = ( self.config.training.rollout.max @@ -175,10 +141,6 @@ def _get_dataset( rollout=r, multistep=self.config.training.multistep_input, timeincrement=self.timeincrement, - model_comm_group_rank=self.model_comm_group_rank, - model_comm_group_id=self.model_comm_group_id, - model_comm_num_groups=self.model_comm_num_groups, - reader_group_rank=self.reader_group_rank, shuffle=shuffle, label=label, ) diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index ca3bcda3..e16bc6bf 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -33,10 +33,6 @@ def __init__( rollout: int = 1, multistep: int = 1, timeincrement: int = 1, - model_comm_group_rank: int = 0, - model_comm_group_id: int = 0, - model_comm_num_groups: int = 1, - reader_group_rank: int = 0, shuffle: bool = True, label: str = "generic", ) -> None: @@ -52,12 +48,6 @@ def __init__( time increment between samples, by default 1 multistep : int, optional collate (t-1, ... t - multistep) into the input state vector, by default 1 - model_comm_group_rank : int, optional - process rank in the torch.distributed group (important when running on multiple GPUs), by default 0 - model_comm_group_id: int, optional - device group ID, default 0 - model_comm_num_groups : int, optional - total number of device groups, by default 1 shuffle : bool, optional Shuffle batches, by default True label : str, optional @@ -75,14 +65,13 @@ def __init__( self.n_samples_per_epoch_total: int = 0 self.n_samples_per_epoch_per_worker: int = 0 - # DDP-relevant info - self.model_comm_group_rank = model_comm_group_rank - self.model_comm_num_groups = model_comm_num_groups - self.model_comm_group_id = model_comm_group_id - self.global_rank = int(os.environ.get("SLURM_PROCID", "0")) + # lazy init model and reader group info, will be set by the DDPGroupStrategy: + self.model_comm_group_rank = 0 + self.model_comm_num_groups = 1 + self.model_comm_group_id = 0 + self.global_rank = 0 - # Reader group info - self.reader_group_rank = reader_group_rank + self.reader_group_rank = 0 # additional state vars (lazy init) self.n_samples_per_worker = 0 @@ -129,6 +118,45 @@ def valid_date_indices(self) -> np.ndarray: """ return get_usable_indices(self.data.missing, len(self.data), self.rollout, self.multi_step, self.timeincrement) + def set_comm_group_info( + self, + global_rank: int, + model_comm_group_id: int, + model_comm_group_rank: int, + model_comm_num_groups: int, + reader_group_rank: int, + ) -> None: + """Set model and reader communication group information (called by DDPGroupStrategy). + + Parameters + ---------- + global_rank : int + Global rank + model_comm_group_id : int + Model communication group ID + model_comm_group_rank : int + Model communication group rank + model_comm_num_groups : int + Number of model communication groups + reader_group_rank : int + Reader group rank + """ + self.global_rank = global_rank + self.model_comm_group_id = model_comm_group_id + self.model_comm_group_rank = model_comm_group_rank + self.model_comm_num_groups = model_comm_num_groups + self.reader_group_rank = reader_group_rank + + LOGGER.debug( + "NativeGridDataset.set_group_info(): global_rank %d, model_comm_group_id %d, " + "model_comm_group_rank %d, model_comm_num_groups %d, reader_group_rank %d", + global_rank, + model_comm_group_id, + model_comm_group_rank, + model_comm_num_groups, + reader_group_rank, + ) + def per_worker_init(self, n_workers: int, worker_id: int) -> None: """Called by worker_init_func on each copy of dataset. @@ -207,13 +235,13 @@ def __iter__(self) -> torch.Tensor: Currently it receives data with an ensemble dimension, which is discarded for now. (Until the code is "ensemble native".) """ - if self.reader_group_rank != 0: + if self.reader_group_rank != 0: # yield dummy data only with shape information for non-root ranks (shape used for broadcast) shape = (self.rollout + self.multi_step, self.data.shape[2], self.data.shape[3], self.data.shape[1]) - for _ in self.chunk_index_range: + for _ in self.chunk_index_range: yield torch.tensor(shape, dtype=torch.long) - return - + return + if self.shuffle: shuffled_chunk_indices = self.rng.choice( self.chunk_index_range, diff --git a/src/anemoi/training/distributed/strategy.py b/src/anemoi/training/distributed/strategy.py index b27c2e7e..64cb966b 100644 --- a/src/anemoi/training/distributed/strategy.py +++ b/src/anemoi/training/distributed/strategy.py @@ -6,7 +6,6 @@ # nor does it submit to any jurisdiction. import logging -import os import numpy as np import pytorch_lightning as pl @@ -24,11 +23,7 @@ class DDPGroupStrategy(DDPStrategy): """Distributed Data Parallel strategy with group communication.""" - def __init__( - self, - num_gpus_per_model: int, - read_frequency: int, - **kwargs: dict) -> None: + def __init__(self, num_gpus_per_model: int, read_frequency: int, **kwargs: dict) -> None: """Initialize the distributed strategy. Parameters @@ -64,18 +59,15 @@ def setup(self, trainer: pl.Trainer) -> None: torch.distributed.new_group(x) for x in model_comm_group_ranks ] # every rank has to create all of these - model_comm_group_id, model_comm_group_nr, model_comm_group_rank = self.get_my_model_comm_group( + model_comm_group_id, model_comm_group_rank, model_comm_num_groups = self.get_my_model_comm_group( self.model_comm_group_size, ) model_comm_group = model_comm_groups[model_comm_group_id] - self.model.set_model_comm_group(model_comm_group) - LOGGER.debug( - "Rank %d model_comm_group is %s, group number %d, with local group rank %d and comms_group_ranks %s", - self.global_rank, - str(model_comm_group_nr), + self.model.set_model_comm_group( + model_comm_group, model_comm_group_id, model_comm_group_rank, - str(model_comm_group_ranks[model_comm_group_id]), + model_comm_num_groups, ) # set up reader groups by further splitting model_comm_group_ranks with read_frequency: @@ -85,27 +77,38 @@ def setup(self, trainer: pl.Trainer) -> None: f"({self.read_frequency})." ) - reader_group_ranks = np.array([ - np.split(group_ranks, int(self.model_comm_group_size / self.read_frequency)) - for group_ranks in model_comm_group_ranks - ]) # Shape: (num_model_comm_groups, model_comm_grp_size/read_freq, read_freq) - reader_groups = [ - [torch.distributed.new_group(x) for x in group_ranks] - for group_ranks in reader_group_ranks - ] - reader_group_id = model_comm_group_rank // self.read_frequency - reader_group_rank = model_comm_group_rank % self.read_frequency + reader_group_ranks = np.array( + [ + np.split(group_ranks, int(self.model_comm_group_size / self.read_frequency)) + for group_ranks in model_comm_group_ranks + ], + ) # Shape: (num_model_comm_groups, model_comm_grp_size/read_freq, read_freq) + reader_groups = [[torch.distributed.new_group(x) for x in group_ranks] for group_ranks in reader_group_ranks] + reader_group_id, reader_group_rank, reader_group_size, reader_group_root = self.get_my_reader_group( + model_comm_group_rank, + self.read_frequency, + ) # get all reader groups of the current model group model_reader_groups = reader_groups[model_comm_group_id] - self.model.set_reader_groups(model_reader_groups) + self.model.set_reader_groups( + model_reader_groups, + reader_group_id, + reader_group_rank, + reader_group_size, + reader_group_root, + ) LOGGER.debug( - "Rank %d reader group is %s, model_comm_group number %d, local reader group number %d, local reader group rank %d", + "Rank %d model_comm_group_id: %d model_comm_group: %s model_comm_group_rank: %d " + "reader_group_id: %d reader_group: %s reader_group_rank: %d reader_group_root (global): %d", self.global_rank, - str(reader_group_ranks[model_comm_group_id, reader_group_id]), model_comm_group_id, + str(model_comm_group_ranks[model_comm_group_id]), + model_comm_group_rank, reader_group_id, - reader_group_rank, + reader_group_ranks[model_comm_group_id, reader_group_id], + reader_group_rank, + reader_group_root, ) # register hooks for correct gradient reduction @@ -143,7 +146,7 @@ def setup(self, trainer: pl.Trainer) -> None: # seed ranks self.seed_rnd(model_comm_group_id) - def get_my_model_comm_group(self, num_gpus_per_model: int) -> tuple[int, np.ndarray, int]: + def get_my_model_comm_group(self, num_gpus_per_model: int) -> tuple[int, int, int]: """Determine tasks that work together and from a model group. Parameters @@ -153,19 +156,68 @@ def get_my_model_comm_group(self, num_gpus_per_model: int) -> tuple[int, np.ndar Returns ------- - tuple[int, np.ndarray, int] - Model_comm_group id, Model_comm_group Nr, Model_comm_group rank + tuple[int, int, int] + Model_comm_group id, Model_comm_group rank, Number of model_comm_groups + """ + model_comm_group_id = self.global_rank // num_gpus_per_model + model_comm_group_rank = self.global_rank % num_gpus_per_model + model_comm_num_groups = self.world_size // num_gpus_per_model + + return model_comm_group_id, model_comm_group_rank, model_comm_num_groups + + def get_my_reader_group(self, model_comm_group_rank: int, read_frequency: int) -> tuple[int, int, int]: + """Determine tasks that work together and from a reader group. + + Parameters + ---------- + model_comm_group_rank : int + Rank within the model communication group. + read_frequency : int + Frequency of dataloader readers per model group. + + Returns + ------- + tuple[int, int, int] + Reader_group id, Reader_group rank, Reader_group root (global rank) """ - model_comm_groups = np.arange(0, self.world_size, dtype=np.int32) - model_comm_groups = np.split(model_comm_groups, self.world_size / num_gpus_per_model) + reader_group_id = model_comm_group_rank // read_frequency + reader_group_rank = model_comm_group_rank % read_frequency + reader_group_size = read_frequency + reader_group_root = (self.global_rank // read_frequency) * read_frequency - model_comm_group_id = None - for i, model_comm_group in enumerate(model_comm_groups): - if self.global_rank in model_comm_group: - model_comm_group_id = i - model_comm_group_nr = model_comm_group - model_comm_group_rank = np.ravel(np.asarray(model_comm_group == self.global_rank).nonzero())[0] - return model_comm_group_id, model_comm_group_nr, model_comm_group_rank + return reader_group_id, reader_group_rank, reader_group_size, reader_group_root + + def process_dataloader(self, dataloader: torch.utils.data.DataLoader) -> torch.utils.data.DataLoader: + """Pass communication group information to the dataloader for distributed training. + + Parameters + ---------- + dataloader : torch.utils.data.DataLoader + Dataloader to process. + + Returns + ------- + torch.utils.data.DataLoader + Processed dataloader. + + """ + dataloader = super().process_dataloader(dataloader) + + # pass model and reader group information to the dataloaders dataset + model_comm_group_id, model_comm_group_rank, model_comm_num_groups = self.get_my_model_comm_group( + self.model_comm_group_size, + ) + _, reader_group_rank, _, _ = self.get_my_reader_group(model_comm_group_rank, self.read_frequency) + + dataloader.dataset.set_comm_group_info( + self.global_rank, + model_comm_group_id, + model_comm_group_rank, + model_comm_num_groups, + reader_group_rank, + ) + + return dataloader def seed_rnd(self, model_comm_group_id: int) -> None: """Seed the random number generators for the rank.""" @@ -179,7 +231,7 @@ def seed_rnd(self, model_comm_group_id: int) -> None: "Strategy: Rank %d, model comm group id %d, base seed %d, seeded with %d, " "running with random seed: %d, sanity rnd: %s" ), - int(os.environ.get("SLURM_PROCID", "0")), + self.global_rank, model_comm_group_id, base_seed, initial_seed, diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 0c7f3f97..8ca71e46 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -8,8 +8,6 @@ # import logging -import math -import os from collections import defaultdict from collections.abc import Mapping @@ -119,26 +117,14 @@ def __init__( self.enable_plot = config.diagnostics.plot.enabled - self.model_comm_group_id = int(os.environ.get("SLURM_PROCID", "0")) // config.hardware.num_gpus_per_model - self.model_comm_group_rank = int(os.environ.get("SLURM_PROCID", "0")) % config.hardware.num_gpus_per_model - self.model_comm_num_groups = math.ceil( - config.hardware.num_gpus_per_node * config.hardware.num_nodes / config.hardware.num_gpus_per_model, - ) + # lazy init model and reader group info, will be set by the DDPGroupStrategy: + self.model_comm_group_id = 0 + self.model_comm_group_rank = 0 + self.model_comm_num_groups = 1 - self.reader_group_size = config.dataloader.read_frequency - self.reader_group_id = self.model_comm_group_rank // self.reader_group_size - self.reader_group_rank = self.model_comm_group_rank % self.reader_group_size - # global rank of the root of the current reader group (required for broadcasting): - self.reader_group_root = (int(os.environ.get("SLURM_PROCID", "0")) // self.reader_group_size) * self.reader_group_size - - LOGGER.debug( - f"GraphForecaster: " - f"Rank {os.environ.get('SLURM_PROCID', '0')} model_comm_group_id: {self.model_comm_group_id}" - f" model_comm_group_rank: {self.model_comm_group_rank}" - f" reader_group_id: {self.reader_group_id}" - f" reader_group_rank: {self.reader_group_rank}" - f" reader_group_root: {self.reader_group_root}", - ) + self.reader_group_id = 0 + self.reader_group_rank = 0 + self.reader_group_root = 0 def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x, self.model_comm_group) @@ -199,13 +185,31 @@ def metrics_loss_scaling(config: DictConfig, data_indices: IndexCollection) -> t metric_ranges_validation[key] = [idx] return metric_ranges, metric_ranges_validation, loss_scaling - def set_model_comm_group(self, model_comm_group: ProcessGroup) -> None: - LOGGER.debug("set_model_comm_group: %s", model_comm_group) + def set_model_comm_group( + self, + model_comm_group: ProcessGroup, + model_comm_group_id: int, + model_comm_group_rank: int, + model_comm_num_groups: int, + ) -> None: self.model_comm_group = model_comm_group + self.model_comm_group_id = model_comm_group_id + self.model_comm_group_rank = model_comm_group_rank + self.model_comm_num_groups = model_comm_num_groups - def set_reader_groups(self, reader_groups: list[ProcessGroup]) -> None: - LOGGER.debug("set_reader_groups: %s", reader_groups) + def set_reader_groups( + self, + reader_groups: list[ProcessGroup], + reader_group_id: int, + reader_group_rank: int, + reader_group_size: int, + reader_group_root: int, + ) -> None: self.reader_groups = reader_groups + self.reader_group_id = reader_group_id + self.reader_group_rank = reader_group_rank + self.reader_group_size = reader_group_size + self.reader_group_root = reader_group_root def advance_input( self, @@ -240,21 +244,21 @@ def _step( ) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]: del batch_idx - # preprocess batch and broadcast from reader_group rank 0 to reader_group - if self.reader_group_rank == 0: + # preprocess batch and broadcast from reader_group rank 0 to reader_group + if self.reader_group_rank == 0: # for validation not normalized in-place because remappers cannot be applied in-place batch = self.model.pre_processors(batch, in_place=not validation_mode) - else: + else: # init batch tensor with correct shape on non-root ranks - shape = (batch.shape[0],) + tuple(batch[0].tolist()) - batch = torch.zeros(shape, device=self.device) + shape = (batch.shape[0], *tuple(batch[0].tolist())) + batch = torch.empty(shape, device=self.device) if self.reader_groups is not None and self.reader_group_size > 1: - torch.distributed.broadcast(batch, src=self.reader_group_root, group=self.reader_groups[self.reader_group_id]) - - # Synchronize after the broadcast to ensure that model_comm_group and reader_group don't overlap - # see https://pytorch.org/docs/stable/distributed.html#torch.distributed.new_group WARNING - torch.distributed.barrier(group=self.reader_groups[self.reader_group_id]) + torch.distributed.broadcast( + batch, + src=self.reader_group_root, + group=self.reader_groups[self.reader_group_id], + ) loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False) metrics = {} diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 6821668f..9083442e 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -11,7 +11,6 @@ import datetime import logging -import os from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING @@ -101,7 +100,7 @@ def initial_seed(self) -> int: (torch.rand(1), np_rng.random()) LOGGER.debug( "Initial seed: Rank %d, initial seed %d, running with random seed: %d", - int(os.environ.get("SLURM_PROCID", "0")), + self.strategy.global_rank, initial_seed, rnd_seed, ) @@ -308,7 +307,7 @@ def strategy(self) -> DDPGroupStrategy: """Training strategy.""" return DDPGroupStrategy( self.config.hardware.num_gpus_per_model, - self.config.dataloader.read_frequency, + self.config.dataloader.get("read_frequency", 1), static_graph=not self.config.training.accum_grad_batches > 1, ) From 57a13c54d41dd5d89fbdbbb4f27ee6ba281a8493 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Oct 2024 15:50:47 +0000 Subject: [PATCH 06/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anemoi/training/config/dataloader/native_grid.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/training/config/dataloader/native_grid.yaml b/src/anemoi/training/config/dataloader/native_grid.yaml index 4fe5ba20..8ae51553 100644 --- a/src/anemoi/training/config/dataloader/native_grid.yaml +++ b/src/anemoi/training/config/dataloader/native_grid.yaml @@ -3,7 +3,7 @@ prefetch_factor: 2 # ============ # read_frequency: # Only ever read_frequency-th GPU of each model commm group reads data -# to reduce CPU memory usage. +# to reduce CPU memory usage. # The number of GPUs per model must be divisible by read_frequency. # Default: 1 (all GPUs read data) # ============ From 9a22225afb1c9e47e33c11d9714966721e883244 Mon Sep 17 00:00:00 2001 From: japols Date: Thu, 24 Oct 2024 19:09:03 +0000 Subject: [PATCH 07/13] feat: Add support for sharded reading in dataloader --- .../config/dataloader/native_grid.yaml | 11 +++- src/anemoi/training/data/datamodule.py | 1 + src/anemoi/training/data/dataset.py | 24 ++++++- src/anemoi/training/distributed/strategy.py | 2 + src/anemoi/training/train/forecaster.py | 65 +++++++++++++++---- 5 files changed, 87 insertions(+), 16 deletions(-) diff --git a/src/anemoi/training/config/dataloader/native_grid.yaml b/src/anemoi/training/config/dataloader/native_grid.yaml index 46deed08..9086daab 100644 --- a/src/anemoi/training/config/dataloader/native_grid.yaml +++ b/src/anemoi/training/config/dataloader/native_grid.yaml @@ -6,10 +6,19 @@ pin_memory: True # Only ever read_frequency-th GPU of each model commm group reads data # to reduce CPU memory usage. # The number of GPUs per model must be divisible by read_frequency. -# Default: 1 (all GPUs read data) +# Default: 1 (all GPUs read data), only if read_shards is False # ============ read_frequency: 1 +# ============ +# read_shards: +# Every GPU only reads 1/num_gpus_per_model of its data sharded along +# the grid dimension which is then put back together via all-gather. +# This can reduce CPU memory usage as well as increase dataloader throughput. +# Default: True, only works if read_frequency is 1 +# ============ +read_shards: True + num_workers: training: 8 validation: 8 diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index 062e0073..2bcf05af 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -149,6 +149,7 @@ def _get_dataset( timeincrement=self.timeincrement, shuffle=shuffle, label=label, + read_shards=self.config.dataloader.read_shards, ) self._check_resolution(data.resolution) return data diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index e9773f6e..dcbb2816 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -38,6 +38,7 @@ def __init__( timeincrement: int = 1, shuffle: bool = True, label: str = "generic", + read_shards: bool = False, ) -> None: """Initialize (part of) the dataset state. @@ -72,8 +73,11 @@ def __init__( self.model_comm_group_rank = 0 self.model_comm_num_groups = 1 self.model_comm_group_id = 0 + self.model_comm_group_size = 1 self.global_rank = 0 + self.read_shards = read_shards + self.reader_group_rank = 0 # additional state vars (lazy init) @@ -86,6 +90,7 @@ def __init__( assert self.multi_step > 0, "Multistep value must be greater than zero." self.ensemble_dim: int = 2 self.ensemble_size = self.data.shape[self.ensemble_dim] + self.grid_size = self.data.shape[-1] @cached_property def statistics(self) -> dict: @@ -124,6 +129,7 @@ def valid_date_indices(self) -> np.ndarray: def set_comm_group_info( self, global_rank: int, + model_comm_group_size: int, model_comm_group_id: int, model_comm_group_rank: int, model_comm_num_groups: int, @@ -135,6 +141,8 @@ def set_comm_group_info( ---------- global_rank : int Global rank + model_comm_group_size : int + Model communication group size model_comm_group_id : int Model communication group ID model_comm_group_rank : int @@ -145,11 +153,21 @@ def set_comm_group_info( Reader group rank """ self.global_rank = global_rank + self.model_comm_group_size = model_comm_group_size self.model_comm_group_id = model_comm_group_id self.model_comm_group_rank = model_comm_group_rank self.model_comm_num_groups = model_comm_num_groups self.reader_group_rank = reader_group_rank + if self.read_shards: + # get the grid shard size and start/end indices + grid_shard_size = self.grid_size // self.model_comm_group_size + self.grid_start = self.model_comm_group_rank * grid_shard_size + if self.model_comm_group_rank == self.model_comm_group_size - 1: + self.grid_end = self.grid_size + else: + self.grid_end = (self.model_comm_group_rank + 1) * grid_shard_size + LOGGER.debug( "NativeGridDataset.set_group_info(): global_rank %d, model_comm_group_id %d, " "model_comm_group_rank %d, model_comm_num_groups %d, reader_group_rank %d", @@ -272,7 +290,11 @@ def __iter__(self) -> torch.Tensor: start = i - (self.multi_step - 1) * self.timeincrement end = i + (self.rollout + 1) * self.timeincrement - x = self.data[start : end : self.timeincrement] + if self.read_shards: + x = self.data[start : end : self.timeincrement, :, :, self.grid_start : self.grid_end] + else: + x = self.data[start : end : self.timeincrement, :, :, :] + x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables") self.ensemble_dim = 1 diff --git a/src/anemoi/training/distributed/strategy.py b/src/anemoi/training/distributed/strategy.py index a1f5c40f..53ecf09e 100644 --- a/src/anemoi/training/distributed/strategy.py +++ b/src/anemoi/training/distributed/strategy.py @@ -71,6 +71,7 @@ def setup(self, trainer: pl.Trainer) -> None: model_comm_group_id, model_comm_group_rank, model_comm_num_groups, + self.model_comm_group_size, ) # set up reader groups by further splitting model_comm_group_ranks with read_frequency: @@ -214,6 +215,7 @@ def process_dataloader(self, dataloader: torch.utils.data.DataLoader) -> torch.u dataloader.dataset.set_comm_group_info( self.global_rank, + self.model_comm_group_size, model_comm_group_id, model_comm_group_rank, model_comm_num_groups, diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 839fb802..1a77494e 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -135,6 +135,8 @@ def __init__( self.reader_group_rank = 0 self.reader_group_root = 0 + self.read_shards = config.dataloader.read_shards + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x, self.model_comm_group) @@ -200,11 +202,13 @@ def set_model_comm_group( model_comm_group_id: int, model_comm_group_rank: int, model_comm_num_groups: int, + model_comm_group_size: int, ) -> None: self.model_comm_group = model_comm_group self.model_comm_group_id = model_comm_group_id self.model_comm_group_rank = model_comm_group_rank self.model_comm_num_groups = model_comm_num_groups + self.model_comm_group_size = model_comm_group_size def set_reader_groups( self, @@ -220,6 +224,10 @@ def set_reader_groups( self.reader_group_size = reader_group_size self.reader_group_root = reader_group_root + assert not ( + self.reader_group_size > 1 and self.read_shards + ), "Reading shards is not supported with reader group size > 1" + def advance_input( self, x: torch.Tensor, @@ -255,21 +263,15 @@ def _step( ) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]: del batch_idx - # preprocess batch and broadcast from reader_group rank 0 to reader_group - if self.reader_group_rank == 0: - # for validation not normalized in-place because remappers cannot be applied in-place - batch = self.model.pre_processors(batch, in_place=not validation_mode) - else: - # init batch tensor with correct shape on non-root ranks - shape = (batch.shape[0], *tuple(batch[0].tolist())) - batch = torch.empty(shape, device=self.device) - if self.reader_groups is not None and self.reader_group_size > 1: - torch.distributed.broadcast( - batch, - src=self.reader_group_root, - group=self.reader_groups[self.reader_group_id], - ) + batch = self.broadcast_batch(batch) + + # all gather grid shards from model_comm_group + if self.model_comm_group is not None and self.read_shards: + batch = self.allgather_batch(batch) + + # for validation not normalized in-place because remappers cannot be applied in-place + batch = self.model.pre_processors(batch, in_place=not validation_mode) loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False) metrics = {} @@ -307,6 +309,41 @@ def _step( loss *= 1.0 / self.rollout return loss, metrics, y_preds + def broadcast_batch(self, batch: torch.Tensor) -> torch.Tensor: + if self.reader_group_rank != 0: + # init batch tensor with correct shape on non-root ranks + shape = (batch.shape[0], *tuple(batch[0].tolist())) + batch = torch.empty(shape, device=self.device) + + torch.distributed.broadcast( + batch, + src=self.reader_group_root, + group=self.reader_groups[self.reader_group_id], + ) + + return batch + + def allgather_batch(self, batch: torch.Tensor) -> torch.Tensor: + # get shard shapes + grid_size = self.model.metadata["dataset"]["shape"][-1] + shard_shape = list(batch.shape) + # handle cases of last grid shard + shard_shape[-2] = grid_size // self.model_comm_group_size + last_grid_dim = grid_size - (shard_shape[-2] * (self.model_comm_group_size - 1)) + + tensor_list = [ + torch.empty(tuple(shard_shape), device=self.device) for _ in range(self.model_comm_group_size - 1) + ] + tensor_list.append(torch.empty((*tuple(shard_shape[:-2]), last_grid_dim, shard_shape[-1]), device=self.device)) + + torch.distributed.all_gather( + tensor_list, + batch, + group=self.model_comm_group, + ) + + return torch.cat(tensor_list, dim=-2) + def calculate_val_metrics( self, y_pred: torch.Tensor, From 6615c9798c1fc588f943c33e78b81d1b1a3add52 Mon Sep 17 00:00:00 2001 From: japols Date: Fri, 25 Oct 2024 12:33:40 +0000 Subject: [PATCH 08/13] refactor: merge read groups with sharded reading functionality in dataloader --- .../config/dataloader/native_grid.yaml | 21 +++----- src/anemoi/training/data/datamodule.py | 1 - src/anemoi/training/data/dataset.py | 34 +++++-------- src/anemoi/training/distributed/strategy.py | 41 ++++++++-------- src/anemoi/training/train/forecaster.py | 49 +++++-------------- src/anemoi/training/train/train.py | 2 +- 6 files changed, 52 insertions(+), 96 deletions(-) diff --git a/src/anemoi/training/config/dataloader/native_grid.yaml b/src/anemoi/training/config/dataloader/native_grid.yaml index 9086daab..7262117e 100644 --- a/src/anemoi/training/config/dataloader/native_grid.yaml +++ b/src/anemoi/training/config/dataloader/native_grid.yaml @@ -2,22 +2,15 @@ prefetch_factor: 2 pin_memory: True # ============ -# read_frequency: -# Only ever read_frequency-th GPU of each model commm group reads data -# to reduce CPU memory usage. -# The number of GPUs per model must be divisible by read_frequency. -# Default: 1 (all GPUs read data), only if read_shards is False -# ============ -read_frequency: 1 - -# ============ -# read_shards: -# Every GPU only reads 1/num_gpus_per_model of its data sharded along -# the grid dimension which is then put back together via all-gather. +# read_group_size: +# Form subgroups of model comm groups that read data together. +# Each reader in the group only reads 1/read_group_size of the data +# which is then all-gathered between the group. # This can reduce CPU memory usage as well as increase dataloader throughput. -# Default: True, only works if read_frequency is 1 +# The number of GPUs per model must be divisible by read_group_size. +# Good values are num_gpus_per_model or num_gpus_per_node. # ============ -read_shards: True +read_group_size: 1 num_workers: training: 8 diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index 2bcf05af..062e0073 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -149,7 +149,6 @@ def _get_dataset( timeincrement=self.timeincrement, shuffle=shuffle, label=label, - read_shards=self.config.dataloader.read_shards, ) self._check_resolution(data.resolution) return data diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index dcbb2816..e6a10943 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -38,7 +38,6 @@ def __init__( timeincrement: int = 1, shuffle: bool = True, label: str = "generic", - read_shards: bool = False, ) -> None: """Initialize (part of) the dataset state. @@ -73,12 +72,10 @@ def __init__( self.model_comm_group_rank = 0 self.model_comm_num_groups = 1 self.model_comm_group_id = 0 - self.model_comm_group_size = 1 self.global_rank = 0 - self.read_shards = read_shards - self.reader_group_rank = 0 + self.reader_group_size = 1 # additional state vars (lazy init) self.n_samples_per_worker = 0 @@ -129,11 +126,11 @@ def valid_date_indices(self) -> np.ndarray: def set_comm_group_info( self, global_rank: int, - model_comm_group_size: int, model_comm_group_id: int, model_comm_group_rank: int, model_comm_num_groups: int, reader_group_rank: int, + reader_group_size: int, ) -> None: """Set model and reader communication group information (called by DDPGroupStrategy). @@ -141,8 +138,6 @@ def set_comm_group_info( ---------- global_rank : int Global rank - model_comm_group_size : int - Model communication group size model_comm_group_id : int Model communication group ID model_comm_group_rank : int @@ -151,22 +146,24 @@ def set_comm_group_info( Number of model communication groups reader_group_rank : int Reader group rank + reader_group_size : int + Reader group size """ self.global_rank = global_rank - self.model_comm_group_size = model_comm_group_size self.model_comm_group_id = model_comm_group_id self.model_comm_group_rank = model_comm_group_rank self.model_comm_num_groups = model_comm_num_groups self.reader_group_rank = reader_group_rank + self.reader_group_size = reader_group_size - if self.read_shards: + if self.reader_group_size > 1: # get the grid shard size and start/end indices - grid_shard_size = self.grid_size // self.model_comm_group_size - self.grid_start = self.model_comm_group_rank * grid_shard_size - if self.model_comm_group_rank == self.model_comm_group_size - 1: + grid_shard_size = self.grid_size // self.reader_group_size + self.grid_start = self.reader_group_rank * grid_shard_size + if self.reader_group_rank == self.reader_group_size - 1: self.grid_end = self.grid_size else: - self.grid_end = (self.model_comm_group_rank + 1) * grid_shard_size + self.grid_end = (self.reader_group_rank + 1) * grid_shard_size LOGGER.debug( "NativeGridDataset.set_group_info(): global_rank %d, model_comm_group_id %d, " @@ -256,13 +253,6 @@ def __iter__(self) -> torch.Tensor: Currently it receives data with an ensemble dimension, which is discarded for now. (Until the code is "ensemble native".) """ - if self.reader_group_rank != 0: - # yield dummy data only with shape information for non-root ranks (shape used for broadcast) - shape = (self.rollout + self.multi_step, self.data.shape[2], self.data.shape[3], self.data.shape[1]) - for _ in self.chunk_index_range: - yield torch.tensor(shape, dtype=torch.long) - return - if self.shuffle: shuffled_chunk_indices = self.rng.choice( self.chunk_index_range, @@ -290,9 +280,9 @@ def __iter__(self) -> torch.Tensor: start = i - (self.multi_step - 1) * self.timeincrement end = i + (self.rollout + 1) * self.timeincrement - if self.read_shards: + if self.reader_group_size > 1: # read only a subset of the grid x = self.data[start : end : self.timeincrement, :, :, self.grid_start : self.grid_end] - else: + else: # read the full grid x = self.data[start : end : self.timeincrement, :, :, :] x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables") diff --git a/src/anemoi/training/distributed/strategy.py b/src/anemoi/training/distributed/strategy.py index 53ecf09e..32c96dc6 100644 --- a/src/anemoi/training/distributed/strategy.py +++ b/src/anemoi/training/distributed/strategy.py @@ -26,22 +26,22 @@ class DDPGroupStrategy(DDPStrategy): """Distributed Data Parallel strategy with group communication.""" - def __init__(self, num_gpus_per_model: int, read_frequency: int, **kwargs: dict) -> None: + def __init__(self, num_gpus_per_model: int, read_group_size: int, **kwargs: dict) -> None: """Initialize the distributed strategy. Parameters ---------- num_gpus_per_model : int Number of GPUs per model to shard over. - read_frequency : int - Frequency of dataloader readers per model group. + read_group_size : int + Number of GPUs per reader group. **kwargs : dict Additional keyword arguments. """ super().__init__(**kwargs) self.model_comm_group_size = num_gpus_per_model - self.read_frequency = read_frequency + self.read_group_size = read_group_size def setup(self, trainer: pl.Trainer) -> None: assert self.accelerator is not None, "Accelerator is not initialized for distributed strategy" @@ -74,23 +74,23 @@ def setup(self, trainer: pl.Trainer) -> None: self.model_comm_group_size, ) - # set up reader groups by further splitting model_comm_group_ranks with read_frequency: + # set up reader groups by further splitting model_comm_group_ranks with read_group_size: - assert self.model_comm_group_size % self.read_frequency == 0, ( - f"Number of GPUs per model ({self.model_comm_group_size}) must be divisible by the read frequency " - f"({self.read_frequency})." + assert self.model_comm_group_size % self.read_group_size == 0, ( + f"Number of GPUs per model ({self.model_comm_group_size}) must be divisible by read_group_size " + f"({self.read_group_size})." ) reader_group_ranks = np.array( [ - np.split(group_ranks, int(self.model_comm_group_size / self.read_frequency)) + np.split(group_ranks, int(self.model_comm_group_size / self.read_group_size)) for group_ranks in model_comm_group_ranks ], - ) # Shape: (num_model_comm_groups, model_comm_grp_size/read_freq, read_freq) + ) # Shape: (num_model_comm_groups, model_comm_grp_size/read_group_size, read_group_size) reader_groups = [[torch.distributed.new_group(x) for x in group_ranks] for group_ranks in reader_group_ranks] reader_group_id, reader_group_rank, reader_group_size, reader_group_root = self.get_my_reader_group( model_comm_group_rank, - self.read_frequency, + self.read_group_size, ) # get all reader groups of the current model group model_reader_groups = reader_groups[model_comm_group_id] @@ -99,7 +99,6 @@ def setup(self, trainer: pl.Trainer) -> None: reader_group_id, reader_group_rank, reader_group_size, - reader_group_root, ) LOGGER.debug( @@ -169,25 +168,25 @@ def get_my_model_comm_group(self, num_gpus_per_model: int) -> tuple[int, int, in return model_comm_group_id, model_comm_group_rank, model_comm_num_groups - def get_my_reader_group(self, model_comm_group_rank: int, read_frequency: int) -> tuple[int, int, int]: + def get_my_reader_group(self, model_comm_group_rank: int, read_group_size: int) -> tuple[int, int, int]: """Determine tasks that work together and from a reader group. Parameters ---------- model_comm_group_rank : int Rank within the model communication group. - read_frequency : int - Frequency of dataloader readers per model group. + read_group_size : int + Number of dataloader readers per model group. Returns ------- tuple[int, int, int] Reader_group id, Reader_group rank, Reader_group root (global rank) """ - reader_group_id = model_comm_group_rank // read_frequency - reader_group_rank = model_comm_group_rank % read_frequency - reader_group_size = read_frequency - reader_group_root = (self.global_rank // read_frequency) * read_frequency + reader_group_id = model_comm_group_rank // read_group_size + reader_group_rank = model_comm_group_rank % read_group_size + reader_group_size = read_group_size + reader_group_root = (self.global_rank // read_group_size) * read_group_size return reader_group_id, reader_group_rank, reader_group_size, reader_group_root @@ -211,15 +210,15 @@ def process_dataloader(self, dataloader: torch.utils.data.DataLoader) -> torch.u model_comm_group_id, model_comm_group_rank, model_comm_num_groups = self.get_my_model_comm_group( self.model_comm_group_size, ) - _, reader_group_rank, _, _ = self.get_my_reader_group(model_comm_group_rank, self.read_frequency) + _, reader_group_rank, _, _ = self.get_my_reader_group(model_comm_group_rank, self.read_group_size) dataloader.dataset.set_comm_group_info( self.global_rank, - self.model_comm_group_size, model_comm_group_id, model_comm_group_rank, model_comm_num_groups, reader_group_rank, + self.read_group_size, ) return dataloader diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 1a77494e..80f6d38e 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -133,9 +133,6 @@ def __init__( self.reader_group_id = 0 self.reader_group_rank = 0 - self.reader_group_root = 0 - - self.read_shards = config.dataloader.read_shards def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x, self.model_comm_group) @@ -216,17 +213,11 @@ def set_reader_groups( reader_group_id: int, reader_group_rank: int, reader_group_size: int, - reader_group_root: int, ) -> None: self.reader_groups = reader_groups self.reader_group_id = reader_group_id self.reader_group_rank = reader_group_rank self.reader_group_size = reader_group_size - self.reader_group_root = reader_group_root - - assert not ( - self.reader_group_size > 1 and self.read_shards - ), "Reading shards is not supported with reader group size > 1" def advance_input( self, @@ -263,11 +254,8 @@ def _step( ) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]: del batch_idx - if self.reader_groups is not None and self.reader_group_size > 1: - batch = self.broadcast_batch(batch) - - # all gather grid shards from model_comm_group - if self.model_comm_group is not None and self.read_shards: + # all gather grid shards from reader group + if self.reader_group_size > 1: batch = self.allgather_batch(batch) # for validation not normalized in-place because remappers cannot be applied in-place @@ -309,37 +297,24 @@ def _step( loss *= 1.0 / self.rollout return loss, metrics, y_preds - def broadcast_batch(self, batch: torch.Tensor) -> torch.Tensor: - if self.reader_group_rank != 0: - # init batch tensor with correct shape on non-root ranks - shape = (batch.shape[0], *tuple(batch[0].tolist())) - batch = torch.empty(shape, device=self.device) - - torch.distributed.broadcast( - batch, - src=self.reader_group_root, - group=self.reader_groups[self.reader_group_id], - ) - - return batch - def allgather_batch(self, batch: torch.Tensor) -> torch.Tensor: - # get shard shapes grid_size = self.model.metadata["dataset"]["shape"][-1] + grid_shard_size = grid_size // self.reader_group_size + last_grid_shard_size = grid_size - (grid_shard_size * (self.reader_group_size - 1)) + + # prepare tensor list with correct shapes for all_gather shard_shape = list(batch.shape) - # handle cases of last grid shard - shard_shape[-2] = grid_size // self.model_comm_group_size - last_grid_dim = grid_size - (shard_shape[-2] * (self.model_comm_group_size - 1)) + shard_shape[-2] = grid_shard_size + last_shard_shape = list(batch.shape) + last_shard_shape[-2] = last_grid_shard_size - tensor_list = [ - torch.empty(tuple(shard_shape), device=self.device) for _ in range(self.model_comm_group_size - 1) - ] - tensor_list.append(torch.empty((*tuple(shard_shape[:-2]), last_grid_dim, shard_shape[-1]), device=self.device)) + tensor_list = [torch.empty(tuple(shard_shape), device=self.device) for _ in range(self.reader_group_size - 1)] + tensor_list.append(torch.empty(last_shard_shape, device=self.device)) torch.distributed.all_gather( tensor_list, batch, - group=self.model_comm_group, + group=self.reader_groups[self.reader_group_id], ) return torch.cat(tensor_list, dim=-2) diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index c1129a57..fa3260b5 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -334,7 +334,7 @@ def strategy(self) -> DDPGroupStrategy: """Training strategy.""" return DDPGroupStrategy( self.config.hardware.num_gpus_per_model, - self.config.dataloader.get("read_frequency", 1), + self.config.dataloader.get("read_group_size", 1), static_graph=not self.config.training.accum_grad_batches > 1, ) From 93a7e0680f0279aad766570134d71a0e78242836 Mon Sep 17 00:00:00 2001 From: japols Date: Mon, 28 Oct 2024 10:09:27 +0000 Subject: [PATCH 09/13] refactor: adress PR review comments --- src/anemoi/training/data/dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index e6a10943..40065e06 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -87,7 +87,8 @@ def __init__( assert self.multi_step > 0, "Multistep value must be greater than zero." self.ensemble_dim: int = 2 self.ensemble_size = self.data.shape[self.ensemble_dim] - self.grid_size = self.data.shape[-1] + self.grid_dim: int = -1 + self.grid_size = self.data.shape[self.grid_dim] @cached_property def statistics(self) -> dict: From 3b69b33362cee44c95ad084c97d5fb05e049ac76 Mon Sep 17 00:00:00 2001 From: japols Date: Mon, 11 Nov 2024 13:04:05 +0000 Subject: [PATCH 10/13] fix: allgather batch in callbacks --- src/anemoi/training/diagnostics/callbacks/evaluation.py | 4 ++-- src/anemoi/training/diagnostics/callbacks/plot.py | 6 ++++-- src/anemoi/training/train/forecaster.py | 9 ++++++--- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/evaluation.py b/src/anemoi/training/diagnostics/callbacks/evaluation.py index fc812121..9cb9c0ef 100644 --- a/src/anemoi/training/diagnostics/callbacks/evaluation.py +++ b/src/anemoi/training/diagnostics/callbacks/evaluation.py @@ -15,7 +15,6 @@ import torch from pytorch_lightning.callbacks import Callback -from pytorch_lightning.utilities import rank_zero_only if TYPE_CHECKING: import pytorch_lightning as pl @@ -103,7 +102,6 @@ def _log(self, pl_module: pl.LightningModule, loss: torch.Tensor, metrics: dict, rank_zero_only=True, ) - @rank_zero_only def on_validation_batch_end( self, trainer: pl.Trainer, @@ -122,5 +120,7 @@ def on_validation_batch_end( dtype = precision_mapping.get(prec) context = torch.autocast(device_type=batch.device.type, dtype=dtype) if dtype is not None else nullcontext() + batch = pl_module.allgather_batch(batch) + with context: self._eval(pl_module, batch) diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 93f9aa17..3814c70c 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -206,7 +206,6 @@ def _plot( ) -> None: """Plotting function to be implemented by subclasses.""" - @rank_zero_only def on_validation_batch_end( self, trainer, @@ -217,6 +216,8 @@ def on_validation_batch_end( **kwargs, ) -> None: if batch_idx % self.every_n_batches == 0: + batch = pl_module.allgather_batch(batch) + self.plot( trainer, pl_module, @@ -403,7 +404,6 @@ def _plot( int(time.time() - start_time), ) - @rank_zero_only def on_validation_batch_end( self, trainer, @@ -413,6 +413,8 @@ def on_validation_batch_end( batch_idx: int, ) -> None: if (batch_idx) == 0 and (trainer.current_epoch + 1) % self.every_n_epochs == 0: + batch = pl_module.allgather_batch(batch) + precision_mapping = { "16-mixed": torch.float16, "bf16-mixed": torch.bfloat16, diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index cf534860..54d7ff71 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -379,9 +379,6 @@ def rollout_step( None None """ - if self.reader_group_size > 1: - batch = self.allgather_batch(batch) - # for validation not normalized in-place because remappers cannot be applied in-place batch = self.model.pre_processors(batch, in_place=not validation_mode) # start rollout of preprocessed batch @@ -423,6 +420,8 @@ def _step( validation_mode: bool = False, ) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]: del batch_idx + batch = self.allgather_batch(batch) + loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False) metrics = {} y_preds = [] @@ -442,6 +441,10 @@ def _step( def allgather_batch(self, batch: torch.Tensor) -> torch.Tensor: grid_size = self.model.metadata["dataset"]["shape"][-1] + + if grid_size == batch.shape[-2]: + return batch # already have the full grid + grid_shard_size = grid_size // self.reader_group_size last_grid_shard_size = grid_size - (grid_shard_size * (self.reader_group_size - 1)) From 3c01e14163d89f2fbb9d31847a5e0d3bcac509bc Mon Sep 17 00:00:00 2001 From: japols Date: Mon, 18 Nov 2024 12:44:49 +0000 Subject: [PATCH 11/13] cleanup, async plot warning --- CHANGELOG.md | 2 +- .../training/diagnostics/callbacks/evaluation.py | 4 ++-- src/anemoi/training/diagnostics/callbacks/plot.py | 4 +++- src/anemoi/training/train/forecaster.py | 12 ++++++++++++ src/anemoi/training/train/train.py | 2 +- 5 files changed, 19 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e84d035d..311ddec0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ Keep it human-readable, your future self will thank you! - Feat: Anemoi Profiler compatible with mlflow and using Pytorch (Kineto) Profiler for memory report [38](https://github.com/ecmwf/anemoi-training/pull/38/) - New limited area config file added, limited_area.yaml. [#134](https://github.com/ecmwf/anemoi-training/pull/134/) - New stretched grid config added, stretched_grid.yaml [#133](https://github.com/ecmwf/anemoi-training/pull/133) +- Add reader groups to reduce CPU memory usage and increase dataloader throughput [#76](https://github.com/ecmwf/anemoi-training/pull/76) ### Changed - Renamed frequency keys in callbacks configuration. [#118](https://github.com/ecmwf/anemoi-training/pull/118) @@ -78,7 +79,6 @@ Keep it human-readable, your future self will thank you! - Add anemoi-transform link to documentation - Codeowners file (#56) - Changelog merge strategy (#56) -- Feature: Add reader groups to reduce CPU memory usage [#76](https://github.com/ecmwf/anemoi-training/pull/76) - Contributors file (#106) #### Miscellaneous diff --git a/src/anemoi/training/diagnostics/callbacks/evaluation.py b/src/anemoi/training/diagnostics/callbacks/evaluation.py index 9cb9c0ef..cbc929d6 100644 --- a/src/anemoi/training/diagnostics/callbacks/evaluation.py +++ b/src/anemoi/training/diagnostics/callbacks/evaluation.py @@ -112,6 +112,8 @@ def on_validation_batch_end( ) -> None: del outputs # outputs are not used if batch_idx % self.every_n_batches == 0: + batch = pl_module.allgather_batch(batch) + precision_mapping = { "16-mixed": torch.float16, "bf16-mixed": torch.bfloat16, @@ -120,7 +122,5 @@ def on_validation_batch_end( dtype = precision_mapping.get(prec) context = torch.autocast(device_type=batch.device.type, dtype=dtype) if dtype is not None else nullcontext() - batch = pl_module.allgather_batch(batch) - with context: self._eval(pl_module, batch) diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 40669458..34e07b84 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -210,7 +210,6 @@ def __init__(self, config: OmegaConf, every_n_batches: int | None = None): super().__init__(config) self.every_n_batches = every_n_batches or self.config.diagnostics.plot.frequency.batch - @rank_zero_only def on_validation_batch_end( self, trainer: pl.Trainer, @@ -220,6 +219,9 @@ def on_validation_batch_end( batch_idx: int, **kwargs, ) -> None: + if self.config.diagnostics.plot.asynchronous and self.config.dataloader.read_group_size > 1: + LOGGER.warning("Asynchronous plotting can result in NCCL timeouts with reader_group_size > 1.") + if batch_idx % self.every_n_batches == 0: batch = pl_module.allgather_batch(batch) diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 46c16bc2..50ffc0cf 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -442,6 +442,18 @@ def _step( return loss, metrics, y_preds def allgather_batch(self, batch: torch.Tensor) -> torch.Tensor: + """Allgather the batch-shards across the reader group. + + Parameters + ---------- + batch : torch.Tensor + Batch-shard of current reader rank + + Returns + ------- + torch.Tensor + Allgathered (full) batch + """ grid_size = self.model.metadata["dataset"]["shape"][-1] if grid_size == batch.shape[-2]: diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 8af2d612..80fc70d3 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -344,7 +344,7 @@ def strategy(self) -> DDPGroupStrategy: """Training strategy.""" return DDPGroupStrategy( self.config.hardware.num_gpus_per_model, - self.config.dataloader.get("read_group_size", 1), + self.config.dataloader.get("read_group_size", self.config.hardware.num_gpus_per_model), static_graph=not self.config.training.accum_grad_batches > 1, ) From 05665c7f8257bf75092049f4f6a71de4055b2fc7 Mon Sep 17 00:00:00 2001 From: japols Date: Tue, 19 Nov 2024 08:33:37 +0000 Subject: [PATCH 12/13] fix: changelog, read_group default --- CHANGELOG.md | 3 ++- src/anemoi/training/config/dataloader/native_grid.yaml | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4a84b0df..64dcc4ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ Keep it human-readable, your future self will thank you! ### Added +- Add reader groups to reduce CPU memory usage and increase dataloader throughput [#76](https://github.com/ecmwf/anemoi-training/pull/76) + ### Changed ## [0.3.0 - Loss & Callback Refactors](https://github.com/ecmwf/anemoi-training/compare/0.2.2...0.3.0) - 2024-11-14 @@ -46,7 +48,6 @@ Keep it human-readable, your future self will thank you! - Feat: Anemoi Profiler compatible with mlflow and using Pytorch (Kineto) Profiler for memory report [38](https://github.com/ecmwf/anemoi-training/pull/38/) - New limited area config file added, limited_area.yaml. [#134](https://github.com/ecmwf/anemoi-training/pull/134/) - New stretched grid config added, stretched_grid.yaml [#133](https://github.com/ecmwf/anemoi-training/pull/133) -- Add reader groups to reduce CPU memory usage and increase dataloader throughput [#76](https://github.com/ecmwf/anemoi-training/pull/76) ### Changed diff --git a/src/anemoi/training/config/dataloader/native_grid.yaml b/src/anemoi/training/config/dataloader/native_grid.yaml index 5018541d..9513ecc7 100644 --- a/src/anemoi/training/config/dataloader/native_grid.yaml +++ b/src/anemoi/training/config/dataloader/native_grid.yaml @@ -8,9 +8,9 @@ pin_memory: True # which is then all-gathered between the group. # This can reduce CPU memory usage as well as increase dataloader throughput. # The number of GPUs per model must be divisible by read_group_size. -# Good values are num_gpus_per_model or num_gpus_per_node. +# To disable, set to 1. # ============ -read_group_size: 1 +read_group_size: ${hardware.num_gpus_per_model} num_workers: training: 8 From aea0ad784e577a2866b0ecf05af2299607108a30 Mon Sep 17 00:00:00 2001 From: japols Date: Tue, 19 Nov 2024 10:41:02 +0000 Subject: [PATCH 13/13] docs: docstring, fix: warning only on rank 0 --- docs/user-guide/distributed.rst | 4 ++++ src/anemoi/training/diagnostics/callbacks/plot.py | 6 +++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/user-guide/distributed.rst b/docs/user-guide/distributed.rst index 40ee4d65..68d7697a 100644 --- a/docs/user-guide/distributed.rst +++ b/docs/user-guide/distributed.rst @@ -45,6 +45,10 @@ number of GPUs you wish to shard the model across. It is recommended to only shard if the model does not fit in GPU memory, as data distribution is a much more efficient way to parallelise the training. +When using model sharding, ``config.dataloader.read_group_size`` allows +for sharded data loading in subgroups. This should be set to the number +of GPUs per model for optimal performance. + ********* Example ********* diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 34e07b84..f5e07f4b 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -219,7 +219,11 @@ def on_validation_batch_end( batch_idx: int, **kwargs, ) -> None: - if self.config.diagnostics.plot.asynchronous and self.config.dataloader.read_group_size > 1: + if ( + self.config.diagnostics.plot.asynchronous + and self.config.dataloader.read_group_size > 1 + and pl_module.local_rank == 0 + ): LOGGER.warning("Asynchronous plotting can result in NCCL timeouts with reader_group_size > 1.") if batch_idx % self.every_n_batches == 0: