Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mpirun protocol - distributed training with @remote decorator #4998

Merged
merged 21 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
841af92
implemented multi-node distribution with @remote function
brunopistone Jan 4, 2025
4fe2747
completed unit tests
brunopistone Jan 6, 2025
fa79639
added distributed training with CPU and torchrun
brunopistone Jan 8, 2025
43547b0
Merge branch 'master' into master
brunopistone Jan 13, 2025
06ab509
backwards compatibility nproc_per_node
brunopistone Jan 14, 2025
3a03c4b
Merge branch 'master' of https://github.com/brunopistone/sagemaker-py…
brunopistone Jan 14, 2025
bc5918a
Merge branch 'master' into master
brunopistone Jan 14, 2025
7d54096
fixing code: permissions for non-root users, integration tests
brunopistone Jan 15, 2025
423c585
fixed docstyle
brunopistone Jan 15, 2025
adcc38e
refactor nproc_per_node for backwards compatibility
brunopistone Jan 15, 2025
00eb637
refactor nproc_per_node for backwards compatibility
brunopistone Jan 15, 2025
0dea502
pylint fix, newlines
brunopistone Jan 16, 2025
b152915
added unit tests for bootstrap_environment remote
brunopistone Jan 16, 2025
c11f130
added mpirun protocol for distributed training with @remote decorator
brunopistone Jan 21, 2025
73cc79d
Merge branch 'master' into master
brunopistone Jan 21, 2025
8a54cc2
Merge branch 'master' into master
benieric Jan 29, 2025
8701782
Merge branch 'master' into master
benieric Jan 29, 2025
86c9f7d
aligned mpi_utils_remote.py to mpi_utils.py for estimator
brunopistone Jan 29, 2025
4a69cfd
Merge branch 'master' into master
benieric Jan 30, 2025
fd8a70c
updated docstring for sagemaker sdk doc
brunopistone Jan 30, 2025
9506d70
Merge branch 'master' of https://github.com/brunopistone/sagemaker-py…
brunopistone Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions src/sagemaker/remote_function/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def remote(
spark_config: SparkConfig = None,
use_spot_instances=False,
max_wait_time_in_seconds=None,
use_torchrun=False,
use_torchrun: bool = False,
use_mpirun: bool = False,
nproc_per_node: Optional[int] = None,
):
"""Decorator for running the annotated function as a SageMaker training job.
Expand Down Expand Up @@ -284,6 +285,9 @@ def remote(
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
Defaults to ``False``.

use_mpirun (bool): Specifies whether to use mpirun for distributed training.
Defaults to ``False``.

nproc_per_node (Optional int): Specifies the number of processes per node for
distributed training. Defaults to ``None``.
This is defined automatically configured on the instance type.
Expand Down Expand Up @@ -320,19 +324,21 @@ def _remote(func):
use_spot_instances=use_spot_instances,
max_wait_time_in_seconds=max_wait_time_in_seconds,
use_torchrun=use_torchrun,
use_mpirun=use_mpirun,
nproc_per_node=nproc_per_node,
)

@functools.wraps(func)
def wrapper(*args, **kwargs):

if instance_count > 1 and not (
(spark_config is not None and not use_torchrun)
or (spark_config is None and use_torchrun)
(spark_config is not None and not use_torchrun and not use_mpirun)
or (spark_config is None and use_torchrun and not use_mpirun)
or (spark_config is None and not use_torchrun and use_mpirun)
):
raise ValueError(
"Remote function do not support training on multi instances "
+ "without spark_config or use_torchrun. "
+ "without spark_config or use_torchrun or use_mpirun. "
+ "Please provide instance_count = 1"
)

Expand Down Expand Up @@ -536,7 +542,8 @@ def __init__(
spark_config: SparkConfig = None,
use_spot_instances=False,
max_wait_time_in_seconds=None,
use_torchrun=False,
use_torchrun: bool = False,
use_mpirun: bool = False,
nproc_per_node: Optional[int] = None,
):
"""Constructor for RemoteExecutor
Expand Down Expand Up @@ -730,6 +737,9 @@ def __init__(
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
Defaults to ``False``.

use_mpirun (bool): Specifies whether to use mpirun for distributed training.
Defaults to ``False``.

nproc_per_node (Optional int): Specifies the number of processes per node for
distributed training. Defaults to ``None``.
This is defined automatically configured on the instance type.
Expand All @@ -740,12 +750,13 @@ def __init__(
raise ValueError("max_parallel_jobs must be greater than 0.")

if instance_count > 1 and not (
(spark_config is not None and not use_torchrun)
or (spark_config is None and use_torchrun)
(spark_config is not None and not use_torchrun and not use_mpirun)
or (spark_config is None and use_torchrun and not use_mpirun)
or (spark_config is None and not use_torchrun and use_mpirun)
):
raise ValueError(
"Remote function do not support training on multi instances "
+ "without spark_config or use_torchrun. "
+ "without spark_config or use_torchrun or use_mpirun. "
+ "Please provide instance_count = 1"
)

Expand Down Expand Up @@ -778,6 +789,7 @@ def __init__(
use_spot_instances=use_spot_instances,
max_wait_time_in_seconds=max_wait_time_in_seconds,
use_torchrun=use_torchrun,
use_mpirun=use_mpirun,
nproc_per_node=nproc_per_node,
)

Expand Down
153 changes: 147 additions & 6 deletions src/sagemaker/remote_function/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@

# runtime script names
BOOTSTRAP_SCRIPT_NAME = "bootstrap_runtime_environment.py"
MPI_UTILS_SCRIPT_NAME = "mpi_utils_remote.py"
ENTRYPOINT_SCRIPT_NAME = "job_driver.sh"
PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh"
RUNTIME_MANAGER_SCRIPT_NAME = "runtime_environment_manager.py"
Expand Down Expand Up @@ -167,6 +168,99 @@
fi
"""

ENTRYPOINT_MPIRUN_SCRIPT = f"""
#!/bin/bash

# Entry point for bootstrapping runtime environment and invoking remote function with mpirun

set -eu

PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}}
export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs
printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\n"
export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n"

printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n"
cat /opt/ml/input/config/resourceconfig.json

printf "INFO: Bootstraping runtime environment.\\n"
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@"
source /opt/ml/input/sm_training.env

if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ]
then
if [ -f "remote_function_conda_env.txt" ]
then
cp remote_function_conda_env.txt {JOB_REMOTE_FUNCTION_WORKSPACE}/remote_function_conda_env.txt
fi
printf "INFO: Changing workspace to {JOB_REMOTE_FUNCTION_WORKSPACE}.\\n"
cd {JOB_REMOTE_FUNCTION_WORKSPACE}
fi

if [ -f "remote_function_conda_env.txt" ]
then
conda_env=$(cat remote_function_conda_env.txt)

if which mamba >/dev/null; then
conda_exe="mamba"
else
conda_exe="conda"
fi

if [ "$SM_CURRENT_HOST" = "$SM_MASTER_ADDR" ]; then
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME}

printf "INFO: Invoking remote function with mpirun inside conda environment: $conda_env.\\n"
printf "INFO: $conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
--allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
-mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \

python -m mpi4py -m sagemaker.remote_function.invoke_function \\n"
$conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
--allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
-mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
$SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
python -m mpi4py -m sagemaker.remote_function.invoke_function "$@"

python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1
else
printf "INFO: This is the instance $SM_CURRENT_HOST. mpirun command terminated\\n"
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME}
fi
else
if [ "$SM_CURRENT_HOST" = "$SM_MASTER_ADDR" ]; then
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME}

printf "INFO: No conda env provided. Invoking remote function with mpirun\\n"
printf "INFO: mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
--allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
-mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
$SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
python -m mpi4py -m sagemaker.remote_function.invoke_function \\n"

mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
--allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
-mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
$SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
python -m mpi4py -m sagemaker.remote_function.invoke_function "$@"

python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1
else
printf "INFO: This is the instance $SM_CURRENT_HOST.\\n"
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME}
fi
fi
"""

ENTRYPOINT_TORCHRUN_SCRIPT = f"""
#!/bin/bash

Expand Down Expand Up @@ -211,13 +305,15 @@
printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
--master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
-m sagemaker.remote_function.invoke_function \\n"

$conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
--master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
-m sagemaker.remote_function.invoke_function "$@"
else
printf "INFO: No conda env provided. Invoking remote function with torchrun\\n"
printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function \\n"

torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function "$@"
fi
Expand Down Expand Up @@ -278,6 +374,7 @@ def __init__(
use_spot_instances=False,
max_wait_time_in_seconds=None,
use_torchrun: bool = False,
use_mpirun: bool = False,
nproc_per_node: Optional[int] = None,
):
"""Initialize a _JobSettings instance which configures the remote job.
Expand Down Expand Up @@ -464,6 +561,9 @@ def __init__(
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
Defaults to ``False``.

use_mpirun (bool): Specifies whether to use mpirun for distributed training.
Defaults to ``False``.

nproc_per_node (Optional int): Specifies the number of processes per node for
distributed training. Defaults to ``None``.
This is defined automatically configured on the instance type.
Expand Down Expand Up @@ -626,6 +726,7 @@ def __init__(
self.tags = self.sagemaker_session._append_sagemaker_config_tags(tags, REMOTE_FUNCTION_TAGS)

self.use_torchrun = use_torchrun
self.use_mpirun = use_mpirun
self.nproc_per_node = nproc_per_node

@staticmethod
Expand Down Expand Up @@ -874,6 +975,12 @@ def compile(
).to_string(),
]
)
if job_settings.use_torchrun:
container_args.extend(["--distribution", "torchrun"])
elif job_settings.use_mpirun:
container_args.extend(["--distribution", "mpirun"])
if job_settings.nproc_per_node is not None and int(job_settings.nproc_per_node) > 0:
container_args.extend(["--user_nproc_per_node", str(job_settings.nproc_per_node)])
if job_settings.s3_kms_key:
container_args.extend(["--s3_kms_key", job_settings.s3_kms_key])

Expand Down Expand Up @@ -950,6 +1057,7 @@ def compile(
request_dict["Environment"].update({"REMOTE_FUNCTION_SECRET_KEY": hmac_key})

extended_request = _extend_spark_config_to_request(request_dict, job_settings, s3_base_uri)
extended_request = _extend_mpirun_to_request(extended_request, job_settings)
extended_request = _extend_torchrun_to_request(extended_request, job_settings)

return extended_request
Expand Down Expand Up @@ -1031,7 +1139,7 @@ def _prepare_and_upload_runtime_scripts(
s3_kms_key: str,
sagemaker_session: Session,
use_torchrun: bool = False,
nproc_per_node: Optional[int] = None,
use_mpirun: bool = False,
):
"""Copy runtime scripts to a folder and upload to S3.

Expand All @@ -1050,6 +1158,8 @@ def _prepare_and_upload_runtime_scripts(

use_torchrun (bool): Whether to use torchrun or not.

use_mpirun (bool): Whether to use mpirun or not.

nproc_per_node (Optional[int]): Number of processes per node
"""

Expand All @@ -1075,23 +1185,25 @@ def _prepare_and_upload_runtime_scripts(
if use_torchrun:
entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT

if nproc_per_node is not None and nproc_per_node > 0:
entry_point_script = entry_point_script.replace(
"$SM_NPROC_PER_NODE", str(nproc_per_node)
)
if use_mpirun:
entry_point_script = ENTRYPOINT_MPIRUN_SCRIPT

with open(entrypoint_script_path, "w", newline="\n") as file:
file.writelines(entry_point_script)

bootstrap_script_path = os.path.join(
os.path.dirname(__file__), "runtime_environment", BOOTSTRAP_SCRIPT_NAME
)
mpi_utils_path = os.path.join(
os.path.dirname(__file__), "runtime_environment", MPI_UTILS_SCRIPT_NAME
)
runtime_manager_script_path = os.path.join(
os.path.dirname(__file__), "runtime_environment", RUNTIME_MANAGER_SCRIPT_NAME
)

# copy runtime scripts to tmpdir
shutil.copy2(bootstrap_script_path, bootstrap_scripts)
shutil.copy2(mpi_utils_path, bootstrap_scripts)
shutil.copy2(runtime_manager_script_path, bootstrap_scripts)

upload_path = S3Uploader.upload(
Expand All @@ -1118,7 +1230,7 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
s3_kms_key=job_settings.s3_kms_key,
sagemaker_session=job_settings.sagemaker_session,
use_torchrun=job_settings.use_torchrun,
nproc_per_node=job_settings.nproc_per_node,
use_mpirun=job_settings.use_mpirun,
)

input_data_config = [
Expand Down Expand Up @@ -1459,6 +1571,35 @@ def _upload_serialized_spark_configuration(
return config_file_s3_uri


def _extend_mpirun_to_request(
request_dict: Dict,
job_settings: _JobSettings,
) -> Dict:
"""Extend the create training job request with mpirun configuration.

Args:
request_dict (Dict): create training job request dict.
job_settings (_JobSettings): the job settings.
"""
use_mpirun = job_settings.use_mpirun
instance_count = job_settings.instance_count

if not use_mpirun:
return request_dict

if instance_count == 1:
return request_dict

extended_request = request_dict.copy()

for input_channel in extended_request["InputDataConfig"]:
s3_data_source = input_channel["DataSource"].get("S3DataSource", None)
if s3_data_source:
s3_data_source["S3DataDistributionType"] = "FullyReplicated"

return extended_request


def _extend_torchrun_to_request(
request_dict: Dict,
job_settings: _JobSettings,
Expand Down
Loading
Loading