diff --git a/README.md b/README.md index b1964314..e5f0f0ca 100644 --- a/README.md +++ b/README.md @@ -374,7 +374,9 @@ run_training( Below is a list of custom environment variables users can set in the training library. -1. `INSTRUCTLAB_NCCL_TIMEOUT_MS`, this environment variable controls the NCCL timeout in milliseconds. Consider increasing if seeing FSDP related NCCL errors. +1. `INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS`, this environment variable controls + the process group timeout in milliseconds. Consider increasing if seeing + FSDP related collective timeout errors. ## Developer Certificate of Origin diff --git a/src/instructlab/training/const.py b/src/instructlab/training/const.py new file mode 100644 index 00000000..3d316590 --- /dev/null +++ b/src/instructlab/training/const.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 + +INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS = "INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS" diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index ccb417d1..29f028b2 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -52,8 +52,7 @@ TorchrunArgs, TrainingArgs, ) - -# pylint: disable=no-name-in-module +from instructlab.training.const import INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS from instructlab.training.logger import ( propagate_package_logs, setup_metric_logger, @@ -278,7 +277,7 @@ def train( # time of writing) default: to cover the unlikely event torch decides to tweak # the default. def _get_collective_timeout() -> datetime.timedelta | None: - timeout_var = os.getenv("INSTRUCTLAB_NCCL_TIMEOUT_MS") + timeout_var = os.getenv(INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS) if timeout_var is None: return None @@ -289,7 +288,7 @@ def _get_collective_timeout() -> datetime.timedelta | None: if timeout <= 0: raise ValueError( - f"Invalid value for INSTRUCTLAB_NCCL_TIMEOUT_MS: {timeout_var}. Must be a positive integer." + f"Invalid value for {INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS}: {timeout_var}. Must be a positive integer." ) return datetime.timedelta(milliseconds=timeout) diff --git a/tests/unit/test_main_ds.py b/tests/unit/test_main_ds.py index 12b35127..4df20a11 100644 --- a/tests/unit/test_main_ds.py +++ b/tests/unit/test_main_ds.py @@ -7,6 +7,7 @@ # First Party from instructlab.training import main_ds +from instructlab.training.const import INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS def test__get_collective_timeout(): @@ -16,7 +17,7 @@ def test__get_collective_timeout(): # Test with custom timeout timeout = 1234 with mock.patch.dict( - main_ds.os.environ, {"INSTRUCTLAB_NCCL_TIMEOUT_MS": str(timeout)} + main_ds.os.environ, {INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS: str(timeout)} ): assert main_ds._get_collective_timeout() == datetime.timedelta( milliseconds=timeout @@ -25,7 +26,7 @@ def test__get_collective_timeout(): # Test with invalid timeout (negative) invalid_timeout = "-100" with mock.patch.dict( - main_ds.os.environ, {"INSTRUCTLAB_NCCL_TIMEOUT_MS": invalid_timeout} + main_ds.os.environ, {INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS: invalid_timeout} ): with pytest.raises(ValueError): main_ds._get_collective_timeout() @@ -33,7 +34,7 @@ def test__get_collective_timeout(): # Test with invalid timeout (string) invalid_timeout = "invalid" with mock.patch.dict( - main_ds.os.environ, {"INSTRUCTLAB_NCCL_TIMEOUT_MS": invalid_timeout} + main_ds.os.environ, {INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS: invalid_timeout} ): with pytest.raises(ValueError): main_ds._get_collective_timeout() diff --git a/tox.ini b/tox.ini index 0794c417..f2ae4bfc 100644 --- a/tox.ini +++ b/tox.ini @@ -20,9 +20,9 @@ basepython = python3.11 [testenv:py3-unit] description = run unit tests with pytest passenv = - HF_HOME - INSTRUCTLAB_NCCL_TIMEOUT_MS - CMAKE_ARGS + HF_HOME + INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS + CMAKE_ARGS # Use PyTorch CPU build instead of CUDA build in test envs. CUDA dependencies # are huge. This reduces venv from 5.7 GB to 1.5 GB. @@ -40,8 +40,8 @@ commands = {envpython} -m pytest tests/unit {posargs} [testenv:py3-smoke] description = run accelerated smoke tests with pytest passenv = - HF_HOME - INSTRUCTLAB_NCCL_TIMEOUT_MS + HF_HOME + INSTRUCTLAB_PROCESS_GROUP_TIMEOUT_MS deps = -r requirements-dev.txt -r requirements-cuda.txt