Skip to content

Commit

Permalink
feat(framework) Prevent zombie subprocesses when the main process rec…
Browse files Browse the repository at this point in the history
…eives `SIGKILL` (#4826)
  • Loading branch information
panh99 authored Jan 17, 2025
1 parent 3db587c commit 471caa1
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
16 changes: 14 additions & 2 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@


import multiprocessing
import os
import signal
import sys
import threading
import time
from contextlib import AbstractContextManager
from dataclasses import dataclass
Expand Down Expand Up @@ -553,7 +555,7 @@ def _on_backoff(retry_state: RetryState) -> None:

proc = mp_spawn_context.Process(
target=_run_flwr_clientapp,
args=(command,),
args=(command, os.getpid()),
daemon=True,
)
proc.start()
Expand Down Expand Up @@ -829,7 +831,17 @@ def signal_handler(sig, frame): # type: ignore
signal.signal(signal.SIGTERM, signal_handler)


def _run_flwr_clientapp(args: list[str]) -> None:
def _run_flwr_clientapp(args: list[str], main_pid: int) -> None:
# Monitor the main process in case of SIGKILL
def main_process_monitor() -> None:
while True:
time.sleep(1)
if os.getppid() != main_pid:
os.kill(os.getpid(), 9)

threading.Thread(target=main_process_monitor, daemon=True).start()

# Run the command
sys.argv = args
flwr_clientapp()

Expand Down
15 changes: 13 additions & 2 deletions src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import importlib.util
import multiprocessing
import multiprocessing.context
import os
import sys
import threading
from collections.abc import Sequence
Expand Down Expand Up @@ -447,7 +448,17 @@ def run_superlink() -> None:
sys.exit(1)


def _run_flwr_command(args: list[str]) -> None:
def _run_flwr_command(args: list[str], main_pid: int) -> None:
# Monitor the main process in case of SIGKILL
def main_process_monitor() -> None:
while True:
sleep(1)
if os.getppid() != main_pid:
os.kill(os.getpid(), 9)

threading.Thread(target=main_process_monitor, daemon=True).start()

# Run the command
sys.argv = args
if args[0] == "flwr-serverapp":
flwr_serverapp()
Expand Down Expand Up @@ -493,7 +504,7 @@ def _flwr_scheduler(
]

proc = mp_spawn_context.Process(
target=_run_flwr_command, args=(command,), daemon=True
target=_run_flwr_command, args=(command, os.getpid()), daemon=True
)
proc.start()

Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/server/serverapp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
log_uploader = None
success = True
hash_run_id = None
run_status = None
while True:

try:
Expand Down

0 comments on commit 471caa1

Please sign in to comment.