Description
Bug description
This may partially be a feature request, question, and unwanted behavior all in one. I would like to figure out a valid way to use different amounts of data on each gpu process with DDP training and validation with large iterable datasets. When using lightning trainer as is, the training hangs.
For training I have come up with a workaround to use a dataloader that infinitely loops over the data on each GPU process and uses max_steps instead of max_epochs. However, for evaluation/validation I am unsure of a workaround using torchmetrics to produce valid metrics and not duplicate data.
Please see the script below.
What version are you seeing the problem on?
v2.4
How to reproduce the bug
import os
import shutil
import lightning as L
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from lightning.fabric.plugins.environments.torchelastic import TorchElasticEnvironment
from torch import nn
from torch.utils.data import DataLoader
from torchmetrics.classification.auroc import AUROC
def main(args):
print(args)
env = TorchElasticEnvironment()
if env.local_rank() == 0:
path = "example-dataset"
if os.path.isdir(path):
shutil.rmtree(path)
if not os.path.exists(path):
os.mkdir(path)
partition_sizes = [10, 20]
total = 0
for i, size in enumerate(partition_sizes):
data = pd.DataFrame(
{
"id": list(range(total, total + size)),
"inputs": [np.random.rand(5).tolist() for _ in range(size)],
"labels": np.random.randint(0, 2, size).tolist(),
}
)
data.to_parquet(os.path.join(path, f"data{i}.parquet"))
total += size
class Model(L.LightningModule):
def __init__(self):
super().__init__()
self.model = nn.Linear(5, 2)
self.auroc = AUROC(task="binary")
def training_step(self, batch, batch_idx):
# training_step defines the train loop.
print(f"{self.trainer.global_rank}: {batch['id'].cpu().numpy().tolist()} ")
batch["inputs"] = torch.vstack(batch["inputs"]).float()
y_hat = self.model(batch["inputs"])
loss = F.cross_entropy(y_hat, batch["labels"])
return loss
def validation_step(self, batch, batch_idx):
batch["inputs"] = torch.vstack(batch["inputs"]).float()
y_hat = self.model(batch["inputs"])
loss = F.cross_entropy(y_hat, batch["labels"])
self.auroc(torch.softmax(y_hat, -1)[:, 1], batch["labels"])
self.log(
"loss", loss, on_epoch=True, prog_bar=True, logger=True, sync_dist=True
)
self.log(
"auroc",
self.auroc,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
# Load dataset
dataset = load_dataset(
"parquet",
data_files=[
"example-dataset/data0.parquet",
"example-dataset/data1.parquet",
],
split="train",
streaming=True,
)
dataset = split_dataset_by_node(
dataset, rank=env.global_rank(), world_size=env.world_size()
)
model = Model()
if args.normal_dataloader:
# Train model
train_dl = DataLoader(dataset, batch_size=5)
val_dl = DataLoader(dataset, batch_size=5)
trainer = L.Trainer(
accelerator="gpu",
strategy="ddp",
devices=env.world_size(),
num_nodes=1,
max_epochs=1,
)
trainer.fit(model, train_dl, val_dl)
# Solutions for training #
# 1. Infitinite dataloader - keep cycling over data and use max_steps instead
class InfiniteDataLoader(DataLoader):
"""
Dataloader that continually cycles over the dataset
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Initialize an iterator over the dataset.
self.dataset_iterator = super().__iter__()
self.epoch = 0
self.iters = 0
def __iter__(self):
return self
def __next__(self):
try:
batch = next(self.dataset_iterator)
except StopIteration:
# Dataset exhausted, use a new fresh iterator.
self.increment_epoch()
self.dataset_iterator = super().__iter__()
batch = next(self.dataset_iterator)
self.iters += 1
return batch
def set_epoch(self, epoch: int):
"Set iteration for the dataset generator seed for shuffling"
# We support if a custom `Dataset` implementation has `set_epoch`
# or in general HF datasets `Datasets`
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
def increment_epoch(self):
self.epoch += 1
self.iters = 0
self.set_epoch(self.epoch)
if args.infinite_dataloader:
train_dl = InfiniteDataLoader(dataset, batch_size=5)
val_dl = DataLoader(dataset, batch_size=5)
trainer = L.Trainer(
accelerator="gpu",
strategy="ddp",
devices=env.world_size(),
num_nodes=1,
max_steps=4,
)
trainer.fit(model, train_dl, val_dl)
# Solutions for eval #
# 1. Load into memory and reshard --- would like to avoid this
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--normal-dataloader",
action="store_true",
help="Run normal dataloader with a sharded dataset",
)
parser.add_argument(
"--infinite-dataloader",
action="store_true",
help="Run with an infinite dataloader with a sharded dataset using max_steps",
)
main(parser.parse_args())
Error messages and logs
Running torchrun --nproc-per-node 2 example_ddp.py --normal-dataloader
results in the process hanging since there is uneven data.
Epoch 0: | | 0/? [00:00<?, ?it/s]1: [10, 11, 12, 13, 14]
0: [0, 1, 2, 3, 4]
Epoch 0: | | 1/? [00:00<00:00, 20.45it/s, v_num=57]
0: [5, 6, 7, 8, 9]
1: [15, 16, 17, 18, 19]
Epoch 0: | | 2/? [00:00<00:00, 38.07it/s, v_num=57]
1: [20, 21, 22, 23, 24]
Validation DataLoader 0: |
Environment
Current environment
- CUDA:
- GPU:
- NVIDIA A10G
- NVIDIA A10G
- NVIDIA A10G
- NVIDIA A10G
- available: True
- version: 12.4 - Lightning:
- lightning: 2.4.0
- lightning-utilities: 0.11.8
- pytorch-lightning: 2.4.0
- torch: 2.5.1
- torchmetrics: 1.5.1 - Packages:
- absl-py: 2.1.0
- accelerate: 0.34.2
- aiohappyeyeballs: 2.4.3
- aiohttp: 3.10.10
- aiosignal: 1.3.1
- antlr4-python3-runtime: 4.9.3
- astroid: 3.3.5
- asttokens: 2.4.1
- async-timeout: 4.0.3
- attrs: 24.2.0
- autocommand: 2.2.2
- autoflake: 2.3.1
- autopep8: 2.3.1
- backports.tarfile: 1.2.0
- black: 24.10.0
- boto3: 1.35.54
- botocore: 1.35.54
- c1-cube-versioning: 0.2.8
- c1-fm-model: 0.2.0
- certifi: 2024.8.30
- cfgv: 3.4.0
- charset-normalizer: 3.4.0
- click: 8.1.7
- comm: 0.2.2
- contourpy: 1.3.0
- coverage: 7.6.4
- cramjam: 2.9.0
- cycler: 0.12.1
- datasets: 2.18.0
- debugpy: 1.8.7
- decorator: 5.1.1
- dill: 0.3.8
- distlib: 0.3.9
- evaluate: 0.4.3
- exceptiongroup: 1.2.2
- executing: 2.1.0
- fastparquet: 2024.5.0
- filelock: 3.16.1
- flake8: 7.1.1
- fonttools: 4.54.1
- frozenlist: 1.5.0
- fsspec: 2024.2.0
- gitdb: 4.0.11
- gitpython: 3.1.43
- grpcio: 1.67.1
- huggingface-hub: 0.26.2
- hydra-callbacks: 0.6.1
- hydra-core: 1.3.2
- identify: 2.6.1
- idna: 3.10
- importlib-metadata: 8.5.0
- importlib-resources: 6.4.5
- inflect: 7.3.1
- iniconfig: 2.0.0
- intake: 2.0.7
- ipykernel: 6.29.5
- ipython: 8.18.1
- isort: 5.13.2
- jaraco.collections: 5.1.0
- jaraco.context: 5.3.0
- jaraco.functools: 4.0.1
- jaraco.text: 3.12.1
- jedi: 0.19.1
- jinja2: 3.1.4
- jmespath: 1.0.1
- joblib: 1.4.2
- jsonpath-ng: 1.6.1
- jupyter-client: 8.6.3
- jupyter-core: 5.7.2
- kiwisolver: 1.4.7
- lightning: 2.4.0
- lightning-utilities: 0.11.8
- markdown: 3.7
- markdown-it-py: 3.0.0
- markupsafe: 3.0.2
- matplotlib: 3.9.2
- matplotlib-inline: 0.1.7
- mccabe: 0.7.0
- mdurl: 0.1.2
- more-itertools: 10.3.0
- mpmath: 1.3.0
- multidict: 6.1.0
- multiprocess: 0.70.16
- mypy: 1.13.0
- mypy-extensions: 1.0.0
- nbqa: 1.9.0
- nest-asyncio: 1.6.0
- networkx: 3.2.1
- nodeenv: 1.9.1
- numpy: 1.26.4
- nvidia-cublas-cu12: 12.4.5.8
- nvidia-cuda-cupti-cu12: 12.4.127
- nvidia-cuda-nvrtc-cu12: 12.4.127
- nvidia-cuda-runtime-cu12: 12.4.127
- nvidia-cudnn-cu12: 9.1.0.70
- nvidia-cufft-cu12: 11.2.1.3
- nvidia-curand-cu12: 10.3.5.147
- nvidia-cusolver-cu12: 11.6.1.9
- nvidia-cusparse-cu12: 12.3.1.170
- nvidia-nccl-cu12: 2.21.5
- nvidia-nvjitlink-cu12: 12.4.127
- nvidia-nvtx-cu12: 12.4.127
- omegaconf: 2.3.0
- packaging: 24.1
- pandas: 1.5.3
- parso: 0.8.4
- pathspec: 0.12.1
- pexpect: 4.9.0
- pickleshare: 0.7.5
- pillow: 11.0.0
- pip: 24.3.1
- platformdirs: 4.3.6
- pluggy: 1.5.0
- ply: 3.11
- pre-commit: 4.0.1
- pre-commit-hooks: 5.0.0
- prompt-toolkit: 3.0.48
- propcache: 0.2.0
- protobuf: 5.28.3
- psutil: 6.1.0
- ptyprocess: 0.7.0
- pure-eval: 0.2.3
- pyarrow: 14.0.1
- pyarrow-hotfix: 0.6
- pycodestyle: 2.12.1
- pydantic: 1.10.18
- pyflakes: 3.2.0
- pygments: 2.18.0
- pylint: 3.3.1
- pyparsing: 3.2.0
- pyrootutils: 1.0.4
- pytest: 8.3.3
- pytest-cov: 6.0.0
- pytest-mock: 3.14.0
- python-dateutil: 2.9.0
- python-dotenv: 1.0.1
- pytorch-lightning: 2.4.0
- pytz: 2024.2
- pyyaml: 6.0.1
- pyzmq: 26.2.0
- regex: 2024.9.11
- requests: 2.32.3
- rich: 13.9.4
- ruamel.yaml: 0.18.6
- ruamel.yaml.clib: 0.2.12
- rubicon-ml: 0.10.3
- s3fs: 0.4.2
- s3transfer: 0.10.3
- safetensors: 0.4.5
- scikit-learn: 1.5.2
- scipy: 1.13.1
- seaborn: 0.13.2
- setuptools: 75.3.0
- six: 1.16.0
- smmap: 5.0.1
- stack-data: 0.6.2
- sympy: 1.13.1
- tensorboard: 2.18.0
- tensorboard-data-server: 0.7.2
- threadpoolctl: 3.5.0
- tokenize-rt: 6.1.0
- tokenizers: 0.20.3
- tomli: 2.0.2
- tomlkit: 0.13.2
- torch: 2.5.1
- torchmetrics: 1.5.1
- tornado: 6.4.1
- tqdm: 4.66.6
- traitlets: 5.14.3
- transformers: 4.45.2
- triton: 3.1.0
- typeguard: 4.3.0
- typing-extensions: 4.12.2
- urllib3: 1.26.20
- virtualenv: 20.27.1
- wcwidth: 0.2.13
- werkzeug: 3.1.2
- wheel: 0.44.0
- xxhash: 3.5.0
- yarl: 1.17.1
- zipp: 3.20.2 - System:
- OS: Linux
- architecture:
- 64bit
-
- processor: x86_64
- python: 3.9.20
- release: 5.10.226-214.880.amzn2.x86_64
- version: Proposal for help #1 SMP Tue Oct 8 16:18:15 UTC 2024
More info
No response
cc @Borda