Skip to content

set device_id in torch's init_process_group #7266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

stas00
Copy link
Collaborator

@stas00 stas00 commented Apr 30, 2025

This PR overcomes this issue when using any torch.distributed calls w/ deepspeed:

[W404 00:15:21.693690333 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 
to perform barrier as devices used by this process are currently unknown. This can
 potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in
 barrier() to force use of a particular device, or call init_process_group() with a device_id.

by setting device_id to the correct device corresponding to LOCAL_RANK env var.


Update: discovered torch.dist deadlocks with torch=>2.7.0 when using device_id arg - switching to draft for now as we can't commit this until we know how to work around this.

@stas00 stas00 requested a review from GuanhuaWang as a code owner April 30, 2025 19:09
@stas00 stas00 requested review from loadams and removed request for GuanhuaWang April 30, 2025 19:09
@stas00
Copy link
Collaborator Author

stas00 commented May 6, 2025

@loadams?

@loadams
Copy link
Collaborator

loadams commented May 7, 2025

@loadams?

Sorry @stas00, I missed this and will review today.

torch.distributed.init_process_group(backend,
timeout=timeout,
init_method=init_method,
rank=rank,
world_size=world_size)
world_size=world_size,
device_id=torch.device('cuda', local_rank))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00 - the cuda here will cause failures on non-cuda backends like HPU (not sure why the tests didn't run, but ran manually here: https://github.com/deepspeedai/DeepSpeed/actions/runs/14886572284/job/41807642413)

Copy link
Collaborator Author

@stas00 stas00 May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aha, thank you so much for seeing the big picture, @loadams

so we need something like:

device = torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu')

or should I just add device_id only if torch.cuda.is_available() and do nothing otherwise - I mean I don't know what device to use in the case of HPU if it's not cpu?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use get_accelerator() here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whatever works - could you please show what you have in mind specifically for filling out:

device_id=torch.device('cuda', local_rank)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what is needed is get_accelerator().device(local_rank).

For cuda this maps to torch.cuda.device(device_index)

@stas00, does that work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, still randomly deadlocks:

Thread 474480 (idle): "MainThread"
    broadcast (torch/distributed/distributed_c10d.py:2772)
    wrapper (torch/distributed/c10d_logger.py:81)
    broadcast (deepspeed/comm/torch.py:216)
    broadcast (deepspeed/comm/comm.py:224)
    log_wrapper (deepspeed/comm/comm.py:117)
    _zero_init_param (deepspeed/runtime/zero/partition_parameters.py:1054)
    _post_init_method (deepspeed/runtime/zero/partition_parameters.py:1099)
    wrapper (deepspeed/runtime/zero/partition_parameters.py:521)
    __init__ (transformers/models/llama/modeling_llama.py:166)
    wrapper (deepspeed/runtime/zero/partition_parameters.py:511)
    __init__ (transformers/models/llama/modeling_llama.py:297)
    wrapper (deepspeed/runtime/zero/partition_parameters.py:511)
    <listcomp> (transformers/models/llama/modeling_llama.py:477)
    __init__ (transformers/models/llama/modeling_llama.py:477)
    wrapper (deepspeed/runtime/zero/partition_parameters.py:511)
    __init__ (transformers/models/llama/modeling_llama.py:740)
    wrapper (deepspeed/runtime/zero/partition_parameters.py:511)
    from_pretrained (transformers/modeling_utils.py:4340)
    _wrapper (transformers/modeling_utils.py:279)
    from_pretrained (transformers/models/auto/auto_factory.py:571)
    from_pretrained (liger_kernel/transformers/auto_model.py:38)
    create_model (arctic_training/model/liger_factory.py:45)
    wrapper (arctic_training/callback/mixin.py:45)
    __call__ (arctic_training/model/factory.py:68)
    __init__ (arctic_training/trainer/trainer.py:228)
    wrapper (arctic_training/callback/mixin.py:45)
    run_script (arctic_training/cli.py:108)
    <module> (arctic_training_run:8)
Thread 476034 (idle): "Thread-1"
    wait (threading.py:324)
    wait (threading.py:607)
    run (tqdm/_monitor.py:60)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Process 474481: /usr/bin/python -u /home/yak/.local/bin/arctic_training_run --local_rank=6 --mode train --config run-dp1-sp8.yml
Python v3.10.12 (/usr/bin/python3.10)

Thread 474481 (active): "MainThread"
    wrapped_fn (deepspeed/runtime/zero/partition_parameters.py:240)
    _compute_default_rope_parameters (transformers/modeling_rope_utils.py:130)
    __init__ (transformers/models/llama/modeling_llama.py:106)
    wrapper (deepspeed/runtime/zero/partition_parameters.py:511)
    __init__ (transformers/models/llama/modeling_llama.py:480)
    wrapper (deepspeed/runtime/zero/partition_parameters.py:511)
    __init__ (transformers/models/llama/modeling_llama.py:740)
    wrapper (deepspeed/runtime/zero/partition_parameters.py:511)
    from_pretrained (transformers/modeling_utils.py:4340)
    _wrapper (transformers/modeling_utils.py:279)
    from_pretrained (transformers/models/auto/auto_factory.py:571)
    from_pretrained (liger_kernel/transformers/auto_model.py:38)
    create_model (arctic_training/model/liger_factory.py:45)
    wrapper (arctic_training/callback/mixin.py:45)
    __call__ (arctic_training/model/factory.py:68)
    __init__ (arctic_training/trainer/trainer.py:228)
    wrapper (arctic_training/callback/mixin.py:45)
    run_script (arctic_training/cli.py:108)
    <module> (arctic_training_run:8)
Thread 476031 (idle): "Thread-1"
    wait (threading.py:324)
    wait (threading.py:607)
    run (tqdm/_monitor.py:60)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Process 474482: /usr/bin/python -u /home/yak/.local/bin/arctic_training_run --local_rank=7 --mode train --config run-dp1-sp8.yml
Python v3.10.12 (/usr/bin/python3.10)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so definitely let's not merge this - asking pytorch folks for help.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Yes, let's align with HF Transformers. Thanks.

@loadams can you help with this?

@sfc-gh-truwase @stas00 - yes, I think we should do something like this. Min 2.1 would be good. Agreed 2.3 might be a bit rushed, but let me check what cuda/GPU versions that implies as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Further investigation shows the deadlocks start at torch>=2.7.0 - it's difficult to debug since the deadlocks aren't always reproducible but usually happen after some 3-6 re-runs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So switching to draft for now as we can't commit this until we know how to work around this. I'm actively pursuing this with pytorch devs.

A seemingly related Issue is: modded-nanogpt flaky NCCL hang starting 3/30 nightly

stas00 added 2 commits May 15, 2025 11:10
Signed-off-by: Stas Bekman <[email protected]>
@stas00 stas00 marked this pull request as draft May 15, 2025 22:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants