seed_everything(..., workers=True)
causes the Dataloader
to apply exactly the same augmentations each epoch if they sample values from torch.distributions
#20412
Labels
Bug description
If
seed_everything
with theworkers=True
flag is used, the seeds generated by_generate_seed_sequence
that are used inpl_worker_init_fn
will make each worker apply the same augmentations/transforms each epoch, if these augmentations usetorch.distributions
to sample random numbers (like for exampletorchvision.transforms.v2.MixUp
).Ergo, when you run torch.manual_seed() on seeds returned by
_generate_seed_sequence
, the number returned bytorch.distributions.Beta().sample()
will be the same for every seed, even though these seeds are different. Some examples of these generated seeds and the described behaviour is shown in the Example Seeds section. This is a problem since these augmentations need to change each epoch, otherwise they have no use.Example seeds
Minimal reproduction using pytest
On my fork of pytorch lightning, I added 3 tests that replicate the issue described above. I added this line in
fabric/utilities/seed.py
to sample values in the actual implementation of lightning. I also copied the relevant functions in thetest_worker_custom_seed_everything
test, for easier debugging. They have the same described behaviour. On the other hand, in thetest_worker_no_seed
test, no seed is set, and I added a custom worker_init_fn to the Dataloader, that just samples a value from a distribution and prints it.Outputs
python -m pytest tests/tests_fabric/test_fabric.py::test_worker_default_seed_everything -v -s
python -m pytest tests/tests_fabric/test_fabric.py::test_worker_custom_seed_everything -v -s
python -m pytest tests/tests_fabric/test_fabric.py::test_worker_no_seed -v -s
In this case, a DataLoader with only 1 worker is created, to observe the problem easier. It also happens when multiple workers are used.
You can observe that even though the seed used to set torch.manual_seed in the worker's init function is different, the sampled values are the same across epochs. These tests are not complete, I could not figure out how to get either the initial seed or the sampled values of the workers, inside the test function.
Probable cause
Removing
torch.manual_seed(seed_sqeuence[0])
from thepl_worker_init_function
resolves the issue of repeated values across epochs. It is also how pytorch recommends implementing this function in their documentation (only setting the random seed for numpy and random).What version are you seeing the problem on?
master
How to reproduce the bug
Below is a full end-to-end training example of the issue. The code is a reimplementation of this paper.
How to run
python main.py none
- will run without setting any seedpython main.py custom
- will run using a version ofseed_everything
implemented in the current file, in order to see the worker seeds and sampled valuespython main.py lightning
- will run using the defaultseed_everything
function from lightningPlotting the train_loss graphs on wandb, you can see that the custom and lightning versions have the same smooth loss, while the version that uses no set seed is spiky. This indicates that the no-seed version applies the augmentations correctly, while the set-seed versions don't.
Error messages and logs
No response
Environment
Current environment
- GPU:
- NVIDIA GeForce RTX 3060
- available: True
- version: 12.4
- lightning: 2.5.0.dev0
- lightning-utilities: 0.11.8
- pytorch-lightning: 2.4.0
- torch: 2.5.1
- torchmetrics: 1.5.2
- torchvision: 0.20.1
- absl-py: 2.1.0
- aiohappyeyeballs: 2.4.3
- aiohttp: 3.10.10
- aiosignal: 1.3.1
- alabaster: 0.7.16
- antlr4-python3-runtime: 4.9.3
- anyio: 4.6.2.post1
- appdirs: 1.4.4
- argcomplete: 3.4.0
- argon2-cffi: 23.1.0
- argon2-cffi-bindings: 21.2.0
- arrow: 1.3.0
- astroid: 3.2.4
- asttokens: 2.4.1
- async-lru: 2.0.4
- async-timeout: 4.0.3
- attrs: 23.2.0
- autocommand: 2.2.2
- babel: 2.16.0
- backcall: 0.2.0
- backports.tarfile: 1.2.0
- beautifulsoup4: 4.12.3
- bitsandbytes: 0.44.1
- black: 24.8.0
- bleach: 6.2.0
- boltons: 23.0.0
- brotli: 1.0.9
- certifi: 2024.8.30
- cffi: 1.17.1
- cfgv: 3.4.0
- charset-normalizer: 3.3.2
- click: 8.1.7
- cloudpickle: 2.2.1
- colorama: 0.4.6
- coloredlogs: 15.0.1
- comm: 0.2.2
- conda: 23.3.1
- conda-package-handling: 2.3.0
- conda-package-streaming: 0.10.0
- contourpy: 1.2.0
- coverage: 7.3.1
- cryptography: 43.0.0
- curio: 1.6
- cycler: 0.12.1
- debugpy: 1.8.8
- decorator: 5.1.1
- deepspeed: 0.9.3
- defusedxml: 0.7.1
- dill: 0.3.8
- distlib: 0.3.9
- docstring-parser: 0.16
- docutils: 0.21.2
- dotty-dict: 1.3.1
- exceptiongroup: 1.2.2
- executing: 2.1.0
- fastapi: 0.115.4
- fastjsonschema: 2.20.0
- filelock: 3.16.1
- flatbuffers: 24.3.25
- fonttools: 4.47.0
- fqdn: 1.5.1
- frozenlist: 1.5.0
- fsspec: 2024.10.0
- grpcio: 1.67.1
- h11: 0.14.0
- halo: 0.0.31
- hid: 1.0.6
- hjson: 3.1.0
- httpcore: 1.0.6
- httpx: 0.27.2
- humanfriendly: 10.0
- hydra-core: 1.3.2
- identify: 2.6.2
- idna: 3.7
- imagesize: 1.4.1
- importlib-metadata: 8.5.0
- importlib-resources: 6.1.1
- inflect: 7.3.1
- iniconfig: 2.0.0
- ipykernel: 6.29.5
- ipyparallel: 9.0.0
- ipython: 8.1.1
- ipywidgets: 8.1.5
- isoduration: 20.11.0
- isort: 5.13.2
- jaraco.collections: 5.1.0
- jaraco.context: 5.3.0
- jaraco.functools: 4.0.1
- jaraco.text: 3.12.1
- jedi: 0.19.2
- jinja2: 3.1.4
- joblib: 1.4.2
- json5: 0.9.28
- jsonargparse: 4.34.0
- jsonpatch: 1.33
- jsonpointer: 2.1
- jsonschema: 4.23.0
- jsonschema-specifications: 2023.12.1
- jupyter-client: 8.6.3
- jupyter-core: 5.7.2
- jupyter-events: 0.10.0
- jupyter-lsp: 2.2.5
- jupyter-server: 2.14.2
- jupyter-server-terminals: 0.5.3
- jupyterlab: 4.2.5
- jupyterlab-pygments: 0.3.0
- jupyterlab-server: 2.27.3
- jupyterlab-widgets: 3.0.13
- kiwisolver: 1.4.5
- lightning: 2.5.0.dev0
- lightning-utilities: 0.11.8
- log-symbols: 0.0.14
- markdown: 3.7
- markdown-it-py: 3.0.0
- markupsafe: 3.0.2
- matplotlib: 3.9.2
- matplotlib-inline: 0.1.7
- mccabe: 0.7.0
- mdurl: 0.1.2
- milc: 1.8.0
- mistune: 3.0.2
- more-itertools: 10.3.0
- mpmath: 1.3.0
- multidict: 6.1.0
- mypy-extensions: 1.0.0
- nbclient: 0.10.0
- nbconvert: 7.16.4
- nbformat: 5.10.4
- nest-asyncio: 1.6.0
- networkx: 3.2.1
- ninja: 1.11.1.1
- nodeenv: 1.9.1
- notebook: 7.2.2
- notebook-shim: 0.2.4
- numpy: 1.26.4
- nvidia-cublas-cu12: 12.4.5.8
- nvidia-cuda-cupti-cu12: 12.4.127
- nvidia-cuda-nvrtc-cu12: 12.4.127
- nvidia-cuda-runtime-cu12: 12.4.127
- nvidia-cudnn-cu12: 9.1.0.70
- nvidia-cufft-cu12: 11.2.1.3
- nvidia-curand-cu12: 10.3.5.147
- nvidia-cusolver-cu12: 11.6.1.9
- nvidia-cusparse-cu12: 12.3.1.170
- nvidia-nccl-cu12: 2.21.5
- nvidia-nvjitlink-cu12: 12.4.127
- nvidia-nvtx-cu12: 12.4.127
- omegaconf: 2.3.0
- onnx: 1.17.0
- onnxruntime: 1.19.2
- outcome: 1.3.0.post0
- overrides: 7.7.0
- packaging: 24.1
- pandas: 2.2.3
- pandocfilters: 1.5.1
- parso: 0.8.4
- pathspec: 0.10.3
- pexpect: 4.9.0
- pickleshare: 0.7.5
- pillow: 10.4.0
- pip: 24.2
- platformdirs: 3.10.0
- pluggy: 1.0.0
- pre-commit: 4.0.1
- prometheus-client: 0.21.0
- prompt-toolkit: 3.0.48
- propcache: 0.2.0
- protobuf: 5.28.3
- psutil: 5.9.8
- ptyprocess: 0.7.0
- pure-eval: 0.2.3
- py-cpuinfo: 9.0.0
- pycocotools: 2.0
- pycosat: 0.6.6
- pycparser: 2.21
- pydantic: 1.10.19
- pygments: 2.18.0
- pylint: 3.2.7
- pyopenssl: 24.2.1
- pyparsing: 3.1.1
- pyserial: 3.5
- pysocks: 1.7.1
- pytest: 7.4.0
- pytest-asyncio: 0.23.8
- pytest-cov: 4.1.0
- pytest-random-order: 1.1.0
- pytest-rerunfailures: 12.0
- pytest-timeout: 2.1.0
- python-dateutil: 2.9.0.post0
- python-json-logger: 2.0.7
- pytorch-lightning: 2.4.0
- pytz: 2024.2
- pyusb: 1.2.1
- pyyaml: 6.0.2
- pyzmq: 26.2.0
- qmk: 1.1.5
- qtconsole: 5.6.1
- qtpy: 2.4.2
- referencing: 0.35.1
- requests: 2.32.3
- rfc3339-validator: 0.1.4
- rfc3986-validator: 0.1.1
- rich: 13.9.4
- rpds-py: 0.19.1
- ruamel.yaml: 0.17.21
- ruamel.yaml.clib: 0.2.8
- scikit-learn: 1.5.2
- scipy: 1.13.1
- send2trash: 1.8.3
- setuptools: 75.1.0
- six: 1.16.0
- sniffio: 1.3.1
- snowballstemmer: 2.2.0
- sortedcontainers: 2.4.0
- soupsieve: 2.6
- sphinx: 7.4.7
- sphinxcontrib-applehelp: 2.0.0
- sphinxcontrib-devhelp: 2.0.0
- sphinxcontrib-htmlhelp: 2.1.0
- sphinxcontrib-jsmath: 1.0.1
- sphinxcontrib-qthelp: 2.0.0
- sphinxcontrib-serializinghtml: 2.0.0
- spinners: 0.0.24
- stack-data: 0.6.3
- starlette: 0.41.2
- sympy: 1.13.1
- tensorboard: 2.18.0
- tensorboard-data-server: 0.7.2
- tensorboardx: 2.6.2.2
- termcolor: 2.4.0
- terminado: 0.18.1
- testpath: 0.6.0
- threadpoolctl: 3.5.0
- tinycss2: 1.4.0
- tomli: 2.0.1
- tomlkit: 0.13.2
- toolz: 0.12.0
- torch: 2.5.1
- torchmetrics: 1.5.2
- torchvision: 0.20.1
- tornado: 6.4.1
- tqdm: 4.66.5
- traitlets: 5.14.3
- trio: 0.27.0
- triton: 3.1.0
- typeguard: 4.3.0
- types-colorama: 0.4.15.20240311
- types-python-dateutil: 2.9.0.20241003
- typeshed-client: 2.7.0
- typing-extensions: 4.11.0
- tzdata: 2024.2
- uri-template: 1.3.0
- urllib3: 2.2.3
- uvicorn: 0.32.0
- virtualenv: 20.27.1
- wcwidth: 0.2.13
- webcolors: 24.11.1
- webencodings: 0.5.1
- websocket-client: 1.8.0
- werkzeug: 3.1.3
- wheel: 0.44.0
- widgetsnbextension: 4.0.13
- yarl: 1.17.1
- zipp: 3.21.0
- zstandard: 0.23.0
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.9.20
- release: 5.15.153.1-microsoft-standard-WSL2
- version: Proposal for help #1 SMP Fri Mar 29 23:14:13 UTC 2024
More info
No response
The text was updated successfully, but these errors were encountered: