Skip to content

Commit

Permalink
feat(framework) Add connection cleanup for SuperNode graceful exit (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Jan 17, 2025
1 parent 471caa1 commit 3128ab4
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 41 deletions.
40 changes: 4 additions & 36 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@

import multiprocessing
import os
import signal
import sys
import threading
import time
from contextlib import AbstractContextManager
from dataclasses import dataclass
from logging import ERROR, INFO, WARN
from os import urandom
from pathlib import Path
Expand Down Expand Up @@ -348,10 +346,7 @@ def _load_client_app(_1: str, _2: str, _3: str) -> ClientApp:
transport, server_address
)

app_state_tracker = _AppStateTracker()

def _on_sucess(retry_state: RetryState) -> None:
app_state_tracker.is_connected = True
if retry_state.tries > 1:
log(
INFO,
Expand All @@ -361,7 +356,6 @@ def _on_sucess(retry_state: RetryState) -> None:
)

def _on_backoff(retry_state: RetryState) -> None:
app_state_tracker.is_connected = False
if retry_state.tries == 1:
log(WARN, "Connection attempt failed, retrying...")
else:
Expand Down Expand Up @@ -398,7 +392,7 @@ def _on_backoff(retry_state: RetryState) -> None:

runs: dict[int, Run] = {}

while not app_state_tracker.interrupt:
while True:
sleep_duration: int = 0
with connection(
address,
Expand Down Expand Up @@ -437,9 +431,8 @@ def _on_backoff(retry_state: RetryState) -> None:
node_config=node_config,
)

app_state_tracker.register_signal_handler()
# pylint: disable=too-many-nested-blocks
while not app_state_tracker.interrupt:
while True:
try:
# Receive
message = receive()
Expand Down Expand Up @@ -597,10 +590,7 @@ def _on_backoff(retry_state: RetryState) -> None:
e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
exc_entity = "SuperNode"

if not app_state_tracker.interrupt:
log(
ERROR, "%s raised an exception", exc_entity, exc_info=ex
)
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)

# Create error message
reply_message = message.create_error_reply(
Expand All @@ -626,19 +616,14 @@ def _on_backoff(retry_state: RetryState) -> None:
run_id,
)
log(INFO, "")

except StopIteration:
sleep_duration = 0
break
# pylint: enable=too-many-nested-blocks

# Unregister node
if delete_node is not None and app_state_tracker.is_connected:
if delete_node is not None:
delete_node() # pylint: disable=not-callable

if sleep_duration == 0:
log(INFO, "Disconnect and shut down")
del app_state_tracker
break

# Sleep and reconnect afterwards
Expand Down Expand Up @@ -814,23 +799,6 @@ def _init_connection(transport: Optional[str], server_address: str) -> tuple[
return connection, address, error_type


@dataclass
class _AppStateTracker:
interrupt: bool = False
is_connected: bool = False

def register_signal_handler(self) -> None:
"""Register handlers for exit signals."""

def signal_handler(sig, frame): # type: ignore
# pylint: disable=unused-argument
self.interrupt = True
raise StopIteration from None

signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)


def _run_flwr_clientapp(args: list[str], main_pid: int) -> None:
# Monitor the main process in case of SIGKILL
def main_process_monitor() -> None:
Expand Down
10 changes: 10 additions & 0 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,3 +311,13 @@ def get_fab(fab_hash: str, run_id: int) -> Fab:
yield (receive, send, create_node, delete_node, get_run, get_fab)
except Exception as exc: # pylint: disable=broad-except
log(ERROR, exc)
# Cleanup
finally:
try:
if node is not None:
# Disable retrying
retry_invoker.max_tries = 1
delete_node()
except grpc.RpcError:
pass
channel.close()
10 changes: 10 additions & 0 deletions src/py/flwr/client/rest_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from cryptography.hazmat.primitives.asymmetric import ec
from google.protobuf.message import Message as GrpcMessage
from requests.exceptions import ConnectionError as RequestsConnectionError

from flwr.client.heartbeat import start_ping_loop
from flwr.client.message_handler.message_handler import validate_out_message
Expand Down Expand Up @@ -379,3 +380,12 @@ def get_fab(fab_hash: str, run_id: int) -> Fab:
yield (receive, send, create_node, delete_node, get_run, get_fab)
except Exception as exc: # pylint: disable=broad-except
log(ERROR, exc)
# Cleanup
finally:
try:
if node is not None:
# Disable retrying
retry_invoker.max_tries = 1
delete_node()
except RequestsConnectionError:
pass
10 changes: 5 additions & 5 deletions src/py/flwr/client/supernode/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ def run_supernode() -> None:

log(DEBUG, "Isolation mode: %s", args.isolation)

# Register handlers for graceful shutdown
register_exit_handlers(
event_type=EventType.RUN_SUPERNODE_LEAVE,
)

start_client_internal(
server_address=args.superlink,
load_client_app_fn=load_fn,
Expand All @@ -103,11 +108,6 @@ def run_supernode() -> None:
clientappio_api_address=args.clientappio_api_address,
)

# Graceful shutdown
register_exit_handlers(
event_type=EventType.RUN_SUPERNODE_LEAVE,
)


def run_client_app() -> None:
"""Run Flower client app."""
Expand Down

0 comments on commit 3128ab4

Please sign in to comment.