diff --git a/nvflare/apis/event_type.py b/nvflare/apis/event_type.py index 5bdbe50bc8..0868e09379 100644 --- a/nvflare/apis/event_type.py +++ b/nvflare/apis/event_type.py @@ -89,3 +89,4 @@ class EventType(object): AUTHORIZE_COMMAND_CHECK = "_authorize_command_check" BEFORE_BUILD_COMPONENT = "_before_build_component" + GET_JOB_LAUNCHER = "_get_job_launcher" diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index a4908c6b6e..70b2e2af41 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -160,6 +160,8 @@ class FLContextKey(object): AUTHORIZATION_REASON = "_authorization_reason" DISCONNECTED_CLIENT_NAME = "_disconnected_client_name" RECONNECTED_CLIENT_NAME = "_reconnected_client_name" + SITE_OBJ = "_site_obj_" + JOB_LAUNCHER = "_job_launcher" CLIENT_REGISTER_DATA = "_client_register_data" SECURITY_ITEMS = "_security_items" @@ -324,7 +326,7 @@ class SnapshotKey(object): class RunProcessKey(object): LISTEN_PORT = "_listen_port" CONNECTION = "_conn" - CHILD_PROCESS = "_child_process" + JOB_HANDLE = "_job_launcher" STATUS = "_status" JOB_ID = "_job_id" PARTICIPANTS = "_participants" @@ -356,6 +358,10 @@ class JobConstants: CLIENT_JOB_CONFIG = "config_fed_client.json" META_FILE = "meta.json" META = "meta" + SITES = "sites" + JOB_IMAGE = "image" + JOB_ID = "job_id" + JOB_LAUNCHER = "job_launcher" class WorkspaceConstants: diff --git a/nvflare/apis/job_launcher_spec.py b/nvflare/apis/job_launcher_spec.py new file mode 100644 index 0000000000..fcc0e04c94 --- /dev/null +++ b/nvflare/apis/job_launcher_spec.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import abstractmethod + +from nvflare.apis.fl_component import FLComponent +from nvflare.apis.fl_constant import FLContextKey +from nvflare.apis.fl_context import FLContext +from nvflare.fuel.common.exit_codes import ProcessExitCode + + +class JobReturnCode(ProcessExitCode): + SUCCESS = 0 + EXECUTION_ERROR = 1 + ABORTED = 9 + UNKNOWN = 127 + + +def add_launcher(launcher, fl_ctx: FLContext): + job_launcher: list = fl_ctx.get_prop(FLContextKey.JOB_LAUNCHER, []) + job_launcher.append(launcher) + fl_ctx.set_prop(FLContextKey.JOB_LAUNCHER, job_launcher, private=True, sticky=False) + + +class JobHandleSpec: + @abstractmethod + def terminate(self): + """To terminate the job run. + + Returns: the job run return code. + + """ + raise NotImplementedError() + + @abstractmethod + def poll(self): + """To get the return code of the job run. + + Returns: return_code + + """ + raise NotImplementedError() + + @abstractmethod + def wait(self): + """To wait until the job run complete. + + Returns: returns until the job run complete. + + """ + raise NotImplementedError() + + +class JobLauncherSpec(FLComponent): + @abstractmethod + def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec: + """To launch a job run. + + Args: + job_meta: job meta data + fl_ctx: FLContext + + Returns: boolean to indicates the job launch success or fail. + + """ + raise NotImplementedError() diff --git a/nvflare/apis/server_engine_spec.py b/nvflare/apis/server_engine_spec.py index be1c40f24d..4c4013f0c1 100644 --- a/nvflare/apis/server_engine_spec.py +++ b/nvflare/apis/server_engine_spec.py @@ -203,12 +203,12 @@ def restore_components(self, snapshot: RunSnapshot, fl_ctx: FLContext): pass @abstractmethod - def start_client_job(self, job_id, client_sites, fl_ctx: FLContext): + def start_client_job(self, job, client_sites, fl_ctx: FLContext): """To send the start client run commands to the clients Args: client_sites: client sites - job_id: job_id + job: job object fl_ctx: FLContext Returns: diff --git a/nvflare/app_common/job_launcher/__init__.py b/nvflare/app_common/job_launcher/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/nvflare/app_common/job_launcher/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/app_common/job_launcher/process_launcher.py b/nvflare/app_common/job_launcher/process_launcher.py new file mode 100644 index 0000000000..912893feff --- /dev/null +++ b/nvflare/app_common/job_launcher/process_launcher.py @@ -0,0 +1,120 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import shlex +import subprocess +import sys + +from nvflare.apis.event_type import EventType +from nvflare.apis.fl_constant import FLContextKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.job_def import JobMetaKey +from nvflare.apis.job_launcher_spec import JobHandleSpec, JobLauncherSpec, JobReturnCode, add_launcher +from nvflare.apis.workspace import Workspace +from nvflare.private.fed.utils.fed_utils import add_custom_dir_to_path, extract_job_image + +JOB_RETURN_CODE_MAPPING = {0: JobReturnCode.SUCCESS, 1: JobReturnCode.EXECUTION_ERROR, 9: JobReturnCode.ABORTED} + + +class ProcessHandle(JobHandleSpec): + def __init__(self, process): + super().__init__() + + self.process = process + self.logger = logging.getLogger(self.__class__.__name__) + + def terminate(self): + if self.process: + try: + os.killpg(os.getpgid(self.process.pid), 9) + self.logger.debug("kill signal sent") + except: + pass + + self.process.terminate() + + def poll(self): + if self.process: + return JOB_RETURN_CODE_MAPPING.get(self.process.poll(), JobReturnCode.EXECUTION_ERROR) + else: + return JobReturnCode.UNKNOWN + + def wait(self): + if self.process: + self.process.wait() + + +class ProcessJobLauncher(JobLauncherSpec): + def __init__(self): + super().__init__() + + self.logger = logging.getLogger(self.__class__.__name__) + + def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec: + + new_env = os.environ.copy() + workspace_obj: Workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) + args = fl_ctx.get_prop(FLContextKey.ARGS) + client = fl_ctx.get_prop(FLContextKey.SITE_OBJ) + job_id = job_meta.get(JobMetaKey.JOB_ID) + server_config = fl_ctx.get_prop(FLContextKey.SERVER_CONFIG) + if not server_config: + raise RuntimeError(f"missing {FLContextKey.SERVER_CONFIG} in FL context") + service = server_config[0].get("service", {}) + if not isinstance(service, dict): + raise RuntimeError(f"expect server config data to be dict but got {type(service)}") + + app_custom_folder = workspace_obj.get_app_custom_dir(job_id) + if app_custom_folder != "": + add_custom_dir_to_path(app_custom_folder, new_env) + + command_options = "" + for t in args.set: + command_options += " " + t + command = ( + f"{sys.executable} -m nvflare.private.fed.app.client.worker_process -m " + + args.workspace + + " -w " + + (workspace_obj.get_startup_kit_dir()) + + " -t " + + client.token + + " -d " + + client.ssid + + " -n " + + job_id + + " -c " + + client.client_name + + " -p " + + str(client.cell.get_internal_listener_url()) + + " -g " + + service.get("target") + + " -scheme " + + service.get("scheme", "grpc") + + " -s fed_client.json " + " --set" + command_options + " print_conf=True" + ) + # use os.setsid to create new process group ID + process = subprocess.Popen(shlex.split(command, True), preexec_fn=os.setsid, env=new_env) + + self.logger.info("Worker child process ID: {}".format(process.pid)) + + return ProcessHandle(process) + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.GET_JOB_LAUNCHER: + job_meta = fl_ctx.get_prop(FLContextKey.JOB_META) + job_image = extract_job_image(job_meta, fl_ctx.get_identity_name()) + if not job_image: + add_launcher(self, fl_ctx) diff --git a/nvflare/app_common/job_schedulers/job_scheduler.py b/nvflare/app_common/job_schedulers/job_scheduler.py index c7e03d394f..d5cd35c817 100644 --- a/nvflare/app_common/job_schedulers/job_scheduler.py +++ b/nvflare/app_common/job_schedulers/job_scheduler.py @@ -25,6 +25,7 @@ from nvflare.apis.job_def_manager_spec import JobDefManagerSpec from nvflare.apis.job_scheduler_spec import DispatchInfo, JobSchedulerSpec from nvflare.apis.server_engine_spec import ServerEngineSpec +from nvflare.private.fed.utils.fed_utils import extract_participants SCHEDULE_RESULT_OK = 0 # the job is scheduled SCHEDULE_RESULT_NO_RESOURCE = 1 # job is not scheduled due to lack of resources @@ -109,7 +110,9 @@ def _try_job(self, job: Job, fl_ctx: FLContext) -> (int, Optional[Dict[str, Disp applicable_sites = [] sites_to_app = {} for app_name in job.deploy_map: - for site_name in job.deploy_map[app_name]: + deployments = job.deploy_map[app_name] + deployments = extract_participants(deployments) + for site_name in deployments: if site_name.upper() == ALL_SITES: # deploy_map: {"app_name": ["ALL_SITES"]} will be treated as deploying to all online clients applicable_sites = online_site_names diff --git a/nvflare/app_opt/job_launcher/__init__.py b/nvflare/app_opt/job_launcher/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/nvflare/app_opt/job_launcher/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/app_opt/job_launcher/k8s_launcher.py b/nvflare/app_opt/job_launcher/k8s_launcher.py new file mode 100644 index 0000000000..bba0df01e0 --- /dev/null +++ b/nvflare/app_opt/job_launcher/k8s_launcher.py @@ -0,0 +1,275 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import time +from enum import Enum + +from kubernetes import config +from kubernetes.client import Configuration +from kubernetes.client.api import core_v1_api +from kubernetes.client.rest import ApiException + +from nvflare.apis.event_type import EventType +from nvflare.apis.fl_constant import FLContextKey, JobConstants +from nvflare.apis.fl_context import FLContext +from nvflare.apis.job_launcher_spec import JobHandleSpec, JobLauncherSpec, JobReturnCode, add_launcher +from nvflare.apis.workspace import Workspace +from nvflare.private.fed.utils.fed_utils import extract_job_image + + +class JobState(Enum): + STARTING = "starting" + RUNNING = "running" + TERMINATED = "terminated" + SUCCEEDED = "succeeded" + UNKNOWN = "unknown" + + +POD_STATE_MAPPING = { + "Pending": JobState.STARTING, + "Running": JobState.RUNNING, + "Succeeded": JobState.SUCCEEDED, + "Failed": JobState.TERMINATED, + "Unknown": JobState.UNKNOWN, +} + +JOB_RETURN_CODE_MAPPING = { + JobState.SUCCEEDED: JobReturnCode.SUCCESS, + JobState.STARTING: JobReturnCode.UNKNOWN, + JobState.RUNNING: JobReturnCode.UNKNOWN, + JobState.TERMINATED: JobReturnCode.ABORTED, + JobState.UNKNOWN: JobReturnCode.UNKNOWN, +} + + +class K8sJobHandle(JobHandleSpec): + def __init__(self, job_id: str, api_instance: core_v1_api, job_config: dict, namespace="default", timeout=None): + super().__init__() + self.job_id = job_id + self.timeout = timeout + + self.api_instance = api_instance + self.namespace = namespace + self.pod_manifest = { + "apiVersion": "v1", + "kind": "Pod", + "metadata": {"name": None}, # set by job_config['name'] + "spec": { + "containers": None, # link to container_list + "volumes": None, # link to volume_list + "restartPolicy": "OnFailure", + }, + } + self.volume_list = [{"name": None, "hostPath": {"path": None, "type": "Directory"}}] + self.container_list = [ + { + "image": None, + "name": None, + "command": ["/usr/local/bin/python"], + "args": None, # args_list + args_dict + args_sets + "volumeMounts": None, # volume_mount_list + "imagePullPolicy": "Always", + } + ] + self.container_args_python_args_list = ["-u", "-m", "nvflare.private.fed.app.client.worker_process"] + self.container_args_module_args_dict = { + "-m": None, + "-w": None, + "-t": None, + "-d": None, + "-n": None, + "-c": None, + "-p": None, + "-g": None, + "-scheme": None, + "-s": None, + } + self.container_volume_mount_list = [ + { + "name": None, + "mountPath": None, + } + ] + self._make_manifest(job_config) + + def _make_manifest(self, job_config): + self.container_volume_mount_list = job_config.get( + "volume_mount_list", [{"name": "workspace-nvflare", "mountPath": "/workspace/nvflare"}] + ) + set_list = job_config.get("set_list") + if set_list is None: + self.container_args_module_args_sets = list() + else: + self.container_args_module_args_sets = ["--set"] + set_list + self.container_args_module_args_dict = job_config.get( + "module_args", + { + "-m": None, + "-w": None, + "-t": None, + "-d": None, + "-n": None, + "-c": None, + "-p": None, + "-g": None, + "-scheme": None, + "-s": None, + }, + ) + self.container_args_module_args_dict_as_list = list() + for k, v in self.container_args_module_args_dict.items(): + self.container_args_module_args_dict_as_list.append(k) + self.container_args_module_args_dict_as_list.append(v) + self.volume_list = job_config.get( + "volume_list", [{"name": None, "hostPath": {"path": None, "type": "Directory"}}] + ) + + self.pod_manifest["metadata"]["name"] = job_config.get("name") + self.pod_manifest["spec"]["containers"] = self.container_list + self.pod_manifest["spec"]["volumes"] = self.volume_list + + self.container_list[0]["image"] = job_config.get("image", "nvflare/nvflare:2.5.0") + self.container_list[0]["name"] = job_config.get("container_name", "nvflare_job") + self.container_list[0]["args"] = ( + self.container_args_python_args_list + + self.container_args_module_args_dict_as_list + + self.container_args_module_args_sets + ) + self.container_list[0]["volumeMounts"] = self.container_volume_mount_list + + def get_manifest(self): + return self.pod_manifest + + def enter_states(self, job_states_to_enter: list, timeout=None): + starting_time = time.time() + if not isinstance(job_states_to_enter, (list, tuple)): + job_states_to_enter = [job_states_to_enter] + if not all([isinstance(js, JobState)] for js in job_states_to_enter): + raise ValueError(f"expect job_states_to_enter with valid values, but get {job_states_to_enter}") + while True: + job_state = self._query_state() + if job_state in job_states_to_enter: + return True + elif timeout is not None and time.time() - starting_time > timeout: + return False + time.sleep(1) + + def terminate(self): + resp = self.api_instance.delete_namespaced_pod( + name=self.job_id, namespace=self.namespace, grace_period_seconds=0 + ) + return self.enter_states([JobState.TERMINATED], timeout=self.timeout) + + def poll(self): + job_state = self._query_state() + return JOB_RETURN_CODE_MAPPING.get(job_state, JobReturnCode.UNKNOWN) + + def _query_state(self): + try: + resp = self.api_instance.read_namespaced_pod(name=self.job_id, namespace=self.namespace) + except ApiException as e: + return JobState.UNKNOWN + return POD_STATE_MAPPING.get(resp.status.phase, JobState.UNKNOWN) + + def wait(self): + self.enter_states([JobState.SUCCEEDED, JobState.TERMINATED]) + + +class K8sJobLauncher(JobLauncherSpec): + def __init__( + self, + config_file_path, + root_hostpath: str, + workspace: str, + mount_path: str, + timeout=None, + namespace="default", + ): + super().__init__() + + self.root_hostpath = root_hostpath + self.workspace = workspace + self.mount_path = mount_path + self.timeout = timeout + + config.load_kube_config(config_file_path) + try: + c = Configuration().get_default_copy() + except AttributeError: + c = Configuration() + c.assert_hostname = False + Configuration.set_default(c) + self.core_v1 = core_v1_api.CoreV1Api() + self.namespace = namespace + + self.job_handle = None + self.logger = logging.getLogger(self.__class__.__name__) + + def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec: + + workspace_obj: Workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) + args = fl_ctx.get_prop(FLContextKey.ARGS) + client = fl_ctx.get_prop(FLContextKey.SITE_OBJ) + job_id = job_meta.get(JobConstants.JOB_ID) + server_config = fl_ctx.get_prop(FLContextKey.SERVER_CONFIG) + if not server_config: + raise RuntimeError(f"missing {FLContextKey.SERVER_CONFIG} in FL context") + service = server_config[0].get("service", {}) + if not isinstance(service, dict): + raise RuntimeError(f"expect server config data to be dict but got {type(service)}") + + self.logger.info(f"K8sJobLauncher start to launch job: {job_id} for client: {client.client_name}") + job_image = extract_job_image(job_meta, fl_ctx.get_identity_name()) + self.logger.info(f"launch job use image: {job_image}") + job_config = { + "name": job_id, + "image": job_image, + "container_name": f"container-{job_id}", + "volume_mount_list": [{"name": self.workspace, "mountPath": self.mount_path}], + "volume_list": [{"name": self.workspace, "hostPath": {"path": self.root_hostpath, "type": "Directory"}}], + "module_args": { + "-m": args.workspace, + "-w": (workspace_obj.get_startup_kit_dir()), + "-t": client.token, + "-d": client.ssid, + "-n": job_id, + "-c": client.client_name, + "-p": "tcp://parent-pod:8004", + "-g": service.get("target"), + "-scheme": service.get("scheme", "grpc"), + "-s": "fed_client.json", + }, + "set_list": args.set, + } + + self.logger.info(f"launch job with k8s_launcher. Job_id:{job_id}") + + job_handle = K8sJobHandle(job_id, self.core_v1, job_config, namespace=self.namespace, timeout=self.timeout) + try: + self.core_v1.create_namespaced_pod(body=job_handle.get_manifest(), namespace=self.namespace) + if job_handle.enter_states([JobState.RUNNING], timeout=self.timeout): + return job_handle + else: + job_handle.terminate() + return None + except ApiException as e: + job_handle.terminate() + return None + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.GET_JOB_LAUNCHER: + job_meta = fl_ctx.get_prop(FLContextKey.JOB_META) + job_image = extract_job_image(job_meta, fl_ctx.get_identity_name()) + if job_image: + add_launcher(self, fl_ctx) diff --git a/nvflare/lighter/impl/master_template.yml b/nvflare/lighter/impl/master_template.yml index 1477faad4e..1265d1206e 100644 --- a/nvflare/lighter/impl/master_template.yml +++ b/nvflare/lighter/impl/master_template.yml @@ -92,6 +92,11 @@ local_client_resources: | "id": "resource_consumer", "path": "nvflare.app_common.resource_consumers.gpu_resource_consumer.GPUResourceConsumer", "args": {} + }, + { + "id": "process_launcher", + "path": "nvflare.app_common.job_launcher.process_launcher.ProcessJobLauncher", + "args": {} } ] } diff --git a/nvflare/private/fed/app/client/client_train.py b/nvflare/private/fed/app/client/client_train.py index 20320fb328..7618622ad8 100644 --- a/nvflare/private/fed/app/client/client_train.py +++ b/nvflare/private/fed/app/client/client_train.py @@ -100,6 +100,7 @@ def main(args): federated_client.use_gpu = False federated_client.config_folder = config_folder + workspace = Workspace(args.workspace, federated_client.client_name, config_folder) client_engine = ClientEngine(federated_client, args, rank) @@ -108,6 +109,8 @@ def main(args): time.sleep(1.0) with client_engine.new_context() as fl_ctx: + client_engine.fire_event(EventType.SYSTEM_BOOTSTRAP, fl_ctx) + fl_ctx.set_prop( key=FLContextKey.CLIENT_CONFIG, value=deployer.client_config, @@ -128,7 +131,8 @@ def main(args): ) fl_ctx.set_prop(FLContextKey.WORKSPACE_OBJECT, workspace, private=True) - client_engine.fire_event(EventType.SYSTEM_BOOTSTRAP, fl_ctx) + fl_ctx.set_prop(FLContextKey.ARGS, args, private=True, sticky=True) + fl_ctx.set_prop(FLContextKey.SITE_OBJ, federated_client, private=True, sticky=True) component_security_check(fl_ctx) diff --git a/nvflare/private/fed/app/simulator/simulator_runner.py b/nvflare/private/fed/app/simulator/simulator_runner.py index 1d19fac603..f88d9e5772 100644 --- a/nvflare/private/fed/app/simulator/simulator_runner.py +++ b/nvflare/private/fed/app/simulator/simulator_runner.py @@ -450,7 +450,7 @@ def simulator_run_main(self): try: self.create_clients() self.server.engine.run_processes[SimulatorConstants.JOB_NAME] = { - RunProcessKey.CHILD_PROCESS: None, + RunProcessKey.JOB_HANDLE: None, RunProcessKey.JOB_ID: SimulatorConstants.JOB_NAME, RunProcessKey.PARTICIPANTS: self.server.engine.client_manager.clients, } diff --git a/nvflare/private/fed/client/client_engine.py b/nvflare/private/fed/client/client_engine.py index 6c4acb92e9..6f293abdb4 100644 --- a/nvflare/private/fed/client/client_engine.py +++ b/nvflare/private/fed/client/client_engine.py @@ -32,7 +32,7 @@ from nvflare.security.logging import secure_format_exception, secure_log_traceback from .client_engine_internal_spec import ClientEngineInternalSpec -from .client_executor import ProcessExecutor +from .client_executor import JobExecutor from .client_run_manager import ClientRunInfo from .client_status import ClientStatus from .fed_client import FederatedClient @@ -62,7 +62,7 @@ def __init__(self, client: FederatedClient, args, rank, workers=5): self.client_name = client.client_name self.args = args self.rank = rank - self.client_executor = ProcessExecutor(client, os.path.join(args.workspace, "startup")) + self.client_executor = JobExecutor(client, os.path.join(args.workspace, "startup")) self.admin_agent = None self.fl_ctx_mgr = FLContextManager( @@ -134,6 +134,7 @@ def get_engine_status(self): def start_app( self, job_id: str, + job_meta: dict, allocated_resource: dict = None, token: str = None, resource_manager=None, @@ -160,17 +161,16 @@ def start_app( self.logger.info("Starting client app. rank: {}".format(self.rank)) - server_config = list(self.client.servers.values())[0] self.client_executor.start_app( self.client, job_id, + job_meta, self.args, app_custom_folder, allocated_resource, token, resource_manager, - target=server_config["target"], - scheme=server_config.get("scheme", "grpc"), + fl_ctx=self.new_context(), ) return "Start the client app..." diff --git a/nvflare/private/fed/client/client_engine_internal_spec.py b/nvflare/private/fed/client/client_engine_internal_spec.py index bd1e852cdc..95008fb7b1 100644 --- a/nvflare/private/fed/client/client_engine_internal_spec.py +++ b/nvflare/private/fed/client/client_engine_internal_spec.py @@ -53,6 +53,7 @@ def deploy_app(self, app_name: str, job_id: str, job_meta: dict, client_name: st def start_app( self, job_id: str, + job_meta: dict, allocated_resource: dict = None, token: str = None, resource_manager=None, diff --git a/nvflare/private/fed/client/client_executor.py b/nvflare/private/fed/client/client_executor.py index e2e8f51586..6a8a111239 100644 --- a/nvflare/private/fed/client/client_executor.py +++ b/nvflare/private/fed/client/client_executor.py @@ -13,22 +13,21 @@ # limitations under the License. import logging -import os -import shlex -import subprocess -import sys import threading import time from abc import ABC, abstractmethod -from nvflare.apis.fl_constant import AdminCommandNames, RunProcessKey, SystemConfigs +from nvflare.apis.event_type import EventType +from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey, RunProcessKey, SystemConfigs +from nvflare.apis.fl_context import FLContext +from nvflare.apis.job_launcher_spec import JobLauncherSpec from nvflare.apis.resource_manager_spec import ResourceManagerSpec from nvflare.fuel.common.exit_codes import PROCESS_EXIT_REASON, ProcessExitCode from nvflare.fuel.f3.cellnet.core_cell import FQCN from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode from nvflare.fuel.utils.config_service import ConfigService from nvflare.private.defs import CellChannel, CellChannelTopic, JobFailureMsgKey, new_cell_message -from nvflare.private.fed.utils.fed_utils import add_custom_dir_to_path, get_return_code +from nvflare.private.fed.utils.fed_utils import get_return_code from nvflare.security.logging import secure_format_exception, secure_log_traceback from .client_status import ClientStatus, get_status_message @@ -40,13 +39,13 @@ def start_app( self, client, job_id, + job_meta, args, app_custom_folder, allocated_resource, token, resource_manager, - target: str, - scheme: str, + fl_ctx: FLContext, ): """Starts the client app. @@ -58,8 +57,7 @@ def start_app( allocated_resource: allocated resources token: token from resource manager resource_manager: resource manager - target: SP target location - scheme: SP target connection scheme + fl_ctx: FLContext """ pass @@ -122,7 +120,7 @@ def reset_errors(self, job_id): """ -class ProcessExecutor(ClientExecutor): +class JobExecutor(ClientExecutor): """Run the Client executor in a child process.""" def __init__(self, client, startup): @@ -145,66 +143,37 @@ def start_app( self, client, job_id, + job_meta, args, app_custom_folder, allocated_resource, token, resource_manager: ResourceManagerSpec, - target: str, - scheme: str, + fl_ctx: FLContext, ): """Starts the app. Args: client: the FL client object job_id: the job_id + job_meta: job meta data args: admin command arguments for starting the worker process app_custom_folder: FL application custom folder allocated_resource: allocated resources token: token from resource manager resource_manager: resource manager - target: SP target location - scheme: SP connection scheme + fl_ctx: FLContext """ - new_env = os.environ.copy() - if app_custom_folder != "": - add_custom_dir_to_path(app_custom_folder, new_env) - - command_options = "" - for t in args.set: - command_options += " " + t - command = ( - f"{sys.executable} -m nvflare.private.fed.app.client.worker_process -m " - + args.workspace - + " -w " - + self.startup - + " -t " - + client.token - + " -d " - + client.ssid - + " -n " - + job_id - + " -c " - + client.client_name - + " -p " - + str(client.cell.get_internal_listener_url()) - + " -g " - + target - + " -scheme " - + scheme - + " -s fed_client.json " - " --set" + command_options + " print_conf=True" - ) - # use os.setsid to create new process group ID - process = subprocess.Popen(shlex.split(command, True), preexec_fn=os.setsid, env=new_env) - self.logger.info("Worker child process ID: {}".format(process.pid)) + job_launcher: JobLauncherSpec = self._get_job_launcher(job_meta, fl_ctx) + job_handle = job_launcher.launch_job(job_meta, fl_ctx) + self.logger.info(f"Launch job_id: {job_id} with job launcher: {type(job_launcher)} ") client.multi_gpu = False with self.lock: self.run_processes[job_id] = { - RunProcessKey.CHILD_PROCESS: process, + RunProcessKey.JOB_HANDLE: job_handle, RunProcessKey.STATUS: ClientStatus.STARTING, } @@ -214,6 +183,17 @@ def start_app( ) thread.start() + def _get_job_launcher(self, job_meta: dict, fl_ctx: FLContext) -> JobLauncherSpec: + engine = fl_ctx.get_engine() + fl_ctx.set_prop(FLContextKey.JOB_META, job_meta, private=True, sticky=False) + engine.fire_event(EventType.GET_JOB_LAUNCHER, fl_ctx) + + job_launcher = fl_ctx.get_prop(FLContextKey.JOB_LAUNCHER) + if not (job_launcher and isinstance(job_launcher, list)): + raise RuntimeError(f"There's no job launcher can handle this job: {job_meta}.") + + return job_launcher[0] + def notify_job_status(self, job_id, job_status): run_process = self.run_processes.get(job_id) if run_process: @@ -336,7 +316,7 @@ def abort_app(self, job_id): if process_status == ClientStatus.STARTED: try: with self.lock: - child_process = self.run_processes[job_id][RunProcessKey.CHILD_PROCESS] + job_handle = self.run_processes[job_id][RunProcessKey.JOB_HANDLE] data = {} request = new_cell_message({}, data) self.client.cell.fire_and_forget( @@ -347,7 +327,7 @@ def abort_app(self, job_id): optional=True, ) self.logger.debug("abort sent to worker") - t = threading.Thread(target=self._terminate_process, args=[child_process, job_id]) + t = threading.Thread(target=self._terminate_job, args=[job_handle, job_id]) t.start() t.join() break @@ -365,7 +345,7 @@ def abort_app(self, job_id): self.logger.info("Client worker process is terminated.") - def _terminate_process(self, child_process, job_id): + def _terminate_job(self, job_handle, job_id): max_wait = 10.0 done = False start = time.time() @@ -382,16 +362,7 @@ def _terminate_process(self, child_process, job_id): time.sleep(0.05) # we want to quickly check - # kill the sub-process group directly - if not done: - self.logger.debug(f"still not done after {max_wait} secs") - try: - os.killpg(os.getpgid(child_process.pid), 9) - self.logger.debug("kill signal sent") - except: - pass - - child_process.terminate() + job_handle.terminate() self.logger.info(f"run ({job_id}): child worker process terminated") def abort_task(self, job_id): @@ -415,11 +386,11 @@ def abort_task(self, job_id): def _wait_child_process_finish(self, client, job_id, allocated_resource, token, resource_manager, workspace): self.logger.info(f"run ({job_id}): waiting for child worker process to finish.") - child_process = self.run_processes.get(job_id, {}).get(RunProcessKey.CHILD_PROCESS) - if child_process: - child_process.wait() + job_handle = self.run_processes.get(job_id, {}).get(RunProcessKey.JOB_HANDLE) + if job_handle: + job_handle.wait() - return_code = get_return_code(child_process, job_id, workspace, self.logger) + return_code = get_return_code(job_handle, job_id, workspace, self.logger) self.logger.info(f"run ({job_id}): child worker process finished with RC {return_code}") if return_code in [ProcessExitCode.UNSAFE_COMPONENT, ProcessExitCode.CONFIG_ERROR]: diff --git a/nvflare/private/fed/client/scheduler_cmds.py b/nvflare/private/fed/client/scheduler_cmds.py index 7ea2e8d55b..ef47354365 100644 --- a/nvflare/private/fed/client/scheduler_cmds.py +++ b/nvflare/private/fed/client/scheduler_cmds.py @@ -101,6 +101,7 @@ def process(self, req: Message, app_ctx) -> Message: try: resource_spec = req.body job_id = req.get_header(RequestHeader.JOB_ID) + job_meta = req.get_header(RequestHeader.JOB_META) token = req.get_header(ShareableHeader.RESOURCE_RESERVE_TOKEN) except Exception as e: msg = f"{ERROR_MSG_PREFIX}: Start job execution exception, missing required information: {secure_format_exception(e)}." @@ -116,6 +117,7 @@ def process(self, req: Message, app_ctx) -> Message: resource_consumer.consume(allocated_resources) result = engine.start_app( job_id, + job_meta=job_meta, allocated_resource=allocated_resources, token=token, resource_manager=resource_manager, diff --git a/nvflare/private/fed/server/job_meta_validator.py b/nvflare/private/fed/server/job_meta_validator.py index 0856b736eb..9362125081 100644 --- a/nvflare/private/fed/server/job_meta_validator.py +++ b/nvflare/private/fed/server/job_meta_validator.py @@ -24,6 +24,7 @@ from nvflare.apis.job_meta_validator_spec import JobMetaValidatorSpec from nvflare.fuel.utils.config import ConfigFormat from nvflare.fuel.utils.config_factory import ConfigFactory +from nvflare.private.fed.utils.fed_utils import extract_participants from nvflare.security.logging import secure_format_exception CONFIG_FOLDER = "/config/" @@ -101,7 +102,11 @@ def _validate_deploy_map(job_name: str, meta: dict) -> list: if not deploy_map: raise ValueError(f"deploy_map is empty for job {job_name}") - site_list = [site for deployments in deploy_map.values() for site in deployments] + site_list = [] + for deployments in deploy_map.values(): + deployments = extract_participants(deployments) + for site in deployments: + site_list.append(site) if not site_list: raise ValueError(f"No site is specified in deploy_map for job {job_name}") @@ -126,6 +131,7 @@ def _validate_app(self, job_name: str, meta: dict, zip_file: ZipFile) -> None: has_byoc = False for app, deployments in deploy_map.items(): + deployments = extract_participants(deployments) config_folder = job_name + "/" + app + CONFIG_FOLDER if not self._entry_exists(zip_file, config_folder): diff --git a/nvflare/private/fed/server/job_runner.py b/nvflare/private/fed/server/job_runner.py index d6bf91e888..7c38294027 100644 --- a/nvflare/private/fed/server/job_runner.py +++ b/nvflare/private/fed/server/job_runner.py @@ -35,7 +35,7 @@ from nvflare.private.fed.server.admin import check_client_replies from nvflare.private.fed.server.server_state import HotState from nvflare.private.fed.utils.app_deployer import AppDeployer -from nvflare.private.fed.utils.fed_utils import set_message_security_data +from nvflare.private.fed.utils.fed_utils import extract_participants, set_message_security_data from nvflare.security.logging import secure_format_exception @@ -131,6 +131,7 @@ def _deploy_job(self, job: Job, sites: dict, fl_ctx: FLContext) -> Tuple[str, li for app_name, participants in job.get_deployment().items(): app_data = job.get_application(app_name, fl_ctx) + participants = extract_participants(participants) if len(participants) == 1 and participants[0].upper() == ALL_SITES: participants = ["server"] @@ -249,7 +250,7 @@ def _start_run(self, job_id: str, job: Job, client_sites: Dict[str, DispatchInfo if err: raise RuntimeError(f"Could not start the server App for job: {job_id}.") - replies = engine.start_client_job(job_id, client_sites, fl_ctx) + replies = engine.start_client_job(job, client_sites, fl_ctx) client_sites_names = list(client_sites.keys()) check_client_replies(replies=replies, client_sites=client_sites_names, command=f"start job ({job_id})") display_sites = ",".join(client_sites_names) diff --git a/nvflare/private/fed/server/server_engine.py b/nvflare/private/fed/server/server_engine.py index 1bf8f415ee..be9823414e 100644 --- a/nvflare/private/fed/server/server_engine.py +++ b/nvflare/private/fed/server/server_engine.py @@ -285,7 +285,7 @@ def _start_runner_process( with self.lock: self.run_processes[run_number] = { - RunProcessKey.CHILD_PROCESS: process, + RunProcessKey.JOB_HANDLE: process, RunProcessKey.JOB_ID: job_id, RunProcessKey.PARTICIPANTS: job_clients, } @@ -333,7 +333,7 @@ def abort_app_on_server(self, job_id: str, turn_to_cold: bool = False) -> str: self.logger.info(f"Abort server status: {status_message}") except Exception: with self.lock: - child_process = self.run_processes.get(job_id, {}).get(RunProcessKey.CHILD_PROCESS, None) + child_process = self.run_processes.get(job_id, {}).get(RunProcessKey.JOB_HANDLE, None) if child_process: child_process.terminate() finally: @@ -873,13 +873,14 @@ def cancel_client_resources( if requests: _ = self._send_admin_requests(requests, fl_ctx) - def start_client_job(self, job_id, client_sites, fl_ctx: FLContext): + def start_client_job(self, job, client_sites, fl_ctx: FLContext): requests = {} for site, dispatch_info in client_sites.items(): resource_requirement = dispatch_info.resource_requirements token = dispatch_info.token request = Message(topic=TrainingTopic.START_JOB, body=resource_requirement) - request.set_header(RequestHeader.JOB_ID, job_id) + request.set_header(RequestHeader.JOB_ID, job.job_id) + request.set_header(RequestHeader.JOB_META, job.meta) request.set_header(ShareableHeader.RESOURCE_RESERVE_TOKEN, token) client = self.get_client_from_name(site) if client: diff --git a/nvflare/private/fed/simulator/simulator_client_engine.py b/nvflare/private/fed/simulator/simulator_client_engine.py index fd0a3969c1..47039a1147 100644 --- a/nvflare/private/fed/simulator/simulator_client_engine.py +++ b/nvflare/private/fed/simulator/simulator_client_engine.py @@ -25,7 +25,7 @@ def __init__(self, client, args, rank=0): fl_ctx.set_prop(FLContextKey.SIMULATE_MODE, True, private=True, sticky=True) self.client_executor.run_processes[SimulatorConstants.JOB_NAME] = { - RunProcessKey.CHILD_PROCESS: None, + RunProcessKey.JOB_HANDLE: None, RunProcessKey.STATUS: ClientStatus.STARTED, } diff --git a/nvflare/private/fed/utils/fed_utils.py b/nvflare/private/fed/utils/fed_utils.py index dea7bb3fc7..2b3052ee3d 100644 --- a/nvflare/private/fed/utils/fed_utils.py +++ b/nvflare/private/fed/utils/fed_utils.py @@ -26,7 +26,15 @@ from nvflare.apis.client import Client from nvflare.apis.event_type import EventType from nvflare.apis.fl_component import FLContext -from nvflare.apis.fl_constant import ConfigVarName, FLContextKey, FLMetaKey, SiteType, SystemVarName, WorkspaceConstants +from nvflare.apis.fl_constant import ( + ConfigVarName, + FLContextKey, + FLMetaKey, + JobConstants, + SiteType, + SystemVarName, + WorkspaceConstants, +) from nvflare.apis.fl_exception import UnsafeComponentError from nvflare.apis.job_def import JobMetaKey from nvflare.apis.utils.decomposers import flare_decomposers @@ -403,7 +411,7 @@ def get_target_names(targets): return target_names -def get_return_code(process, job_id, workspace, logger): +def get_return_code(job_handle, job_id, workspace, logger): run_dir = os.path.join(workspace, job_id) rc_file = os.path.join(run_dir, FLMetaKey.PROCESS_RC_FILE) if os.path.exists(rc_file): @@ -414,11 +422,11 @@ def get_return_code(process, job_id, workspace, logger): except Exception: logger.warning( f"Could not get the return_code from {rc_file} of the job:{job_id}, " - f"Return the RC from the process:{process.pid}" + f"Return the RC from the job_handle:{job_handle}" ) - return_code = process.poll() + return_code = job_handle.poll() else: - return_code = process.poll() + return_code = job_handle.poll() return return_code @@ -434,6 +442,30 @@ def add_custom_dir_to_path(app_custom_folder, new_env): new_env[SystemVarName.PYTHONPATH] = app_custom_folder +def extract_participants(participants_list): + participants = [] + for item in participants_list: + if isinstance(item, str): + participants.append(item) + elif isinstance(item, dict): + sites = item.get(JobConstants.SITES) + participants.extend(sites) + else: + raise ValueError(f"Must be tye of str or dict, but got {type(item)}") + return participants + + +def extract_job_image(job_meta, site_name): + deploy_map = job_meta.get(JobMetaKey.DEPLOY_MAP, {}) + for _, participants in deploy_map.items(): + for item in participants: + if isinstance(item, dict): + sites = item.get(JobConstants.SITES) + if site_name in sites: + return item.get(JobConstants.JOB_IMAGE) + return None + + def _scope_prop_key(scope_name: str, key: str): return f"{scope_name}::{key}" diff --git a/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py b/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py index 54c6fc49b7..8e0520053c 100644 --- a/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py +++ b/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py @@ -134,7 +134,7 @@ def persist_components(self, fl_ctx: FLContext, completed: bool): def restore_components(self, snapshot, fl_ctx: FLContext): pass - def start_client_job(self, job_id, client_sites, fl_ctx: FLContext): + def start_client_job(self, job, client_sites, fl_ctx: FLContext): pass def check_client_resources( diff --git a/tests/unit_test/private/fed/utils/fed_utils_test.py b/tests/unit_test/private/fed/utils/fed_utils_test.py index 5de4b5e0ed..1a2f784eba 100644 --- a/tests/unit_test/private/fed/utils/fed_utils_test.py +++ b/tests/unit_test/private/fed/utils/fed_utils_test.py @@ -18,6 +18,7 @@ from nvflare.fuel.utils.fobs import Decomposer from nvflare.fuel.utils.fobs.datum import DatumManager from nvflare.fuel.utils.fobs.fobs import register_custom_folder +from nvflare.private.fed.utils.fed_utils import extract_job_image, extract_participants class ExampleTestClass: @@ -50,3 +51,40 @@ def test_custom_fobs_initialize(self): decomposer = ExampleTestClassDecomposer() decomposers = fobs.fobs._decomposers assert decomposer in list(decomposers.values()) + + def test_extract_participants(self): + participants = ["site-1", "site-2"] + results = extract_participants(participants) + expected = ["site-1", "site-2"] + assert results == expected + + participants = ["@ALL"] + results = extract_participants(participants) + expected = ["@ALL"] + assert results == expected + + def test_extract_participants_with_image(self): + participants = [ + "site-1", + "site-2", + {"sites": ["site-3", "site-4"], "image": "image1"}, + {"sites": ["site-5"], "image": "image2"}, + ] + results = extract_participants(participants) + expected = ["site-1", "site-2", "site-3", "site-4", "site-5"] + assert results == expected + + def test_extract_job_image(self): + job_meta = {"deploy_map": {"app": ["site-1", "site-2", {"sites": ["site-3", "site-4"], "image": "image1"}]}} + result = extract_job_image(job_meta, "site-3") + expected = "image1" + assert result == expected + + result = extract_job_image(job_meta, "site-1") + expected = None + assert result == expected + + job_meta = {"deploy_map": {"app": ["site-1", "site-2"]}} + result = extract_job_image(job_meta, "site-1") + expected = None + assert result == expected