Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/improve dataloader memory #76

Open
wants to merge 19 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
cdaf082
feat: initial implementation of dataloader memory optimization
japols Sep 24, 2024
fcc7c93
fix: non-reader tasks actually return before reading
japols Oct 2, 2024
5d171c7
feat: add reader_group to define per-model_comm_group read behaviour …
japols Oct 7, 2024
8c16e54
Merge remote-tracking branch 'origin' into feature/improve-dataloader…
japols Oct 7, 2024
ee94593
docs: cleanup, add comments
japols Oct 9, 2024
3c6b5c9
refactor: Pass model/reader group information from DDPGroupStrategy i…
japols Oct 9, 2024
57a13c5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 9, 2024
bcd0fe6
Merge branch 'develop' into feature/improve-dataloader-memory
gabrieloks Oct 24, 2024
9a22225
feat: Add support for sharded reading in dataloader
japols Oct 24, 2024
6615c97
refactor: merge read groups with sharded reading functionality in dat…
japols Oct 25, 2024
93a7e06
refactor: adress PR review comments
japols Oct 28, 2024
ef9fad9
Merge remote-tracking branch 'origin/develop' into feature/improve-da…
japols Oct 31, 2024
1d4f779
Merge remote-tracking branch 'origin/develop' into feature/improve-da…
japols Nov 11, 2024
3b69b33
fix: allgather batch in callbacks
japols Nov 11, 2024
937943b
Merge branch 'develop' into feature/improve-dataloader-memory
HCookie Nov 15, 2024
3c01e14
cleanup, async plot warning
japols Nov 18, 2024
24cab45
Merge branch 'develop' into feature/improve-dataloader-memory
japols Nov 18, 2024
05665c7
fix: changelog, read_group default
japols Nov 19, 2024
aea0ad7
docs: docstring, fix: warning only on rank 0
japols Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Keep it human-readable, your future self will thank you!
- Add anemoi-transform link to documentation
- Codeowners file (#56)
- Changelog merge strategy (#56)
- Feature: Add reader groups to reduce CPU memory usage [#76](https://github.com/ecmwf/anemoi-training/pull/76)
japols marked this conversation as resolved.
Show resolved Hide resolved

#### Miscellaneous

Expand Down
11 changes: 11 additions & 0 deletions src/anemoi/training/config/dataloader/native_grid.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
prefetch_factor: 2
pin_memory: True

# ============
# read_group_size:
# Form subgroups of model comm groups that read data together.
# Each reader in the group only reads 1/read_group_size of the data
# which is then all-gathered between the group.
# This can reduce CPU memory usage as well as increase dataloader throughput.
# The number of GPUs per model must be divisible by read_group_size.
# Good values are num_gpus_per_model or num_gpus_per_node.
# ============
read_group_size: 1

num_workers:
training: 8
validation: 8
Expand Down
29 changes: 0 additions & 29 deletions src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


import logging
import os
from functools import cached_property
from typing import Callable

Expand Down Expand Up @@ -59,31 +58,6 @@ def __init__(self, config: DictConfig) -> None:
timestep,
)

self.global_rank = int(os.environ.get("SLURM_PROCID", "0")) # global rank
self.model_comm_group_id = (
self.global_rank // self.config.hardware.num_gpus_per_model
) # id of the model communication group the rank is participating in
self.model_comm_group_rank = (
self.global_rank % self.config.hardware.num_gpus_per_model
) # rank within one model communication group
total_gpus = self.config.hardware.num_gpus_per_node * self.config.hardware.num_nodes
assert (
total_gpus
) % self.config.hardware.num_gpus_per_model == 0, (
f"GPUs per model {self.config.hardware.num_gpus_per_model} does not divide total GPUs {total_gpus}"
)
self.model_comm_num_groups = (
self.config.hardware.num_gpus_per_node
* self.config.hardware.num_nodes
// self.config.hardware.num_gpus_per_model
) # number of model communication groups
LOGGER.debug(
"Rank %d model communication group number %d, with local model communication group rank %d",
self.global_rank,
self.model_comm_group_id,
self.model_comm_group_rank,
)

# Set the maximum rollout to be expected
self.rollout = (
self.config.training.rollout.max
Expand Down Expand Up @@ -173,9 +147,6 @@ def _get_dataset(
rollout=r,
multistep=self.config.training.multistep_input,
timeincrement=self.timeincrement,
model_comm_group_rank=self.model_comm_group_rank,
model_comm_group_id=self.model_comm_group_id,
model_comm_num_groups=self.model_comm_num_groups,
shuffle=shuffle,
label=label,
)
Expand Down
81 changes: 66 additions & 15 deletions src/anemoi/training/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ def __init__(
rollout: int = 1,
multistep: int = 1,
timeincrement: int = 1,
model_comm_group_rank: int = 0,
model_comm_group_id: int = 0,
model_comm_num_groups: int = 1,
shuffle: bool = True,
label: str = "generic",
) -> None:
Expand All @@ -54,12 +51,6 @@ def __init__(
time increment between samples, by default 1
multistep : int, optional
collate (t-1, ... t - multistep) into the input state vector, by default 1
model_comm_group_rank : int, optional
process rank in the torch.distributed group (important when running on multiple GPUs), by default 0
model_comm_group_id: int, optional
device group ID, default 0
model_comm_num_groups : int, optional
total number of device groups, by default 1
shuffle : bool, optional
Shuffle batches, by default True
label : str, optional
Expand All @@ -77,11 +68,14 @@ def __init__(
self.n_samples_per_epoch_total: int = 0
self.n_samples_per_epoch_per_worker: int = 0

# DDP-relevant info
self.model_comm_group_rank = model_comm_group_rank
self.model_comm_num_groups = model_comm_num_groups
self.model_comm_group_id = model_comm_group_id
self.global_rank = int(os.environ.get("SLURM_PROCID", "0"))
# lazy init model and reader group info, will be set by the DDPGroupStrategy:
self.model_comm_group_rank = 0
self.model_comm_num_groups = 1
self.model_comm_group_id = 0
self.global_rank = 0

self.reader_group_rank = 0
self.reader_group_size = 1

# additional state vars (lazy init)
self.n_samples_per_worker = 0
Expand All @@ -93,6 +87,7 @@ def __init__(
assert self.multi_step > 0, "Multistep value must be greater than zero."
self.ensemble_dim: int = 2
self.ensemble_size = self.data.shape[self.ensemble_dim]
self.grid_size = self.data.shape[-1]
mishooax marked this conversation as resolved.
Show resolved Hide resolved

@cached_property
def statistics(self) -> dict:
Expand Down Expand Up @@ -128,6 +123,58 @@ def valid_date_indices(self) -> np.ndarray:
"""
return get_usable_indices(self.data.missing, len(self.data), self.rollout, self.multi_step, self.timeincrement)

def set_comm_group_info(
self,
global_rank: int,
model_comm_group_id: int,
model_comm_group_rank: int,
model_comm_num_groups: int,
reader_group_rank: int,
reader_group_size: int,
) -> None:
"""Set model and reader communication group information (called by DDPGroupStrategy).

Parameters
----------
global_rank : int
Global rank
model_comm_group_id : int
Model communication group ID
model_comm_group_rank : int
Model communication group rank
model_comm_num_groups : int
Number of model communication groups
reader_group_rank : int
Reader group rank
reader_group_size : int
Reader group size
"""
self.global_rank = global_rank
self.model_comm_group_id = model_comm_group_id
self.model_comm_group_rank = model_comm_group_rank
self.model_comm_num_groups = model_comm_num_groups
self.reader_group_rank = reader_group_rank
self.reader_group_size = reader_group_size

if self.reader_group_size > 1:
# get the grid shard size and start/end indices
grid_shard_size = self.grid_size // self.reader_group_size
self.grid_start = self.reader_group_rank * grid_shard_size
if self.reader_group_rank == self.reader_group_size - 1:
mishooax marked this conversation as resolved.
Show resolved Hide resolved
self.grid_end = self.grid_size
else:
self.grid_end = (self.reader_group_rank + 1) * grid_shard_size

LOGGER.debug(
"NativeGridDataset.set_group_info(): global_rank %d, model_comm_group_id %d, "
"model_comm_group_rank %d, model_comm_num_groups %d, reader_group_rank %d",
global_rank,
model_comm_group_id,
model_comm_group_rank,
model_comm_num_groups,
reader_group_rank,
)

def per_worker_init(self, n_workers: int, worker_id: int) -> None:
"""Called by worker_init_func on each copy of dataset.

Expand Down Expand Up @@ -233,7 +280,11 @@ def __iter__(self) -> torch.Tensor:
start = i - (self.multi_step - 1) * self.timeincrement
end = i + (self.rollout + 1) * self.timeincrement

x = self.data[start : end : self.timeincrement]
if self.reader_group_size > 1: # read only a subset of the grid
x = self.data[start : end : self.timeincrement, :, :, self.grid_start : self.grid_end]
mishooax marked this conversation as resolved.
Show resolved Hide resolved
else: # read the full grid
x = self.data[start : end : self.timeincrement, :, :, :]

x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables")
self.ensemble_dim = 1

Expand Down
130 changes: 110 additions & 20 deletions src/anemoi/training/distributed/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


import logging
import os

import numpy as np
import pytorch_lightning as pl
Expand All @@ -27,19 +26,22 @@
class DDPGroupStrategy(DDPStrategy):
"""Distributed Data Parallel strategy with group communication."""

def __init__(self, num_gpus_per_model: int, **kwargs: dict) -> None:
def __init__(self, num_gpus_per_model: int, read_group_size: int, **kwargs: dict) -> None:
"""Initialize the distributed strategy.

Parameters
----------
num_gpus_per_model : int
Number of GPUs per model to shard over.
read_group_size : int
Number of GPUs per reader group.
**kwargs : dict
Additional keyword arguments.

"""
super().__init__(**kwargs)
self.model_comm_group_size = num_gpus_per_model
self.read_group_size = read_group_size

def setup(self, trainer: pl.Trainer) -> None:
assert self.accelerator is not None, "Accelerator is not initialized for distributed strategy"
Expand All @@ -60,18 +62,56 @@ def setup(self, trainer: pl.Trainer) -> None:
torch.distributed.new_group(x) for x in model_comm_group_ranks
] # every rank has to create all of these

model_comm_group_id, model_comm_group_nr, model_comm_group_rank = self.get_my_model_comm_group(
model_comm_group_id, model_comm_group_rank, model_comm_num_groups = self.get_my_model_comm_group(
self.model_comm_group_size,
)
model_comm_group = model_comm_groups[model_comm_group_id]
self.model.set_model_comm_group(model_comm_group)
self.model.set_model_comm_group(
model_comm_group,
model_comm_group_id,
model_comm_group_rank,
model_comm_num_groups,
self.model_comm_group_size,
)

# set up reader groups by further splitting model_comm_group_ranks with read_group_size:

assert self.model_comm_group_size % self.read_group_size == 0, (
f"Number of GPUs per model ({self.model_comm_group_size}) must be divisible by read_group_size "
f"({self.read_group_size})."
)

reader_group_ranks = np.array(
[
np.split(group_ranks, int(self.model_comm_group_size / self.read_group_size))
for group_ranks in model_comm_group_ranks
],
) # Shape: (num_model_comm_groups, model_comm_grp_size/read_group_size, read_group_size)
reader_groups = [[torch.distributed.new_group(x) for x in group_ranks] for group_ranks in reader_group_ranks]
reader_group_id, reader_group_rank, reader_group_size, reader_group_root = self.get_my_reader_group(
model_comm_group_rank,
self.read_group_size,
)
# get all reader groups of the current model group
model_reader_groups = reader_groups[model_comm_group_id]
self.model.set_reader_groups(
model_reader_groups,
reader_group_id,
reader_group_rank,
reader_group_size,
)

LOGGER.debug(
"Rank %d model_comm_group is %s, group number %d, with local group rank %d and comms_group_ranks %s",
"Rank %d model_comm_group_id: %d model_comm_group: %s model_comm_group_rank: %d "
"reader_group_id: %d reader_group: %s reader_group_rank: %d reader_group_root (global): %d",
self.global_rank,
str(model_comm_group_nr),
model_comm_group_id,
model_comm_group_rank,
str(model_comm_group_ranks[model_comm_group_id]),
model_comm_group_rank,
reader_group_id,
reader_group_ranks[model_comm_group_id, reader_group_id],
reader_group_rank,
reader_group_root,
)

# register hooks for correct gradient reduction
Expand Down Expand Up @@ -109,7 +149,7 @@ def setup(self, trainer: pl.Trainer) -> None:
# seed ranks
self.seed_rnd(model_comm_group_id)

def get_my_model_comm_group(self, num_gpus_per_model: int) -> tuple[int, np.ndarray, int]:
def get_my_model_comm_group(self, num_gpus_per_model: int) -> tuple[int, int, int]:
"""Determine tasks that work together and from a model group.

Parameters
Expand All @@ -119,19 +159,69 @@ def get_my_model_comm_group(self, num_gpus_per_model: int) -> tuple[int, np.ndar

Returns
-------
tuple[int, np.ndarray, int]
Model_comm_group id, Model_comm_group Nr, Model_comm_group rank
tuple[int, int, int]
Model_comm_group id, Model_comm_group rank, Number of model_comm_groups
"""
model_comm_group_id = self.global_rank // num_gpus_per_model
model_comm_group_rank = self.global_rank % num_gpus_per_model
model_comm_num_groups = self.world_size // num_gpus_per_model

return model_comm_group_id, model_comm_group_rank, model_comm_num_groups

def get_my_reader_group(self, model_comm_group_rank: int, read_group_size: int) -> tuple[int, int, int]:
"""Determine tasks that work together and from a reader group.

Parameters
----------
model_comm_group_rank : int
Rank within the model communication group.
read_group_size : int
Number of dataloader readers per model group.

Returns
-------
tuple[int, int, int]
Reader_group id, Reader_group rank, Reader_group root (global rank)
"""
model_comm_groups = np.arange(0, self.world_size, dtype=np.int32)
model_comm_groups = np.split(model_comm_groups, self.world_size / num_gpus_per_model)
reader_group_id = model_comm_group_rank // read_group_size
reader_group_rank = model_comm_group_rank % read_group_size
reader_group_size = read_group_size
reader_group_root = (self.global_rank // read_group_size) * read_group_size

return reader_group_id, reader_group_rank, reader_group_size, reader_group_root

model_comm_group_id = None
for i, model_comm_group in enumerate(model_comm_groups):
if self.global_rank in model_comm_group:
model_comm_group_id = i
model_comm_group_nr = model_comm_group
model_comm_group_rank = np.ravel(np.asarray(model_comm_group == self.global_rank).nonzero())[0]
return model_comm_group_id, model_comm_group_nr, model_comm_group_rank
def process_dataloader(self, dataloader: torch.utils.data.DataLoader) -> torch.utils.data.DataLoader:
"""Pass communication group information to the dataloader for distributed training.

Parameters
----------
dataloader : torch.utils.data.DataLoader
Dataloader to process.

Returns
-------
torch.utils.data.DataLoader
Processed dataloader.

"""
dataloader = super().process_dataloader(dataloader)

# pass model and reader group information to the dataloaders dataset
model_comm_group_id, model_comm_group_rank, model_comm_num_groups = self.get_my_model_comm_group(
self.model_comm_group_size,
)
_, reader_group_rank, _, _ = self.get_my_reader_group(model_comm_group_rank, self.read_group_size)

dataloader.dataset.set_comm_group_info(
self.global_rank,
model_comm_group_id,
model_comm_group_rank,
model_comm_num_groups,
reader_group_rank,
self.read_group_size,
)

return dataloader

def seed_rnd(self, model_comm_group_id: int) -> None:
"""Seed the random number generators for the rank."""
Expand All @@ -145,7 +235,7 @@ def seed_rnd(self, model_comm_group_id: int) -> None:
"Strategy: Rank %d, model comm group id %d, base seed %d, seeded with %d, "
"running with random seed: %d, sanity rnd: %s"
),
int(os.environ.get("SLURM_PROCID", "0")),
self.global_rank,
model_comm_group_id,
base_seed,
initial_seed,
Expand Down
Loading
Loading