Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to deal with uneven inputs in DDP with sharded data without hanging #20404

Open
ssharpe42 opened this issue Nov 7, 2024 · 0 comments
Open
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x

Comments

@ssharpe42
Copy link

ssharpe42 commented Nov 7, 2024

Bug description

This may partially be a feature request, question, and unwanted behavior all in one. I would like to figure out a valid way to use different amounts of data on each gpu process with DDP training and validation with large iterable datasets. When using lightning trainer as is, the training hangs.

For training I have come up with a workaround to use a dataloader that infinitely loops over the data on each GPU process and uses max_steps instead of max_epochs. However, for evaluation/validation I am unsure of a workaround using torchmetrics to produce valid metrics and not duplicate data.

Please see the script below.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

import os
import shutil

import lightning as L
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from lightning.fabric.plugins.environments.torchelastic import TorchElasticEnvironment
from torch import nn
from torch.utils.data import DataLoader
from torchmetrics.classification.auroc import AUROC


def main(args):
    print(args)
    env = TorchElasticEnvironment()

    if env.local_rank() == 0:
        path = "example-dataset"

        if os.path.isdir(path):
            shutil.rmtree(path)

        if not os.path.exists(path):
            os.mkdir(path)

        partition_sizes = [10, 20]
        total = 0
        for i, size in enumerate(partition_sizes):
            data = pd.DataFrame(
                {
                    "id": list(range(total, total + size)),
                    "inputs": [np.random.rand(5).tolist() for _ in range(size)],
                    "labels": np.random.randint(0, 2, size).tolist(),
                }
            )

            data.to_parquet(os.path.join(path, f"data{i}.parquet"))
            total += size

    class Model(L.LightningModule):
        def __init__(self):
            super().__init__()
            self.model = nn.Linear(5, 2)
            self.auroc = AUROC(task="binary")

        def training_step(self, batch, batch_idx):
            # training_step defines the train loop.
            print(f"{self.trainer.global_rank}: {batch['id'].cpu().numpy().tolist()} ")
            batch["inputs"] = torch.vstack(batch["inputs"]).float()
            y_hat = self.model(batch["inputs"])
            loss = F.cross_entropy(y_hat, batch["labels"])
            return loss

        def validation_step(self, batch, batch_idx):
            batch["inputs"] = torch.vstack(batch["inputs"]).float()
            y_hat = self.model(batch["inputs"])
            loss = F.cross_entropy(y_hat, batch["labels"])
            self.auroc(torch.softmax(y_hat, -1)[:, 1], batch["labels"])
            self.log(
                "loss", loss, on_epoch=True, prog_bar=True, logger=True, sync_dist=True
            )
            self.log(
                "auroc",
                self.auroc,
                on_epoch=True,
                prog_bar=True,
                logger=True,
                sync_dist=True,
            )

        def configure_optimizers(self):
            optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
            return optimizer

    # Load dataset
    dataset = load_dataset(
        "parquet",
        data_files=[
            "example-dataset/data0.parquet",
            "example-dataset/data1.parquet",
        ],
        split="train",
        streaming=True,
    )
    dataset = split_dataset_by_node(
        dataset, rank=env.global_rank(), world_size=env.world_size()
    )
    model = Model()

    if args.normal_dataloader:

        # Train model
        train_dl = DataLoader(dataset, batch_size=5)
        val_dl = DataLoader(dataset, batch_size=5)
        trainer = L.Trainer(
            accelerator="gpu",
            strategy="ddp",
            devices=env.world_size(),
            num_nodes=1,
            max_epochs=1,
        )
        trainer.fit(model, train_dl, val_dl)

    # Solutions for training #

    # 1. Infitinite dataloader - keep cycling over data and use max_steps instead
    class InfiniteDataLoader(DataLoader):
        """
        Dataloader that continually cycles over the dataset
        """

        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            # Initialize an iterator over the dataset.
            self.dataset_iterator = super().__iter__()
            self.epoch = 0
            self.iters = 0

        def __iter__(self):
            return self

        def __next__(self):
            try:
                batch = next(self.dataset_iterator)
            except StopIteration:
                # Dataset exhausted, use a new fresh iterator.
                self.increment_epoch()
                self.dataset_iterator = super().__iter__()
                batch = next(self.dataset_iterator)
            self.iters += 1
            return batch

        def set_epoch(self, epoch: int):
            "Set iteration for the dataset generator seed for shuffling"

            # We support if a custom `Dataset` implementation has `set_epoch`
            # or in general HF datasets `Datasets`
            if hasattr(self.dataset, "set_epoch"):
                self.dataset.set_epoch(epoch)

        def increment_epoch(self):
            self.epoch += 1
            self.iters = 0
            self.set_epoch(self.epoch)

    if args.infinite_dataloader:
        train_dl = InfiniteDataLoader(dataset, batch_size=5)
        val_dl = DataLoader(dataset, batch_size=5)
        trainer = L.Trainer(
            accelerator="gpu",
            strategy="ddp",
            devices=env.world_size(),
            num_nodes=1,
            max_steps=4,
        )
        trainer.fit(model, train_dl, val_dl)

    # Solutions for eval #
    # 1. Load into memory and reshard --- would like to avoid this


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--normal-dataloader",
        action="store_true",
        help="Run normal dataloader with a sharded dataset",
    )
    parser.add_argument(
        "--infinite-dataloader",
        action="store_true",
        help="Run with an infinite dataloader with a sharded dataset using max_steps",
    )
    main(parser.parse_args())

