You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This may partially be a feature request, question, and unwanted behavior all in one. I would like to figure out a valid way to use different amounts of data on each gpu process with DDP training and validation with large iterable datasets. When using lightning trainer as is, the training hangs.
For training I have come up with a workaround to use a dataloader that infinitely loops over the data on each GPU process and uses max_steps instead of max_epochs. However, for evaluation/validation I am unsure of a workaround using torchmetrics to produce valid metrics and not duplicate data.
Please see the script below.
What version are you seeing the problem on?
v2.4
How to reproduce the bug
importosimportshutilimportlightningasLimportnumpyasnpimportpandasaspdimporttorchimporttorch.nn.functionalasFfromdatasetsimportload_datasetfromdatasets.distributedimportsplit_dataset_by_nodefromlightning.fabric.plugins.environments.torchelasticimportTorchElasticEnvironmentfromtorchimportnnfromtorch.utils.dataimportDataLoaderfromtorchmetrics.classification.aurocimportAUROCdefmain(args):
print(args)
env=TorchElasticEnvironment()
ifenv.local_rank() ==0:
path="example-dataset"ifos.path.isdir(path):
shutil.rmtree(path)
ifnotos.path.exists(path):
os.mkdir(path)
partition_sizes= [10, 20]
total=0fori, sizeinenumerate(partition_sizes):
data=pd.DataFrame(
{
"id": list(range(total, total+size)),
"inputs": [np.random.rand(5).tolist() for_inrange(size)],
"labels": np.random.randint(0, 2, size).tolist(),
}
)
data.to_parquet(os.path.join(path, f"data{i}.parquet"))
total+=sizeclassModel(L.LightningModule):
def__init__(self):
super().__init__()
self.model=nn.Linear(5, 2)
self.auroc=AUROC(task="binary")
deftraining_step(self, batch, batch_idx):
# training_step defines the train loop.print(f"{self.trainer.global_rank}: {batch['id'].cpu().numpy().tolist()} ")
batch["inputs"] =torch.vstack(batch["inputs"]).float()
y_hat=self.model(batch["inputs"])
loss=F.cross_entropy(y_hat, batch["labels"])
returnlossdefvalidation_step(self, batch, batch_idx):
batch["inputs"] =torch.vstack(batch["inputs"]).float()
y_hat=self.model(batch["inputs"])
loss=F.cross_entropy(y_hat, batch["labels"])
self.auroc(torch.softmax(y_hat, -1)[:, 1], batch["labels"])
self.log(
"loss", loss, on_epoch=True, prog_bar=True, logger=True, sync_dist=True
)
self.log(
"auroc",
self.auroc,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
)
defconfigure_optimizers(self):
optimizer=torch.optim.Adam(self.parameters(), lr=1e-3)
returnoptimizer# Load datasetdataset=load_dataset(
"parquet",
data_files=[
"example-dataset/data0.parquet",
"example-dataset/data1.parquet",
],
split="train",
streaming=True,
)
dataset=split_dataset_by_node(
dataset, rank=env.global_rank(), world_size=env.world_size()
)
model=Model()
ifargs.normal_dataloader:
# Train modeltrain_dl=DataLoader(dataset, batch_size=5)
val_dl=DataLoader(dataset, batch_size=5)
trainer=L.Trainer(
accelerator="gpu",
strategy="ddp",
devices=env.world_size(),
num_nodes=1,
max_epochs=1,
)
trainer.fit(model, train_dl, val_dl)
# Solutions for training ## 1. Infitinite dataloader - keep cycling over data and use max_steps insteadclassInfiniteDataLoader(DataLoader):
""" Dataloader that continually cycles over the dataset """def__init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Initialize an iterator over the dataset.self.dataset_iterator=super().__iter__()
self.epoch=0self.iters=0def__iter__(self):
returnselfdef__next__(self):
try:
batch=next(self.dataset_iterator)
exceptStopIteration:
# Dataset exhausted, use a new fresh iterator.self.increment_epoch()
self.dataset_iterator=super().__iter__()
batch=next(self.dataset_iterator)
self.iters+=1returnbatchdefset_epoch(self, epoch: int):
"Set iteration for the dataset generator seed for shuffling"# We support if a custom `Dataset` implementation has `set_epoch`# or in general HF datasets `Datasets`ifhasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
defincrement_epoch(self):
self.epoch+=1self.iters=0self.set_epoch(self.epoch)
ifargs.infinite_dataloader:
train_dl=InfiniteDataLoader(dataset, batch_size=5)
val_dl=DataLoader(dataset, batch_size=5)
trainer=L.Trainer(
accelerator="gpu",
strategy="ddp",
devices=env.world_size(),
num_nodes=1,
max_steps=4,
)
trainer.fit(model, train_dl, val_dl)
# Solutions for eval ## 1. Load into memory and reshard --- would like to avoid thisif__name__=="__main__":
importargparseparser=argparse.ArgumentParser()
parser.add_argument(
"--normal-dataloader",
action="store_true",
help="Run normal dataloader with a sharded dataset",
)
parser.add_argument(
"--infinite-dataloader",
action="store_true",
help="Run with an infinite dataloader with a sharded dataset using max_steps",
)
main(parser.parse_args())
Error messages and logs
Running torchrun --nproc-per-node 2 example_ddp.py --normal-dataloader results in the process hanging since there is uneven data.
Bug description
This may partially be a feature request, question, and unwanted behavior all in one. I would like to figure out a valid way to use different amounts of data on each gpu process with DDP training and validation with large iterable datasets. When using lightning trainer as is, the training hangs.
For training I have come up with a workaround to use a dataloader that infinitely loops over the data on each GPU process and uses max_steps instead of max_epochs. However, for evaluation/validation I am unsure of a workaround using torchmetrics to produce valid metrics and not duplicate data.
Please see the script below.
What version are you seeing the problem on?
v2.4
How to reproduce the bug
Error messages and logs
Running
torchrun --nproc-per-node 2 example_ddp.py --normal-dataloader
results in the process hanging since there is uneven data.Environment
Current environment
- GPU:
- NVIDIA A10G
- NVIDIA A10G
- NVIDIA A10G
- NVIDIA A10G
- available: True
- version: 12.4
- lightning: 2.4.0
- lightning-utilities: 0.11.8
- pytorch-lightning: 2.4.0
- torch: 2.5.1
- torchmetrics: 1.5.1
- absl-py: 2.1.0
- accelerate: 0.34.2
- aiohappyeyeballs: 2.4.3
- aiohttp: 3.10.10
- aiosignal: 1.3.1
- antlr4-python3-runtime: 4.9.3
- astroid: 3.3.5
- asttokens: 2.4.1
- async-timeout: 4.0.3
- attrs: 24.2.0
- autocommand: 2.2.2
- autoflake: 2.3.1
- autopep8: 2.3.1
- backports.tarfile: 1.2.0
- black: 24.10.0
- boto3: 1.35.54
- botocore: 1.35.54
- c1-cube-versioning: 0.2.8
- c1-fm-model: 0.2.0
- certifi: 2024.8.30
- cfgv: 3.4.0
- charset-normalizer: 3.4.0
- click: 8.1.7
- comm: 0.2.2
- contourpy: 1.3.0
- coverage: 7.6.4
- cramjam: 2.9.0
- cycler: 0.12.1
- datasets: 2.18.0
- debugpy: 1.8.7
- decorator: 5.1.1
- dill: 0.3.8
- distlib: 0.3.9
- evaluate: 0.4.3
- exceptiongroup: 1.2.2
- executing: 2.1.0
- fastparquet: 2024.5.0
- filelock: 3.16.1
- flake8: 7.1.1
- fonttools: 4.54.1
- frozenlist: 1.5.0
- fsspec: 2024.2.0
- gitdb: 4.0.11
- gitpython: 3.1.43
- grpcio: 1.67.1
- huggingface-hub: 0.26.2
- hydra-callbacks: 0.6.1
- hydra-core: 1.3.2
- identify: 2.6.1
- idna: 3.10
- importlib-metadata: 8.5.0
- importlib-resources: 6.4.5
- inflect: 7.3.1
- iniconfig: 2.0.0
- intake: 2.0.7
- ipykernel: 6.29.5
- ipython: 8.18.1
- isort: 5.13.2
- jaraco.collections: 5.1.0
- jaraco.context: 5.3.0
- jaraco.functools: 4.0.1
- jaraco.text: 3.12.1
- jedi: 0.19.1
- jinja2: 3.1.4
- jmespath: 1.0.1
- joblib: 1.4.2
- jsonpath-ng: 1.6.1
- jupyter-client: 8.6.3
- jupyter-core: 5.7.2
- kiwisolver: 1.4.7
- lightning: 2.4.0
- lightning-utilities: 0.11.8
- markdown: 3.7
- markdown-it-py: 3.0.0
- markupsafe: 3.0.2
- matplotlib: 3.9.2
- matplotlib-inline: 0.1.7
- mccabe: 0.7.0
- mdurl: 0.1.2
- more-itertools: 10.3.0
- mpmath: 1.3.0
- multidict: 6.1.0
- multiprocess: 0.70.16
- mypy: 1.13.0
- mypy-extensions: 1.0.0
- nbqa: 1.9.0
- nest-asyncio: 1.6.0
- networkx: 3.2.1
- nodeenv: 1.9.1
- numpy: 1.26.4
- nvidia-cublas-cu12: 12.4.5.8
- nvidia-cuda-cupti-cu12: 12.4.127
- nvidia-cuda-nvrtc-cu12: 12.4.127
- nvidia-cuda-runtime-cu12: 12.4.127
- nvidia-cudnn-cu12: 9.1.0.70
- nvidia-cufft-cu12: 11.2.1.3
- nvidia-curand-cu12: 10.3.5.147
- nvidia-cusolver-cu12: 11.6.1.9
- nvidia-cusparse-cu12: 12.3.1.170
- nvidia-nccl-cu12: 2.21.5
- nvidia-nvjitlink-cu12: 12.4.127
- nvidia-nvtx-cu12: 12.4.127
- omegaconf: 2.3.0
- packaging: 24.1
- pandas: 1.5.3
- parso: 0.8.4
- pathspec: 0.12.1
- pexpect: 4.9.0
- pickleshare: 0.7.5
- pillow: 11.0.0
- pip: 24.3.1
- platformdirs: 4.3.6
- pluggy: 1.5.0
- ply: 3.11
- pre-commit: 4.0.1
- pre-commit-hooks: 5.0.0
- prompt-toolkit: 3.0.48
- propcache: 0.2.0
- protobuf: 5.28.3
- psutil: 6.1.0
- ptyprocess: 0.7.0
- pure-eval: 0.2.3
- pyarrow: 14.0.1
- pyarrow-hotfix: 0.6
- pycodestyle: 2.12.1
- pydantic: 1.10.18
- pyflakes: 3.2.0
- pygments: 2.18.0
- pylint: 3.3.1
- pyparsing: 3.2.0
- pyrootutils: 1.0.4
- pytest: 8.3.3
- pytest-cov: 6.0.0
- pytest-mock: 3.14.0
- python-dateutil: 2.9.0
- python-dotenv: 1.0.1
- pytorch-lightning: 2.4.0
- pytz: 2024.2
- pyyaml: 6.0.1
- pyzmq: 26.2.0
- regex: 2024.9.11
- requests: 2.32.3
- rich: 13.9.4
- ruamel.yaml: 0.18.6
- ruamel.yaml.clib: 0.2.12
- rubicon-ml: 0.10.3
- s3fs: 0.4.2
- s3transfer: 0.10.3
- safetensors: 0.4.5
- scikit-learn: 1.5.2
- scipy: 1.13.1
- seaborn: 0.13.2
- setuptools: 75.3.0
- six: 1.16.0
- smmap: 5.0.1
- stack-data: 0.6.2
- sympy: 1.13.1
- tensorboard: 2.18.0
- tensorboard-data-server: 0.7.2
- threadpoolctl: 3.5.0
- tokenize-rt: 6.1.0
- tokenizers: 0.20.3
- tomli: 2.0.2
- tomlkit: 0.13.2
- torch: 2.5.1
- torchmetrics: 1.5.1
- tornado: 6.4.1
- tqdm: 4.66.6
- traitlets: 5.14.3
- transformers: 4.45.2
- triton: 3.1.0
- typeguard: 4.3.0
- typing-extensions: 4.12.2
- urllib3: 1.26.20
- virtualenv: 20.27.1
- wcwidth: 0.2.13
- werkzeug: 3.1.2
- wheel: 0.44.0
- xxhash: 3.5.0
- yarl: 1.17.1
- zipp: 3.20.2
- OS: Linux
- architecture:
- 64bit
-
- processor: x86_64
- python: 3.9.20
- release: 5.10.226-214.880.amzn2.x86_64
- version: Proposal for help #1 SMP Tue Oct 8 16:18:15 UTC 2024
More info
No response
The text was updated successfully, but these errors were encountered: