|
5 | 5 | from pathlib import Path
|
6 | 6 | import argparse
|
7 | 7 | import datetime
|
| 8 | +import functools |
8 | 9 | import logging
|
9 | 10 | import math
|
10 | 11 | import os
|
@@ -544,6 +545,28 @@ def train(
|
544 | 545 | )
|
545 | 546 |
|
546 | 547 |
|
| 548 | +# This function makes an effort to stick to a default value from torch library, |
| 549 | +# whatever it may be. That's why we don't just set to the current (as of the |
| 550 | +# time of writing) default: to cover the unlikely event torch decides to tweak |
| 551 | +# the default. |
| 552 | +def _get_collective_timeout() -> datetime.timedelta | None: |
| 553 | + timeout_var = os.getenv("INSTRUCTLAB_NCCL_TIMEOUT_MS") |
| 554 | + if timeout_var is None: |
| 555 | + return None |
| 556 | + |
| 557 | + try: |
| 558 | + timeout = int(timeout_var) |
| 559 | + except ValueError: |
| 560 | + timeout = -1 |
| 561 | + |
| 562 | + if timeout <= 0: |
| 563 | + raise ValueError( |
| 564 | + f"Invalid value for INSTRUCTLAB_NCCL_TIMEOUT_MS: {timeout_var}. Must be a positive integer." |
| 565 | + ) |
| 566 | + |
| 567 | + return datetime.timedelta(milliseconds=timeout) |
| 568 | + |
| 569 | + |
547 | 570 | def main(args):
|
548 | 571 | if args.distributed_training_framework == "deepspeed" and not FusedAdam:
|
549 | 572 | raise ImportError(
|
@@ -571,15 +594,17 @@ def main(args):
|
571 | 594 | model_conf = AutoConfig.from_pretrained(args.model_name_or_path)
|
572 | 595 | args.model_type = model_conf.model_type
|
573 | 596 |
|
574 |
| - # solution discovered from torchtune https://github.com/pytorch/torchtune/issues/2093 |
575 |
| - # gets converted to a timedelta of 1:40:00 if the default is kept |
576 |
| - nccl_timeout = int(os.getenv("INSTRUCTLAB_NCCL_TIMEOUT_MS", "6000000")) |
577 | 597 | #### distributed init #####
|
578 | 598 | torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
579 | 599 | args.local_rank = int(os.environ["LOCAL_RANK"])
|
580 |
| - torch.distributed.init_process_group( |
581 |
| - "nccl", timeout=datetime.timedelta(milliseconds=nccl_timeout) |
582 |
| - ) |
| 600 | + |
| 601 | + timeout = _get_collective_timeout() |
| 602 | + init = functools.partial(torch.distributed.init_process_group, "nccl") |
| 603 | + if timeout is not None: |
| 604 | + init(timeout=timeout) |
| 605 | + else: |
| 606 | + init() |
| 607 | + |
583 | 608 | args.global_rank = torch.distributed.get_rank()
|
584 | 609 | tensor = torch.ByteTensor([False]).cuda()
|
585 | 610 | torch.distributed.all_reduce(tensor)
|
|
0 commit comments