From 471caa10a0d1b127c1deb2e08fe59df49e0e87dc Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Fri, 17 Jan 2025 10:47:11 +0000 Subject: [PATCH] feat(framework) Prevent zombie subprocesses when the main process receives `SIGKILL` (#4826) --- src/py/flwr/client/app.py | 16 ++++++++++++++-- src/py/flwr/server/app.py | 15 +++++++++++++-- src/py/flwr/server/serverapp/app.py | 1 + 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index ff5a0cb4d617..6d83eebcd5f9 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -16,8 +16,10 @@ import multiprocessing +import os import signal import sys +import threading import time from contextlib import AbstractContextManager from dataclasses import dataclass @@ -553,7 +555,7 @@ def _on_backoff(retry_state: RetryState) -> None: proc = mp_spawn_context.Process( target=_run_flwr_clientapp, - args=(command,), + args=(command, os.getpid()), daemon=True, ) proc.start() @@ -829,7 +831,17 @@ def signal_handler(sig, frame): # type: ignore signal.signal(signal.SIGTERM, signal_handler) -def _run_flwr_clientapp(args: list[str]) -> None: +def _run_flwr_clientapp(args: list[str], main_pid: int) -> None: + # Monitor the main process in case of SIGKILL + def main_process_monitor() -> None: + while True: + time.sleep(1) + if os.getppid() != main_pid: + os.kill(os.getpid(), 9) + + threading.Thread(target=main_process_monitor, daemon=True).start() + + # Run the command sys.argv = args flwr_clientapp() diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index bce92669928d..e8c8b73f642d 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -20,6 +20,7 @@ import importlib.util import multiprocessing import multiprocessing.context +import os import sys import threading from collections.abc import Sequence @@ -447,7 +448,17 @@ def run_superlink() -> None: sys.exit(1) -def _run_flwr_command(args: list[str]) -> None: +def _run_flwr_command(args: list[str], main_pid: int) -> None: + # Monitor the main process in case of SIGKILL + def main_process_monitor() -> None: + while True: + sleep(1) + if os.getppid() != main_pid: + os.kill(os.getpid(), 9) + + threading.Thread(target=main_process_monitor, daemon=True).start() + + # Run the command sys.argv = args if args[0] == "flwr-serverapp": flwr_serverapp() @@ -493,7 +504,7 @@ def _flwr_scheduler( ] proc = mp_spawn_context.Process( - target=_run_flwr_command, args=(command,), daemon=True + target=_run_flwr_command, args=(command, os.getpid()), daemon=True ) proc.start() diff --git a/src/py/flwr/server/serverapp/app.py b/src/py/flwr/server/serverapp/app.py index 39ecc4899d8a..35e1df92ab67 100644 --- a/src/py/flwr/server/serverapp/app.py +++ b/src/py/flwr/server/serverapp/app.py @@ -117,6 +117,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915 log_uploader = None success = True hash_run_id = None + run_status = None while True: try: