Skip to content

Commit 5136461

Browse files
authored
Merge pull request #521 from booxter/revert-to-default-timeout-for-pytorch
fix: Use default torch timeout for nccl watchdog unless overridden
2 parents 6c76c98 + 2cdb409 commit 5136461

File tree

2 files changed

+70
-6
lines changed

2 files changed

+70
-6
lines changed

src/instructlab/training/main_ds.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pathlib import Path
66
import argparse
77
import datetime
8+
import functools
89
import logging
910
import math
1011
import os
@@ -544,6 +545,28 @@ def train(
544545
)
545546

546547

548+
# This function makes an effort to stick to a default value from torch library,
549+
# whatever it may be. That's why we don't just set to the current (as of the
550+
# time of writing) default: to cover the unlikely event torch decides to tweak
551+
# the default.
552+
def _get_collective_timeout() -> datetime.timedelta | None:
553+
timeout_var = os.getenv("INSTRUCTLAB_NCCL_TIMEOUT_MS")
554+
if timeout_var is None:
555+
return None
556+
557+
try:
558+
timeout = int(timeout_var)
559+
except ValueError:
560+
timeout = -1
561+
562+
if timeout <= 0:
563+
raise ValueError(
564+
f"Invalid value for INSTRUCTLAB_NCCL_TIMEOUT_MS: {timeout_var}. Must be a positive integer."
565+
)
566+
567+
return datetime.timedelta(milliseconds=timeout)
568+
569+
547570
def main(args):
548571
if args.distributed_training_framework == "deepspeed" and not FusedAdam:
549572
raise ImportError(
@@ -571,15 +594,17 @@ def main(args):
571594
model_conf = AutoConfig.from_pretrained(args.model_name_or_path)
572595
args.model_type = model_conf.model_type
573596

574-
# solution discovered from torchtune https://github.com/pytorch/torchtune/issues/2093
575-
# gets converted to a timedelta of 1:40:00 if the default is kept
576-
nccl_timeout = int(os.getenv("INSTRUCTLAB_NCCL_TIMEOUT_MS", "6000000"))
577597
#### distributed init #####
578598
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
579599
args.local_rank = int(os.environ["LOCAL_RANK"])
580-
torch.distributed.init_process_group(
581-
"nccl", timeout=datetime.timedelta(milliseconds=nccl_timeout)
582-
)
600+
601+
timeout = _get_collective_timeout()
602+
init = functools.partial(torch.distributed.init_process_group, "nccl")
603+
if timeout is not None:
604+
init(timeout=timeout)
605+
else:
606+
init()
607+
583608
args.global_rank = torch.distributed.get_rank()
584609
tensor = torch.ByteTensor([False]).cuda()
585610
torch.distributed.all_reduce(tensor)

tests/unit/test_main_ds.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Standard
2+
from unittest import mock
3+
import datetime
4+
5+
# Third Party
6+
import pytest
7+
8+
# First Party
9+
from instructlab.training import main_ds
10+
11+
12+
def test__get_collective_timeout():
13+
# Test with default timeout
14+
assert main_ds._get_collective_timeout() is None
15+
16+
# Test with custom timeout
17+
timeout = 1234
18+
with mock.patch.dict(
19+
main_ds.os.environ, {"INSTRUCTLAB_NCCL_TIMEOUT_MS": str(timeout)}
20+
):
21+
assert main_ds._get_collective_timeout() == datetime.timedelta(
22+
milliseconds=timeout
23+
)
24+
25+
# Test with invalid timeout (negative)
26+
invalid_timeout = "-100"
27+
with mock.patch.dict(
28+
main_ds.os.environ, {"INSTRUCTLAB_NCCL_TIMEOUT_MS": invalid_timeout}
29+
):
30+
with pytest.raises(ValueError):
31+
main_ds._get_collective_timeout()
32+
33+
# Test with invalid timeout (string)
34+
invalid_timeout = "invalid"
35+
with mock.patch.dict(
36+
main_ds.os.environ, {"INSTRUCTLAB_NCCL_TIMEOUT_MS": invalid_timeout}
37+
):
38+
with pytest.raises(ValueError):
39+
main_ds._get_collective_timeout()

0 commit comments

Comments
 (0)