From 830b795d3f1bbd8067c552db01b1c3d98b3b2b4b Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Mon, 16 Dec 2024 11:49:36 +0000 Subject: [PATCH] feat(framework) Add exception handling to SuperNode for graceful exit when stopped (#4668) --- src/py/flwr/client/app.py | 12 +++++++++++- src/py/flwr/client/clientapp/app.py | 1 + src/py/flwr/client/grpc_rere_client/connection.py | 14 ++++++++++---- .../fleet/message_handler/message_handler.py | 4 +--- 4 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index e245e5d96576..1f04efa99eeb 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -56,7 +56,7 @@ from flwr.common.logger import log, warn_deprecated_feature from flwr.common.message import Error from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential -from flwr.common.typing import Fab, Run, UserConfig +from flwr.common.typing import Fab, Run, RunNotRunningException, UserConfig from flwr.proto.clientappio_pb2_grpc import add_ClientAppIoServicer_to_server from flwr.server.superlink.fleet.grpc_bidi.grpc_server import generic_create_grpc_server from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes @@ -612,6 +612,16 @@ def _on_backoff(retry_state: RetryState) -> None: send(reply_message) log(INFO, "Sent reply") + except RunNotRunningException: + log(INFO, "") + log( + INFO, + "SuperNode aborted sending the reply message. " + "Run ID %s is not in `RUNNING` status.", + run_id, + ) + log(INFO, "") + except StopIteration: sleep_duration = 0 break diff --git a/src/py/flwr/client/clientapp/app.py b/src/py/flwr/client/clientapp/app.py index 96ff7e4fd2fc..32813205478a 100644 --- a/src/py/flwr/client/clientapp/app.py +++ b/src/py/flwr/client/clientapp/app.py @@ -140,6 +140,7 @@ def run_clientapp( # pylint: disable=R0914 # Execute ClientApp reply_message = client_app(message=message, context=context) + except Exception as ex: # pylint: disable=broad-exception-caught # Don't update/change NodeState diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index ed5745512186..8f33f0b3f07a 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -42,7 +42,7 @@ from flwr.common.message import Message, Metadata from flwr.common.retry_invoker import RetryInvoker from flwr.common.serde import message_from_taskins, message_to_taskres, run_from_proto -from flwr.common.typing import Fab, Run +from flwr.common.typing import Fab, Run, RunNotRunningException from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611 from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, @@ -155,10 +155,16 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917 ping_thread: Optional[threading.Thread] = None ping_stop_event = threading.Event() + def _should_giveup_fn(e: Exception) -> bool: + if e.code() == grpc.StatusCode.PERMISSION_DENIED: # type: ignore + raise RunNotRunningException + if e.code() == grpc.StatusCode.UNAVAILABLE: # type: ignore + return False + return True + # Restrict retries to cases where the status code is UNAVAILABLE - retry_invoker.should_giveup = ( - lambda e: e.code() != grpc.StatusCode.UNAVAILABLE # type: ignore - ) + # If the status code is PERMISSION_DENIED, additionally raise RunNotRunningException + retry_invoker.should_giveup = _should_giveup_fn ########################################################################### # ping/create_node/delete_node/receive/send/get_run functions diff --git a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py index 34917b8168eb..c70c09cd7ac9 100644 --- a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py @@ -123,9 +123,7 @@ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResR return response -def get_run( - request: GetRunRequest, state: LinkState # pylint: disable=W0613 -) -> GetRunResponse: +def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse: """Get run information.""" run = state.get_run(request.run_id)