From 9fc02fb364abe5274e0e02bc413f423d3b07f11f Mon Sep 17 00:00:00 2001 From: jafermarq Date: Tue, 2 Apr 2024 11:41:31 +0200 Subject: [PATCH] defined `ClientAppException` --- src/py/flwr/client/app.py | 9 ++++-- src/py/flwr/client/client_app.py | 52 ++++++++++++++++++++------------ 2 files changed, 39 insertions(+), 22 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 7104ba267f57..9b6bd5b86568 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -24,7 +24,7 @@ from grpc import RpcError from flwr.client.client import Client -from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError from flwr.client.typing import ClientFn from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event from flwr.common.address import parse_address @@ -34,6 +34,7 @@ TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_REST, TRANSPORT_TYPES, + ErrorCode, ) from flwr.common.exit_handlers import register_exit_handlers from flwr.common.logger import log, warn_deprecated_feature @@ -507,11 +508,15 @@ def _load_client_app() -> ClientApp: # Don't update/change NodeState + e_code = ErrorCode.UNKNOWN + if isinstance(ex, ClientAppException): + e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION + # Create error message # Reason example: ":<'division by zero'>" reason = str(type(ex)) + ":<'" + str(ex) + "'>" reply_message = message.create_error_reply( - error=Error(code=0, reason=reason) + error=Error(code=e_code, reason=reason) ) # Send diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py index 79e7720cbb8e..89033dca496e 100644 --- a/src/py/flwr/client/client_app.py +++ b/src/py/flwr/client/client_app.py @@ -28,6 +28,15 @@ from .typing import ClientAppCallable +class ClientAppException(Exception): + """Exepction raised when an exception is raised while executing a ClientApp.""" + + def __init__(self, message: str): + ex_name = self.__class__.__name__ + self.message = f"\nA {ex_name} occurred." + message + super().__init__(self.message) + + class ClientApp: """Flower ClientApp. @@ -84,26 +93,29 @@ def ffn( def __call__(self, message: Message, context: Context) -> Message: """Execute `ClientApp`.""" - # Execute message using `client_fn` - if self._call: - return self._call(message, context) - - # Execute message using a new - if message.metadata.message_type == MessageType.TRAIN: - if self._train: - return self._train(message, context) - raise ValueError("No `train` function registered") - if message.metadata.message_type == MessageType.EVALUATE: - if self._evaluate: - return self._evaluate(message, context) - raise ValueError("No `evaluate` function registered") - if message.metadata.message_type == MessageType.QUERY: - if self._query: - return self._query(message, context) - raise ValueError("No `query` function registered") - - # Message type did not match one of the known message types abvoe - raise ValueError(f"Unknown message_type: {message.metadata.message_type}") + try: + # Execute message using `client_fn` + if self._call: + return self._call(message, context) + + # Execute message using a new + if message.metadata.message_type == MessageType.TRAIN: + if self._train: + return self._train(message, context) + raise ValueError("No `train` function registered") + if message.metadata.message_type == MessageType.EVALUATE: + if self._evaluate: + return self._evaluate(message, context) + raise ValueError("No `evaluate` function registered") + if message.metadata.message_type == MessageType.QUERY: + if self._query: + return self._query(message, context) + raise ValueError("No `query` function registered") + + # Message type did not match one of the known message types abvoe + raise ValueError(f"Unknown message_type: {message.metadata.message_type}") + except Exception as ex: + raise ClientAppException(str(ex)) from ex def train(self) -> Callable[[ClientAppCallable], ClientAppCallable]: """Return a decorator that registers the train fn with the client app.