Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/improve dataloader memory #76

Open
wants to merge 19 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
cdaf082
feat: initial implementation of dataloader memory optimization
japols Sep 24, 2024
fcc7c93
fix: non-reader tasks actually return before reading
japols Oct 2, 2024
5d171c7
feat: add reader_group to define per-model_comm_group read behaviour …
japols Oct 7, 2024
8c16e54
Merge remote-tracking branch 'origin' into feature/improve-dataloader…
japols Oct 7, 2024
ee94593
docs: cleanup, add comments
japols Oct 9, 2024
3c6b5c9
refactor: Pass model/reader group information from DDPGroupStrategy i…
japols Oct 9, 2024
57a13c5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 9, 2024
bcd0fe6
Merge branch 'develop' into feature/improve-dataloader-memory
gabrieloks Oct 24, 2024
9a22225
feat: Add support for sharded reading in dataloader
japols Oct 24, 2024
6615c97
refactor: merge read groups with sharded reading functionality in dat…
japols Oct 25, 2024
93a7e06
refactor: adress PR review comments
japols Oct 28, 2024
ef9fad9
Merge remote-tracking branch 'origin/develop' into feature/improve-da…
japols Oct 31, 2024
1d4f779
Merge remote-tracking branch 'origin/develop' into feature/improve-da…
japols Nov 11, 2024
3b69b33
fix: allgather batch in callbacks
japols Nov 11, 2024
937943b
Merge branch 'develop' into feature/improve-dataloader-memory
HCookie Nov 15, 2024
3c01e14
cleanup, async plot warning
japols Nov 18, 2024
24cab45
Merge branch 'develop' into feature/improve-dataloader-memory
japols Nov 18, 2024
05665c7
fix: changelog, read_group default
japols Nov 19, 2024
aea0ad7
docs: docstring, fix: warning only on rank 0
japols Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Keep it human-readable, your future self will thank you!
- Feat: Anemoi Profiler compatible with mlflow and using Pytorch (Kineto) Profiler for memory report [38](https://github.com/ecmwf/anemoi-training/pull/38/)
- New limited area config file added, limited_area.yaml. [#134](https://github.com/ecmwf/anemoi-training/pull/134/)
- New stretched grid config added, stretched_grid.yaml [#133](https://github.com/ecmwf/anemoi-training/pull/133)
- Add reader groups to reduce CPU memory usage and increase dataloader throughput [#76](https://github.com/ecmwf/anemoi-training/pull/76)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to changelog bot fun, this is in the wrong space, could you put it in the unrelease bit please. (Sorry for the chore)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done: 05665c7


### Changed

Expand Down
11 changes: 11 additions & 0 deletions src/anemoi/training/config/dataloader/native_grid.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
prefetch_factor: 2
pin_memory: True

# ============
# read_group_size:
# Form subgroups of model comm groups that read data together.
# Each reader in the group only reads 1/read_group_size of the data
# which is then all-gathered between the group.
# This can reduce CPU memory usage as well as increase dataloader throughput.
# The number of GPUs per model must be divisible by read_group_size.
# Good values are num_gpus_per_model or num_gpus_per_node.
# ============
read_group_size: 1

num_workers:
training: 8
validation: 8
Expand Down
29 changes: 0 additions & 29 deletions src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


import logging
import os
from functools import cached_property
from typing import Callable

Expand Down Expand Up @@ -43,31 +42,6 @@ def __init__(self, config: DictConfig) -> None:

self.config = config

self.global_rank = int(os.environ.get("SLURM_PROCID", "0")) # global rank
self.model_comm_group_id = (
self.global_rank // self.config.hardware.num_gpus_per_model
) # id of the model communication group the rank is participating in
self.model_comm_group_rank = (
self.global_rank % self.config.hardware.num_gpus_per_model
) # rank within one model communication group
total_gpus = self.config.hardware.num_gpus_per_node * self.config.hardware.num_nodes
assert (
total_gpus
) % self.config.hardware.num_gpus_per_model == 0, (
f"GPUs per model {self.config.hardware.num_gpus_per_model} does not divide total GPUs {total_gpus}"
)
self.model_comm_num_groups = (
self.config.hardware.num_gpus_per_node
* self.config.hardware.num_nodes
// self.config.hardware.num_gpus_per_model
) # number of model communication groups
LOGGER.debug(
"Rank %d model communication group number %d, with local model communication group rank %d",
self.global_rank,
self.model_comm_group_id,
self.model_comm_group_rank,
)

# Set the maximum rollout to be expected
self.rollout = (
self.config.training.rollout.max
Expand Down Expand Up @@ -182,9 +156,6 @@ def _get_dataset(
rollout=r,
multistep=self.config.training.multistep_input,
timeincrement=self.timeincrement,
model_comm_group_rank=self.model_comm_group_rank,
model_comm_group_id=self.model_comm_group_id,
model_comm_num_groups=self.model_comm_num_groups,
shuffle=shuffle,
label=label,
)
Expand Down
82 changes: 67 additions & 15 deletions src/anemoi/training/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ def __init__(
rollout: int = 1,
multistep: int = 1,
timeincrement: int = 1,
model_comm_group_rank: int = 0,
model_comm_group_id: int = 0,
model_comm_num_groups: int = 1,
shuffle: bool = True,
label: str = "generic",
) -> None:
Expand All @@ -54,12 +51,6 @@ def __init__(
time increment between samples, by default 1
multistep : int, optional
collate (t-1, ... t - multistep) into the input state vector, by default 1
model_comm_group_rank : int, optional
process rank in the torch.distributed group (important when running on multiple GPUs), by default 0
model_comm_group_id: int, optional
device group ID, default 0
model_comm_num_groups : int, optional
total number of device groups, by default 1
shuffle : bool, optional
Shuffle batches, by default True
label : str, optional
Expand All @@ -77,11 +68,14 @@ def __init__(
self.n_samples_per_epoch_total: int = 0
self.n_samples_per_epoch_per_worker: int = 0

# DDP-relevant info
self.model_comm_group_rank = model_comm_group_rank
self.model_comm_num_groups = model_comm_num_groups
self.model_comm_group_id = model_comm_group_id
self.global_rank = int(os.environ.get("SLURM_PROCID", "0"))
# lazy init model and reader group info, will be set by the DDPGroupStrategy:
self.model_comm_group_rank = 0
self.model_comm_num_groups = 1
self.model_comm_group_id = 0
self.global_rank = 0

self.reader_group_rank = 0
self.reader_group_size = 1

# additional state vars (lazy init)
self.n_samples_per_worker = 0
Expand All @@ -93,6 +87,8 @@ def __init__(
assert self.multi_step > 0, "Multistep value must be greater than zero."
self.ensemble_dim: int = 2
self.ensemble_size = self.data.shape[self.ensemble_dim]
self.grid_dim: int = -1
self.grid_size = self.data.shape[self.grid_dim]

@cached_property
def statistics(self) -> dict:
Expand Down Expand Up @@ -128,6 +124,58 @@ def valid_date_indices(self) -> np.ndarray:
"""
return get_usable_indices(self.data.missing, len(self.data), self.rollout, self.multi_step, self.timeincrement)

def set_comm_group_info(
self,
global_rank: int,
model_comm_group_id: int,
model_comm_group_rank: int,
model_comm_num_groups: int,
reader_group_rank: int,
reader_group_size: int,
) -> None:
"""Set model and reader communication group information (called by DDPGroupStrategy).

Parameters
----------
global_rank : int
Global rank
model_comm_group_id : int
Model communication group ID
model_comm_group_rank : int
Model communication group rank
model_comm_num_groups : int
Number of model communication groups
reader_group_rank : int
Reader group rank
reader_group_size : int
Reader group size
"""
self.global_rank = global_rank
self.model_comm_group_id = model_comm_group_id
self.model_comm_group_rank = model_comm_group_rank
self.model_comm_num_groups = model_comm_num_groups
self.reader_group_rank = reader_group_rank
self.reader_group_size = reader_group_size

if self.reader_group_size > 1:
# get the grid shard size and start/end indices
grid_shard_size = self.grid_size // self.reader_group_size
self.grid_start = self.reader_group_rank * grid_shard_size
if self.reader_group_rank == self.reader_group_size - 1:
mishooax marked this conversation as resolved.
Show resolved Hide resolved
self.grid_end = self.grid_size
else:
self.grid_end = (self.reader_group_rank + 1) * grid_shard_size

LOGGER.debug(
"NativeGridDataset.set_group_info(): global_rank %d, model_comm_group_id %d, "
"model_comm_group_rank %d, model_comm_num_groups %d, reader_group_rank %d",
global_rank,
model_comm_group_id,
model_comm_group_rank,
model_comm_num_groups,
reader_group_rank,
)

def per_worker_init(self, n_workers: int, worker_id: int) -> None:
"""Called by worker_init_func on each copy of dataset.

Expand Down Expand Up @@ -233,7 +281,11 @@ def __iter__(self) -> torch.Tensor:
start = i - (self.multi_step - 1) * self.timeincrement
end = i + (self.rollout + 1) * self.timeincrement

x = self.data[start : end : self.timeincrement]
if self.reader_group_size > 1: # read only a subset of the grid
x = self.data[start : end : self.timeincrement, :, :, self.grid_start : self.grid_end]
mishooax marked this conversation as resolved.
Show resolved Hide resolved
else: # read the full grid
x = self.data[start : end : self.timeincrement, :, :, :]

x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables")
self.ensemble_dim = 1

Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/training/diagnostics/callbacks/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import torch
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_only

if TYPE_CHECKING:
import pytorch_lightning as pl
Expand Down Expand Up @@ -103,7 +102,6 @@ def _log(self, pl_module: pl.LightningModule, loss: torch.Tensor, metrics: dict,
rank_zero_only=True,
)

@rank_zero_only
def on_validation_batch_end(
self,
trainer: pl.Trainer,
Expand All @@ -114,6 +112,8 @@ def on_validation_batch_end(
) -> None:
del outputs # outputs are not used
if batch_idx % self.every_n_batches == 0:
batch = pl_module.allgather_batch(batch)

precision_mapping = {
"16-mixed": torch.float16,
"bf16-mixed": torch.bfloat16,
Expand Down
9 changes: 7 additions & 2 deletions src/anemoi/training/diagnostics/callbacks/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ def __init__(self, config: OmegaConf, every_n_batches: int | None = None):
super().__init__(config)
self.every_n_batches = every_n_batches or self.config.diagnostics.plot.frequency.batch

@rank_zero_only
def on_validation_batch_end(
self,
trainer: pl.Trainer,
Expand All @@ -220,7 +219,12 @@ def on_validation_batch_end(
batch_idx: int,
**kwargs,
) -> None:
if self.config.diagnostics.plot.asynchronous and self.config.dataloader.read_group_size > 1:
LOGGER.warning("Asynchronous plotting can result in NCCL timeouts with reader_group_size > 1.")

if batch_idx % self.every_n_batches == 0:
batch = pl_module.allgather_batch(batch)

self.plot(
trainer,
pl_module,
Expand Down Expand Up @@ -407,7 +411,6 @@ def _plot(
int(time.time() - start_time),
)

@rank_zero_only
def on_validation_batch_end(
self,
trainer: pl.Trainer,
Expand All @@ -417,6 +420,8 @@ def on_validation_batch_end(
batch_idx: int,
) -> None:
if (batch_idx) == 0 and (trainer.current_epoch + 1) % self.every_n_epochs == 0:
batch = pl_module.allgather_batch(batch)

precision_mapping = {
"16-mixed": torch.float16,
"bf16-mixed": torch.bfloat16,
Expand Down
Loading
Loading