Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shutdown method to Flare Client API #3152

Merged
merged 10 commits into from
Feb 7, 2025
9 changes: 5 additions & 4 deletions nvflare/app_common/executors/launcher_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def initialize(self, fl_ctx: FLContext) -> None:

def finalize(self, fl_ctx: FLContext) -> None:
self._execute_launcher_method_in_thread_executor(method_name="finalize", fl_ctx=fl_ctx)
self._thread_pool_executor.shutdown()

def handle_event(self, event_type: str, fl_ctx: FLContext) -> None:
if event_type == EventType.START_RUN:
Expand Down Expand Up @@ -295,7 +296,7 @@ def _wait_external_setup(self, task_name: str, fl_ctx: FLContext, abort_signal:

run_status = self.launcher.check_run_status(task_name, fl_ctx)
if run_status != LauncherRunStatus.RUNNING:
self.log_info(
self.log_error(
fl_ctx, f"External process has not called flare.init and run status becomes {run_status}."
)
return False
Expand All @@ -316,7 +317,7 @@ def _finalize_external_execution(
fl_ctx=fl_ctx,
)
if not self._received_result.is_set() and check_run_status != LauncherRunStatus.COMPLETE_SUCCESS:
self.log_warning(fl_ctx, f"Try to stop task ({task_name}) when launcher run status is {check_run_status}")
self.log_debug(fl_ctx, f"Try to stop task ({task_name}) when launcher run status is {check_run_status}")

self.log_info(fl_ctx, f"Calling stop task ({task_name}).")
stop_task_success = self._execute_launcher_method_in_thread_executor(
Expand Down Expand Up @@ -407,11 +408,11 @@ def _monitor_launcher(self, fl_ctx: FLContext):
self._launcher_finish = True
self.log_info(
fl_ctx,
f"launcher completed {task_name} with status {run_status} at time {self._launcher_finish_time}",
f"launcher completed with status {run_status} at time {self._launcher_finish_time}",
)

if run_status == LauncherRunStatus.COMPLETE_FAILED:
msg = f"Launcher failed with at time {self._launcher_finish_time} "
msg = f"Launcher failed at time {self._launcher_finish_time} "
self._abort_signal.trigger(msg)
break

Expand Down
87 changes: 52 additions & 35 deletions nvflare/app_common/launchers/subprocess_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import shlex
import subprocess
from threading import Thread
from threading import Lock, Thread
from typing import Optional

from nvflare.apis.fl_constant import FLContextKey
Expand Down Expand Up @@ -84,12 +84,21 @@ def log_subprocess_output(process, logger):


class SubprocessLauncher(Launcher):
def __init__(self, script: str, launch_once: bool = True, clean_up_script: Optional[str] = None):
def __init__(
self,
script: str,
launch_once: Optional[bool] = True,
clean_up_script: Optional[str] = None,
shutdown_timeout: Optional[float] = None,
):
"""Initializes the SubprocessLauncher.

Args:
script (str): Script to be launched using subprocess.
launch_once (bool): Whether the external process will be launched only once at the beginning or on each task.
clean_up_script (Optional[str]): Optional clean up script to be run after the main script execution.
shutdown_timeout (float): If provided, will wait for this number of seconds before shutdown.
None means never times out.
"""
super().__init__()

Expand All @@ -98,6 +107,8 @@ def __init__(self, script: str, launch_once: bool = True, clean_up_script: Optio
self._script = script
self._launch_once = launch_once
self._clean_up_script = clean_up_script
self._shutdown_timeout = shutdown_timeout
self._lock = Lock()
self.logger = get_obj_logger(self)

def initialize(self, fl_ctx: FLContext):
Expand All @@ -119,40 +130,46 @@ def stop_task(self, task_name: str, fl_ctx: FLContext, abort_signal: Signal) ->
self._stop_external_process()

def _start_external_process(self, fl_ctx: FLContext):
if self._process is None:
command = self._script
env = os.environ.copy()
env["CLIENT_API_TYPE"] = "EX_PROCESS_API"

workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT)
job_id = fl_ctx.get_prop(FLContextKey.CURRENT_JOB_ID)
app_custom_folder = workspace.get_app_custom_dir(job_id)
add_custom_dir_to_path(app_custom_folder, env)

command_seq = shlex.split(command)
self._process = subprocess.Popen(
command_seq, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=self._app_dir, env=env
)
self._log_thread = Thread(target=log_subprocess_output, args=(self._process, self.logger))
self._log_thread.start()
with self._lock:
if self._process is None:
command = self._script
env = os.environ.copy()
env["CLIENT_API_TYPE"] = "EX_PROCESS_API"

workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT)
job_id = fl_ctx.get_prop(FLContextKey.CURRENT_JOB_ID)
app_custom_folder = workspace.get_app_custom_dir(job_id)
add_custom_dir_to_path(app_custom_folder, env)

command_seq = shlex.split(command)
self._process = subprocess.Popen(
command_seq, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=self._app_dir, env=env
)
self._log_thread = Thread(target=log_subprocess_output, args=(self._process, self.logger))
self._log_thread.start()

def _stop_external_process(self):
if self._process:
self._process.terminate()
self._process.wait()
self._log_thread.join()
if self._clean_up_script:
command_seq = shlex.split(self._clean_up_script)
process = subprocess.Popen(command_seq, cwd=self._app_dir)
process.wait()
self._process = None
with self._lock:
if self._process:
try:
self._process.wait(self._shutdown_timeout)
except subprocess.TimeoutExpired:
pass
self._process.terminate()
self._log_thread.join()
if self._clean_up_script:
command_seq = shlex.split(self._clean_up_script)
process = subprocess.Popen(command_seq, cwd=self._app_dir)
process.wait()
self._process = None

def check_run_status(self, task_name: str, fl_ctx: FLContext) -> str:
if self._process is None:
return LauncherRunStatus.NOT_RUNNING
return_code = self._process.poll()
if return_code is None:
return LauncherRunStatus.RUNNING
if return_code == 0:
return LauncherRunStatus.COMPLETE_SUCCESS
return LauncherRunStatus.COMPLETE_FAILED
with self._lock:
if self._process is None:
return LauncherRunStatus.NOT_RUNNING
return_code = self._process.poll()
if return_code is None:
return LauncherRunStatus.RUNNING
if return_code == 0:
return LauncherRunStatus.COMPLETE_SUCCESS
return LauncherRunStatus.COMPLETE_FAILED
3 changes: 3 additions & 0 deletions nvflare/app_opt/lightning/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from torch import Tensor

from nvflare.app_common.abstract.fl_model import FLModel, MetaKey
from nvflare.app_opt.pt.decomposers import TensorDecomposer
from nvflare.client.api import clear, get_config, init, is_evaluate, is_submit_model, is_train, receive, send
from nvflare.client.config import ConfigKey
from nvflare.fuel.utils import fobs

from .callbacks import RestoreState

Expand Down Expand Up @@ -65,6 +67,7 @@ def __init__(self):
self.__fl_meta__ = {"CUSTOM_VAR": "VALUE_OF_THE_VAR"}

"""
fobs.register(TensorDecomposer)
callbacks = trainer.callbacks
if isinstance(callbacks, Callback):
callbacks = [callbacks]
Expand Down
1 change: 1 addition & 0 deletions nvflare/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .api import log as log
from .api import receive as receive
from .api import send as send
from .api import shutdown as shutdown
from .api import system_info as system_info
from .decorator import evaluate as evaluate
from .decorator import train as train
Expand Down
Loading
Loading