Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Sub-workers exits without messages #692

Open
GongZhengLi opened this issue Mar 27, 2023 · 7 comments
Open

Sub-workers exits without messages #692

GongZhengLi opened this issue Mar 27, 2023 · 7 comments
Labels
bug Something isn't working

Comments

@GongZhengLi
Copy link

🐛 Bug

I use the script as follow:

CUDA_VISIBLE_DEVICES="0, 1, 2, 3" metaseq-train --task streaming_language_modeling
data/pile-test/
--num-workers 4
--reset-dataloader
--vocab-filename ./vocab/gpt2-vocab.json
--merges-filename ./vocab/gpt2-merges.txt
--model-parallel-size 1
--ddp-backend fully_sharded
--task-ddp-backend fully_sharded
--criterion cross_entropy
--batch-size 8
--save-dir /checkpoints/lm_transformer_pile-00
--arch transformer_lm_gpt2_tiny --share-decoder-input-output-embed
--dropout 0.1
--optimizer adam --weight-decay 0.01 --clip-norm 0.0
--lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07
--tokens-per-sample 1024 --sample-break-mode none --fp16
--use-sharded-state
--decoder-learned-pos
--log-format json
--log-interval 1

The rank 1, 2, 3 was exit before the loop of train_step. I print the every detailed log and find that the iter() inside more_itertools.peekable() kill all the non-master processes.
What's the matter with this ?

@GongZhengLi GongZhengLi added the bug Something isn't working label Mar 27, 2023
@GongZhengLi GongZhengLi changed the title Sub-wokers exits without messages Sub-workers exits without messages Mar 27, 2023
@mahnerak
Copy link
Member

mahnerak commented Apr 1, 2023

I tried to go deeper and recover the errors causing exit. I ended up here:
https://github.com/pytorch/pytorch/blob/db8abde9b6c4735d18d4681a1f70a55ff0b09f5b/torch/multiprocessing/spawn.py#L72-L76
Got an error:

Traceback (most recent call last):
  File "/home/USER/miniconda3/envs/MY_METASEQ_ENV/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/home/USER/pip_editable_packages/metaseq/metaseq/distributed/utils.py", line 227, in distributed_main
    retval = main(cfg, **kwargs)
  File "/home/USER/pip_editable_packages/metaseq/metaseq/cli/train.py", line 181, in main
    valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
  File "/home/USER/miniconda3/envs/MY_METASEQ_ENV/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/USER/pip_editable_packages/metaseq/metaseq/cli/train.py", line 214, in train
    itr = epoch_itr.next_epoch_itr(
  File "/home/USER/pip_editable_packages/metaseq/metaseq/data/iterators.py", line 268, in next_epoch_itr
    self._itr = self._get_iterator_for_epoch(self.epoch)
  File "/home/USER/pip_editable_packages/metaseq/metaseq/data/iterators.py", line 395, in _get_iterator_for_epoch
    itr = StreamingCountingIterator(
  File "/home/USER/pip_editable_packages/metaseq/metaseq/data/iterators.py", line 125, in __init__
    self._peekable_itr = more_itertools.peekable(iterable)
  File "/home/USER/miniconda3/envs/MY_METASEQ_ENV/lib/python3.9/site-packages/more_itertools/more.py", line 311, in __init__
    self._it = iter(iterable)
  File "/home/USER/miniconda3/envs/MY_METASEQ_ENV/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 368, in __iter__
    return self._get_iterator()
  File "/home/USER/miniconda3/envs/MY_METASEQ_ENV/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 314, in _get_iterator
    return _MultiProcessingDataLoaderIter(self)
  File "/home/USER/miniconda3/envs/MY_METASEQ_ENV/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 927, in __init__
    w.start()
  File "/home/USER/miniconda3/envs/MY_METASEQ_ENV/lib/python3.9/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/home/USER/miniconda3/envs/MY_METASEQ_ENV/lib/python3.9/multiprocessing/context.py", line 224, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
  File "/home/USER/miniconda3/envs/MY_METASEQ_ENV/lib/python3.9/multiprocessing/context.py", line 284, in _Popen
    return Popen(process_obj)
  File "/home/USER/miniconda3/envs/MY_METASEQ_ENV/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/home/USER/miniconda3/envs/MY_METASEQ_ENV/lib/python3.9/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/home/USER/miniconda3/envs/MY_METASEQ_ENV/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/home/USER/miniconda3/envs/MY_METASEQ_ENV/lib/python3.9/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
_pickle.PicklingError: Can't pickle <enum 'Choices'>: attribute lookup Choices on metaseq.dataclass.constants failed

Which made workers sys.exit(1) immediately, without even letting them to send the error messages.

This is the command I launch the training.

metaseq-train 
--task streaming_language_modeling /home/USER/PROJECT/WORKDIR/pile/ 
--vocab-filename /home/USER/PROJECT/WORKDIR/vocab.json 
--merges-filename /home/USER/PROJECT/WORKDIR/merges.txt 
--criterion cross_entropy 
--batch-size 8 
--save-dir /home/USER/PROJECT/WORKDIR/ckpts/a4 
--arch transformer_lm 
--share-decoder-input-output-embed 
--dropout 0.1 
--optimizer adam 
--weight-decay 0.01 
--clip-norm 0.0 
--lr 0.0005 
--lr-scheduler inverse_sqrt 
--warmup-updates 4000 
--warmup-init-lr 1e-07 
--tokens-per-sample 1024 
--sample-break-mode none 
--decoder-learned-pos 
--log-format json 
--log-interval 1 
--aim-repo /home/USER/PROJECT/WORKDIR/. 
--save-interval-updates 30000 
--fp16

Two CUDA GPUs are available.
Tested on both physical machine, VM, as well as Slurm.

Single-GPU version (just setting CUDA_VISIBLE_DEVICES env variable) works well.

@mahnerak
Copy link
Member

mahnerak commented Apr 1, 2023

The rank 1, 2, 3 was exit before the loop of train_step. I print the every detailed log and find that the iter() inside more_itertools.peekable() kill all the non-master processes.

Confirming that this error in my case too comes from more_itertools.peekable() so it's very likely we're experiencing with the same bug.

@GongZhengLi
Copy link
Author

@mahnerak I solved this by add num_workers=0.
It seems like a bug from pytorch !

@mahnerak
Copy link
Member

mahnerak commented Apr 4, 2023

Thanks @GongZhengLi

I don't think num_workers=0 will be okay in my setup. The data is too big. The training will be bottlenecked by data processing and the GPUs will be very underutilized :/

@GongZhengLi
Copy link
Author

@mahnerak , did you solve it ?

@mahnerak
Copy link
Member

Not yet. Still waiting.
I might come back to this issue in couple of days, but not sure if anything is changed related to this issue.

@jihwankwak
Copy link

@mahnerak @GongZhengLi
DId any one solve this issue? I am facing the same issue while using 1 node & 2 gpus

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants