From 018dd45142a2d2375e1bfe12489ab1208bfafebc Mon Sep 17 00:00:00 2001 From: Javier Date: Sat, 8 Jun 2024 16:23:21 +0200 Subject: [PATCH] feat(framework) Introduce `RunTracker` (#3561) Co-authored-by: Charles Beauville --- src/py/flwr/client/app.py | 65 +++++++++++++++++++++++++++------------ 1 file changed, 45 insertions(+), 20 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index d7c05d8afbb2..4e09c53c2b00 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -14,8 +14,10 @@ # ============================================================================== """Flower client app.""" +import signal import sys import time +from dataclasses import dataclass from logging import DEBUG, ERROR, INFO, WARN from typing import Callable, ContextManager, Optional, Tuple, Type, Union @@ -37,7 +39,7 @@ ) from flwr.common.logger import log, warn_deprecated_feature from flwr.common.message import Error -from flwr.common.retry_invoker import RetryInvoker, exponential +from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential from .grpc_client.connection import grpc_connection from .grpc_rere_client.connection import grpc_request_response @@ -263,6 +265,29 @@ def _load_client_app() -> ClientApp: transport, server_address ) + run_tracker = _RunTracker() + + def _on_sucess(retry_state: RetryState) -> None: + if retry_state.tries > 1: + log( + INFO, + "Connection successful after %.2f seconds and %s tries.", + retry_state.elapsed_time, + retry_state.tries, + ) + if run_tracker.create_node: + run_tracker.create_node() + + def _on_backoff(retry_state: RetryState) -> None: + if retry_state.tries == 1: + log(WARN, "Connection attempt failed, retrying...") + else: + log( + DEBUG, + "Connection attempt failed, retrying in %.2f seconds", + retry_state.actual_wait, + ) + retry_invoker = RetryInvoker( wait_gen_factory=exponential, recoverable_exceptions=connection_error_type, @@ -278,25 +303,8 @@ def _load_client_app() -> ClientApp: if retry_state.tries > 1 else None ), - on_success=lambda retry_state: ( - log( - INFO, - "Connection successful after %.2f seconds and %s tries.", - retry_state.elapsed_time, - retry_state.tries, - ) - if retry_state.tries > 1 - else None - ), - on_backoff=lambda retry_state: ( - log(WARN, "Connection attempt failed, retrying...") - if retry_state.tries == 1 - else log( - DEBUG, - "Connection attempt failed, retrying in %.2f seconds", - retry_state.actual_wait, - ) - ), + on_success=_on_sucess, + on_backoff=_on_backoff, ) node_state = NodeState() @@ -579,3 +587,20 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ ) return connection, address, error_type + + +@dataclass +class _RunTracker: + create_node: Optional[Callable[[], None]] = None + interrupt: bool = False + + def register_signal_handler(self) -> None: + """Register handlers for exit signals.""" + + def signal_handler(sig, frame): # type: ignore + # pylint: disable=unused-argument + self.interrupt = True + raise StopIteration from None + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler)