Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add no_ssh and slurm multinode launcher options for deepspeed #3329

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion docs/source/usage_guides/deepspeed.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ Currently, `Accelerate` supports following config through the CLI:
`deepspeed_hostfile`: DeepSpeed hostfile for configuring multi-node compute resources.
`deepspeed_exclusion_filter`: DeepSpeed exclusion filter string when using mutli-node setup.
`deepspeed_inclusion_filter`: DeepSpeed inclusion filter string when using mutli-node setup.
`deepspeed_multinode_launcher`: DeepSpeed multi-node launcher to use. If unspecified, will default to `pdsh`.
`deepspeed_multinode_launcher`: DeepSpeed multi-node launcher to use, e.g. `pdsh`, `standard`, `openmpi`, `mvapich`, `mpich`, `slurm`, `nossh` (requires DeepSpeed >= 0.14.5). If unspecified, will default to `pdsh`.
`deepspeed_config_file`: path to the DeepSpeed config file in `json` format. See the next section for more details on this.
```
To be able to tweak more options, you will need to use a DeepSpeed config file.
Expand Down Expand Up @@ -710,6 +710,13 @@ model, eval_dataloader = accelerator.prepare(model, eval_dataloader)
2. Current integration doesn’t support `mpu`, limiting the tensor parallelism which is supported in Megatron-LM.
3. Current integration doesn’t support multiple models.

## Multi-node DeepSpeed
DeepSpeed supports multi-node inference and training over a variety of different launchers. You can specify a different launcher by setting the `deepspeed_multinode_launcher` config in the CLI or in the DeepSpeed config file.

Currently, accelerate supports passing configuration for the following DeepSpeed multi-node launchers: `pdsh` (default), `standard`, `openmpi`, `mvapich`, `mpich`, `slurm`, `nossh` (requires DeepSpeed >= 0.14.5).

Please read the [DeepSpeed documentation](https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node) for more information on the different launchers. By default, DeepSpeed will attempt to use passwordless SSH from the main machine node to the other nodes to perform the launcher command. In this configuration, the accelerate launch command only needs to be run on the main node. If using the `nossh` launcher, you will need to run the accelerate launch command on every node using copied configuration.

## DeepSpeed Resources

The documentation for the internals related to deepspeed can be found [here](../package_reference/deepspeed).
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def launch_command_parser(subparsers=None):
"--deepspeed_multinode_launcher",
default=None,
type=str,
help="DeepSpeed multi-node launcher to use. If unspecified, will default to `pdsh`.",
help="DeepSpeed multi-node launcher to use, e.g. `pdsh`, `standard`, `openmpi`, `mvapich`, `mpich`, `slurm`, `nossh` (requires DeepSpeed >= 0.14.5). If unspecified, will default to `pdsh`.",
)
deepspeed_args.add_argument(
"--deepspeed_moe_layer_cls_names",
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"2.1.0.a0+32f93b1" # Technically should be 2.1.0, but MS-AMP uses this specific prerelease in their Docker image.
)
FSDP_MODEL_NAME = "pytorch_model_fsdp"
DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich"]
DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich", "nossh", "slurm"]
TORCH_DYNAMO_MODES = ["default", "reduce-overhead", "max-autotune"]
ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = "2.2.0"
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0"
Expand Down
11 changes: 9 additions & 2 deletions src/accelerate/utils/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from ..utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS
from ..utils.other import is_port_in_use, merge_dicts
from ..utils.versions import compare_versions
from .dataclasses import DistributedType, SageMakerDistributedType


Expand Down Expand Up @@ -321,8 +322,14 @@ def prepare_deepspeed_cmd_env(args: argparse.Namespace) -> Tuple[List[str], Dict
args.deepspeed_multinode_launcher = DEEPSPEED_MULTINODE_LAUNCHERS[0]

if num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]:
cmd = ["deepspeed", "--no_local_rank"]
cmd.extend(["--hostfile", str(args.deepspeed_hostfile), "--launcher", str(args.deepspeed_multinode_launcher)])
cmd = ["deepspeed"]
cmd.extend(["--hostfile", str(args.deepspeed_hostfile)])
if args.deepspeed_multinode_launcher == "nossh":
if compare_versions("deepspeed", "<", "0.14.5"):
raise ValueError("nossh launcher requires DeepSpeed >= 0.14.5")
cmd.extend(["--node_rank", str(args.machine_rank), "--no_ssh"])
hsmallbone marked this conversation as resolved.
Show resolved Hide resolved
else:
cmd.extend(["--no_local_rank", "--launcher", str(args.deepspeed_multinode_launcher)])
if args.deepspeed_exclusion_filter is not None:
cmd.extend(
[
Expand Down
Loading