Skip to content

Commit

Permalink
feat(framework) Add exception handling to SuperNode for graceful exit…
Browse files Browse the repository at this point in the history
… when stopped (#4668)
  • Loading branch information
chongshenng authored Dec 16, 2024
1 parent fae9df9 commit 830b795
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 8 deletions.
12 changes: 11 additions & 1 deletion src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from flwr.common.logger import log, warn_deprecated_feature
from flwr.common.message import Error
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
from flwr.common.typing import Fab, Run, UserConfig
from flwr.common.typing import Fab, Run, RunNotRunningException, UserConfig
from flwr.proto.clientappio_pb2_grpc import add_ClientAppIoServicer_to_server
from flwr.server.superlink.fleet.grpc_bidi.grpc_server import generic_create_grpc_server
from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes
Expand Down Expand Up @@ -612,6 +612,16 @@ def _on_backoff(retry_state: RetryState) -> None:
send(reply_message)
log(INFO, "Sent reply")

except RunNotRunningException:
log(INFO, "")
log(
INFO,
"SuperNode aborted sending the reply message. "
"Run ID %s is not in `RUNNING` status.",
run_id,
)
log(INFO, "")

except StopIteration:
sleep_duration = 0
break
Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/client/clientapp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def run_clientapp( # pylint: disable=R0914

# Execute ClientApp
reply_message = client_app(message=message, context=context)

except Exception as ex: # pylint: disable=broad-exception-caught
# Don't update/change NodeState

Expand Down
14 changes: 10 additions & 4 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from flwr.common.message import Message, Metadata
from flwr.common.retry_invoker import RetryInvoker
from flwr.common.serde import message_from_taskins, message_to_taskres, run_from_proto
from flwr.common.typing import Fab, Run
from flwr.common.typing import Fab, Run, RunNotRunningException
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
Expand Down Expand Up @@ -155,10 +155,16 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
ping_thread: Optional[threading.Thread] = None
ping_stop_event = threading.Event()

def _should_giveup_fn(e: Exception) -> bool:
if e.code() == grpc.StatusCode.PERMISSION_DENIED: # type: ignore
raise RunNotRunningException
if e.code() == grpc.StatusCode.UNAVAILABLE: # type: ignore
return False
return True

# Restrict retries to cases where the status code is UNAVAILABLE
retry_invoker.should_giveup = (
lambda e: e.code() != grpc.StatusCode.UNAVAILABLE # type: ignore
)
# If the status code is PERMISSION_DENIED, additionally raise RunNotRunningException
retry_invoker.should_giveup = _should_giveup_fn

###########################################################################
# ping/create_node/delete_node/receive/send/get_run functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,7 @@ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResR
return response


def get_run(
request: GetRunRequest, state: LinkState # pylint: disable=W0613
) -> GetRunResponse:
def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
"""Get run information."""
run = state.get_run(request.run_id)

Expand Down

0 comments on commit 830b795

Please sign in to comment.