Skip to content

Commit 0cbbe2b

Browse files
committed
fix: Use default torch timeout for nccl watchdog unless overridden
The default value is recommended, and we should not change it in production. The knob may still be useful for debugging or testing purposes though. Signed-off-by: Ihar Hrachyshka <[email protected]>
1 parent fd03460 commit 0cbbe2b

File tree

2 files changed

+66
-6
lines changed

2 files changed

+66
-6
lines changed

src/instructlab/training/main_ds.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pathlib import Path
66
import argparse
77
import datetime
8+
import functools
89
import math
910
import os
1011
import re
@@ -533,6 +534,24 @@ def train(
533534
)
534535

535536

537+
def _get_collective_timeout() -> datetime.timedelta | None:
538+
timeout_var = os.getenv("INSTRUCTLAB_NCCL_TIMEOUT_MS")
539+
if timeout_var is None:
540+
return None
541+
542+
try:
543+
timeout = int(timeout_var)
544+
except ValueError:
545+
timeout = -1
546+
547+
if timeout <= 0:
548+
raise ValueError(
549+
f"Invalid value for INSTRUCTLAB_NCCL_TIMEOUT_MS: {timeout_var}. Must be a positive integer."
550+
)
551+
552+
return datetime.timedelta(milliseconds=timeout)
553+
554+
536555
def main(args):
537556
# Third Party
538557
import yaml
@@ -566,15 +585,17 @@ def main(args):
566585
model_conf = AutoConfig.from_pretrained(args.model_name_or_path)
567586
args.model_type = model_conf.model_type
568587

569-
# solution discovered from torchtune https://github.com/pytorch/torchtune/issues/2093
570-
# gets converted to a timedelta of 1:40:00 if the default is kept
571-
nccl_timeout = int(os.getenv("INSTRUCTLAB_NCCL_TIMEOUT_MS", "6000000"))
572588
#### distributed init #####
573589
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
574590
args.local_rank = int(os.environ["LOCAL_RANK"])
575-
torch.distributed.init_process_group(
576-
"nccl", timeout=datetime.timedelta(milliseconds=nccl_timeout)
577-
)
591+
592+
timeout = _get_collective_timeout()
593+
init = functools.partial(torch.distributed.init_process_group, "nccl")
594+
if timeout is not None:
595+
init(timeout=timeout)
596+
else:
597+
init()
598+
578599
args.global_rank = torch.distributed.get_rank()
579600
tensor = torch.ByteTensor([False]).cuda()
580601
torch.distributed.all_reduce(tensor)

tests/unit/test_main_ds.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Standard
2+
from unittest import mock
3+
import datetime
4+
5+
# Third Party
6+
import pytest
7+
8+
# First Party
9+
from instructlab.training import main_ds
10+
11+
12+
def test__get_collective_timeout():
13+
# Test with default timeout
14+
assert main_ds._get_collective_timeout() is None
15+
16+
# Test with custom timeout
17+
timeout = 1234
18+
with mock.patch.dict(
19+
main_ds.os.environ, {"INSTRUCTLAB_NCCL_TIMEOUT_MS": str(timeout)}
20+
):
21+
assert main_ds._get_collective_timeout() == datetime.timedelta(
22+
milliseconds=timeout
23+
)
24+
25+
# Test with invalid timeout (negative)
26+
invalid_timeout = "-100"
27+
with mock.patch.dict(
28+
main_ds.os.environ, {"INSTRUCTLAB_NCCL_TIMEOUT_MS": invalid_timeout}
29+
):
30+
with pytest.raises(ValueError):
31+
main_ds._get_collective_timeout()
32+
33+
# Test with invalid timeout (string)
34+
invalid_timeout = "invalid"
35+
with mock.patch.dict(
36+
main_ds.os.environ, {"INSTRUCTLAB_NCCL_TIMEOUT_MS": invalid_timeout}
37+
):
38+
with pytest.raises(ValueError):
39+
main_ds._get_collective_timeout()

0 commit comments

Comments
 (0)