Error messages and logs

Running torchrun --nproc-per-node 2 example_ddp.py --normal-dataloader results in the process hanging since there is uneven data.

Epoch 0: |                                                                                                                                                                                                                                                                                                    | 0/? [00:00<?, ?it/s]1: [10, 11, 12, 13, 14] 
0: [0, 1, 2, 3, 4] 

Epoch 0: |                                                                                                                                                                                                                                                                                  | 1/? [00:00<00:00, 20.45it/s, v_num=57]
0: [5, 6, 7, 8, 9] 

1: [15, 16, 17, 18, 19] 
Epoch 0: |                                                                                                                                                                                                                                                                                  | 2/? [00:00<00:00, 38.07it/s, v_num=57]

1: [20, 21, 22, 23, 24] 

Validation DataLoader 0: |       

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA A10G
    - NVIDIA A10G
    - NVIDIA A10G
    - NVIDIA A10G
    - available: True
    - version: 12.4
  • Lightning:
    - lightning: 2.4.0
    - lightning-utilities: 0.11.8
    - pytorch-lightning: 2.4.0
    - torch: 2.5.1
    - torchmetrics: 1.5.1
  • Packages:
    - absl-py: 2.1.0
    - accelerate: 0.34.2
    - aiohappyeyeballs: 2.4.3
    - aiohttp: 3.10.10
    - aiosignal: 1.3.1
    - antlr4-python3-runtime: 4.9.3
    - astroid: 3.3.5
    - asttokens: 2.4.1
    - async-timeout: 4.0.3
    - attrs: 24.2.0
    - autocommand: 2.2.2
    - autoflake: 2.3.1
    - autopep8: 2.3.1
    - backports.tarfile: 1.2.0
    - black: 24.10.0
    - boto3: 1.35.54
    - botocore: 1.35.54
    - c1-cube-versioning: 0.2.8
    - c1-fm-model: 0.2.0
    - certifi: 2024.8.30
    - cfgv: 3.4.0
    - charset-normalizer: 3.4.0
    - click: 8.1.7
    - comm: 0.2.2
    - contourpy: 1.3.0
    - coverage: 7.6.4
    - cramjam: 2.9.0
    - cycler: 0.12.1
    - datasets: 2.18.0
    - debugpy: 1.8.7
    - decorator: 5.1.1
    - dill: 0.3.8
    - distlib: 0.3.9
    - evaluate: 0.4.3
    - exceptiongroup: 1.2.2
    - executing: 2.1.0
    - fastparquet: 2024.5.0
    - filelock: 3.16.1
    - flake8: 7.1.1
    - fonttools: 4.54.1
    - frozenlist: 1.5.0
    - fsspec: 2024.2.0
    - gitdb: 4.0.11
    - gitpython: 3.1.43
    - grpcio: 1.67.1
    - huggingface-hub: 0.26.2
    - hydra-callbacks: 0.6.1
    - hydra-core: 1.3.2
    - identify: 2.6.1
    - idna: 3.10
    - importlib-metadata: 8.5.0
    - importlib-resources: 6.4.5
    - inflect: 7.3.1
    - iniconfig: 2.0.0
    - intake: 2.0.7
    - ipykernel: 6.29.5
    - ipython: 8.18.1
    - isort: 5.13.2
    - jaraco.collections: 5.1.0
    - jaraco.context: 5.3.0
    - jaraco.functools: 4.0.1
    - jaraco.text: 3.12.1
    - jedi: 0.19.1
    - jinja2: 3.1.4
    - jmespath: 1.0.1
    - joblib: 1.4.2
    - jsonpath-ng: 1.6.1
    - jupyter-client: 8.6.3
    - jupyter-core: 5.7.2
    - kiwisolver: 1.4.7
    - lightning: 2.4.0
    - lightning-utilities: 0.11.8
    - markdown: 3.7
    - markdown-it-py: 3.0.0
    - markupsafe: 3.0.2
    - matplotlib: 3.9.2
    - matplotlib-inline: 0.1.7
    - mccabe: 0.7.0
    - mdurl: 0.1.2
    - more-itertools: 10.3.0
    - mpmath: 1.3.0
    - multidict: 6.1.0
    - multiprocess: 0.70.16
    - mypy: 1.13.0
    - mypy-extensions: 1.0.0
    - nbqa: 1.9.0
    - nest-asyncio: 1.6.0
    - networkx: 3.2.1
    - nodeenv: 1.9.1
    - numpy: 1.26.4
    - nvidia-cublas-cu12: 12.4.5.8
    - nvidia-cuda-cupti-cu12: 12.4.127
    - nvidia-cuda-nvrtc-cu12: 12.4.127
    - nvidia-cuda-runtime-cu12: 12.4.127
    - nvidia-cudnn-cu12: 9.1.0.70
    - nvidia-cufft-cu12: 11.2.1.3
    - nvidia-curand-cu12: 10.3.5.147
    - nvidia-cusolver-cu12: 11.6.1.9
    - nvidia-cusparse-cu12: 12.3.1.170
    - nvidia-nccl-cu12: 2.21.5
    - nvidia-nvjitlink-cu12: 12.4.127
    - nvidia-nvtx-cu12: 12.4.127
    - omegaconf: 2.3.0
    - packaging: 24.1
    - pandas: 1.5.3
    - parso: 0.8.4
    - pathspec: 0.12.1
    - pexpect: 4.9.0
    - pickleshare: 0.7.5
    - pillow: 11.0.0
    - pip: 24.3.1
    - platformdirs: 4.3.6
    - pluggy: 1.5.0
    - ply: 3.11
    - pre-commit: 4.0.1
    - pre-commit-hooks: 5.0.0
    - prompt-toolkit: 3.0.48
    - propcache: 0.2.0
    - protobuf: 5.28.3
    - psutil: 6.1.0
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.3
    - pyarrow: 14.0.1
    - pyarrow-hotfix: 0.6
    - pycodestyle: 2.12.1
    - pydantic: 1.10.18
    - pyflakes: 3.2.0
    - pygments: 2.18.0
    - pylint: 3.3.1
    - pyparsing: 3.2.0
    - pyrootutils: 1.0.4
    - pytest: 8.3.3
    - pytest-cov: 6.0.0
    - pytest-mock: 3.14.0
    - python-dateutil: 2.9.0
    - python-dotenv: 1.0.1
    - pytorch-lightning: 2.4.0
    - pytz: 2024.2
    - pyyaml: 6.0.1
    - pyzmq: 26.2.0
    - regex: 2024.9.11
    - requests: 2.32.3
    - rich: 13.9.4
    - ruamel.yaml: 0.18.6
    - ruamel.yaml.clib: 0.2.12
    - rubicon-ml: 0.10.3
    - s3fs: 0.4.2
    - s3transfer: 0.10.3
    - safetensors: 0.4.5
    - scikit-learn: 1.5.2
    - scipy: 1.13.1
    - seaborn: 0.13.2
    - setuptools: 75.3.0
    - six: 1.16.0
    - smmap: 5.0.1
    - stack-data: 0.6.2
    - sympy: 1.13.1
    - tensorboard: 2.18.0
    - tensorboard-data-server: 0.7.2
    - threadpoolctl: 3.5.0
    - tokenize-rt: 6.1.0
    - tokenizers: 0.20.3
    - tomli: 2.0.2
    - tomlkit: 0.13.2
    - torch: 2.5.1
    - torchmetrics: 1.5.1
    - tornado: 6.4.1
    - tqdm: 4.66.6
    - traitlets: 5.14.3
    - transformers: 4.45.2
    - triton: 3.1.0
    - typeguard: 4.3.0
    - typing-extensions: 4.12.2
    - urllib3: 1.26.20
    - virtualenv: 20.27.1
    - wcwidth: 0.2.13
    - werkzeug: 3.1.2
    - wheel: 0.44.0
    - xxhash: 3.5.0
    - yarl: 1.17.1
    - zipp: 3.20.2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    -
    - processor: x86_64
    - python: 3.9.20
    - release: 5.10.226-214.880.amzn2.x86_64
    - version: Proposal for help #1 SMP Tue Oct 8 16:18:15 UTC 2024

More info

No response

@ssharpe42 ssharpe42 added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Nov 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x
Projects
None yet
Development

No branches or pull requests

1 participant