Skip to content

Commit

Permalink
feat(framework) Make flwr-serverapp / flwr-clientapp / `flwr-simu…
Browse files Browse the repository at this point in the history
…lation` subprocesses daemonic (#4798)
  • Loading branch information
panh99 authored Jan 10, 2025
1 parent 6f52a3c commit 9e97a50
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 13 deletions.
20 changes: 14 additions & 6 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
"""Flower client app."""


import multiprocessing
import signal
import subprocess
import sys
import time
from contextlib import AbstractContextManager
Expand All @@ -34,6 +34,7 @@
from flwr.cli.install import install_from_fab
from flwr.client.client import Client
from flwr.client.client_app import ClientApp, LoadClientAppError
from flwr.client.clientapp.app import flwr_clientapp
from flwr.client.nodestate.nodestate_factory import NodeStateFactory
from flwr.client.typing import ClientFnExt
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
Expand Down Expand Up @@ -391,6 +392,7 @@ def _on_backoff(retry_state: RetryState) -> None:
run_info_store: Optional[DeprecatedRunInfoStore] = None
state_factory = NodeStateFactory()
state = state_factory.state()
mp_spawn_context = multiprocessing.get_context("spawn")

runs: dict[int, Run] = {}

Expand Down Expand Up @@ -549,12 +551,13 @@ def _on_backoff(retry_state: RetryState) -> None:
]
command.append("--insecure")

subprocess.run(
command,
stdout=None,
stderr=None,
check=True,
proc = mp_spawn_context.Process(
target=_run_flwr_clientapp,
args=(command,),
daemon=True,
)
proc.start()
proc.join()
else:
# Wait for output to become available
while not clientappio_servicer.has_outputs():
Expand Down Expand Up @@ -826,6 +829,11 @@ def signal_handler(sig, frame): # type: ignore
signal.signal(signal.SIGTERM, signal_handler)


def _run_flwr_clientapp(args: list[str]) -> None:
sys.argv = args
flwr_clientapp()


def run_clientappio_api_grpc(
address: str,
certificates: Optional[tuple[bytes, bytes, bytes]],
Expand Down
38 changes: 31 additions & 7 deletions src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import argparse
import csv
import importlib.util
import subprocess
import multiprocessing
import multiprocessing.context
import sys
import threading
from collections.abc import Sequence
Expand Down Expand Up @@ -69,6 +70,8 @@
add_FleetServicer_to_server,
)
from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server
from flwr.server.serverapp.app import flwr_serverapp
from flwr.simulation.app import flwr_simulation
from flwr.superexec.app import load_executor
from flwr.superexec.exec_grpc import run_exec_api_grpc

Expand Down Expand Up @@ -444,22 +447,35 @@ def run_superlink() -> None:
sys.exit(1)


def _run_flwr_command(args: list[str]) -> None:
sys.argv = args
if args[0] == "flwr-serverapp":
flwr_serverapp()
elif args[0] == "flwr-simulation":
flwr_simulation()
else:
raise ValueError(f"Unknown command: {args[0]}")


def _flwr_scheduler(
state_factory: LinkStateFactory,
io_api_arg: str,
io_api_address: str,
cmd: str,
) -> None:
log(DEBUG, "Started %s scheduler thread.", cmd)

state = state_factory.state()
run_id_to_proc: dict[int, multiprocessing.context.SpawnProcess] = {}

# Use the "spawn" start method for multiprocessing.
mp_spawn_context = multiprocessing.get_context("spawn")

# Periodically check for a pending run in the LinkState
while True:
sleep(3)
sleep(0.1)
pending_run_id = state.get_pending_run_id()

if pending_run_id:
if pending_run_id and pending_run_id not in run_id_to_proc:

log(
INFO,
Expand All @@ -476,10 +492,18 @@ def _flwr_scheduler(
"--insecure",
]

subprocess.Popen( # pylint: disable=consider-using-with
command,
text=True,
proc = mp_spawn_context.Process(
target=_run_flwr_command, args=(command,), daemon=True
)
proc.start()

# Store the process
run_id_to_proc[pending_run_id] = proc

# Clean up finished processes
for run_id, proc in list(run_id_to_proc.items()):
if not proc.is_alive():
del run_id_to_proc[run_id]


def _format_address(address: str) -> tuple[str, str, int]:
Expand Down

0 comments on commit 9e97a50

Please sign in to comment.