From 4b5701707fc6a5bc7ab82bb93a8e2e657c159918 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 2 Apr 2024 11:10:15 +0100 Subject: [PATCH 1/3] Set `ping_interval` in `State.create_node` (#3190) --- src/py/flwr/common/constant.py | 1 + .../superlink/fleet/message_handler/message_handler.py | 2 +- src/py/flwr/server/superlink/fleet/vce/vce_api.py | 3 ++- src/py/flwr/server/superlink/state/in_memory_state.py | 6 ++---- src/py/flwr/server/superlink/state/sqlite_state.py | 6 ++---- src/py/flwr/server/superlink/state/state.py | 2 +- src/py/flwr/server/superlink/state/state_test.py | 8 ++++---- 7 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index 3ee60f6222f9..dd100ba25d25 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -41,6 +41,7 @@ PING_CALL_TIMEOUT = 5 PING_BASE_MULTIPLIER = 0.8 PING_RANDOM_RANGE = (-0.1, 0.1) +PING_MAX_INTERVAL = 1e300 class MessageType: diff --git a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py index 9fa7656198e5..39edd606b464 100644 --- a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py @@ -43,7 +43,7 @@ def create_node( ) -> CreateNodeResponse: """.""" # Create node - node_id = state.create_node() + node_id = state.create_node(ping_interval=request.ping_interval) return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False)) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index 5fec10940343..ea74bf492ab9 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -24,6 +24,7 @@ from flwr.client.client_app import ClientApp, LoadClientAppError from flwr.client.node_state import NodeState +from flwr.common.constant import PING_MAX_INTERVAL from flwr.common.logger import log from flwr.common.message import Error from flwr.common.object_ref import load_app @@ -43,7 +44,7 @@ def _register_nodes( nodes_mapping: NodeToPartitionMapping = {} state = state_factory.state() for i in range(num_nodes): - node_id = state.create_node() + node_id = state.create_node(ping_interval=PING_MAX_INTERVAL) nodes_mapping[node_id] = i log(INFO, "Registered %i nodes", len(nodes_mapping)) return nodes_mapping diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index 6fc57707ac36..2ce6dcd4599a 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -182,16 +182,14 @@ def num_task_res(self) -> int: """ return len(self.task_res_store) - def create_node(self) -> int: + def create_node(self, ping_interval: float) -> int: """Create, store in state, and return `node_id`.""" # Sample a random int64 as node_id node_id: int = int.from_bytes(os.urandom(8), "little", signed=True) with self.lock: if node_id not in self.node_ids: - # Default ping interval is 30s - # TODO: change 1e9 to 30s # pylint: disable=W0511 - self.node_ids[node_id] = (time.time() + 1e9, 1e9) + self.node_ids[node_id] = (time.time() + ping_interval, ping_interval) return node_id log(ERROR, "Unexpected node registration failure.") return 0 diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 6996d51d2a9b..b68d19bd96d9 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -468,7 +468,7 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None: return None - def create_node(self) -> int: + def create_node(self, ping_interval: float) -> int: """Create, store in state, and return `node_id`.""" # Sample a random int64 as node_id node_id: int = int.from_bytes(os.urandom(8), "little", signed=True) @@ -478,9 +478,7 @@ def create_node(self) -> int: ) try: - # Default ping interval is 30s - # TODO: change 1e9 to 30s # pylint: disable=W0511 - self.query(query, (node_id, time.time() + 1e9, 1e9)) + self.query(query, (node_id, time.time() + ping_interval, ping_interval)) except sqlite3.IntegrityError: log(ERROR, "Unexpected node registration failure.") return 0 diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index 313290eb1022..b356cd47befa 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -132,7 +132,7 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None: """Delete all delivered TaskIns/TaskRes pairs.""" @abc.abstractmethod - def create_node(self) -> int: + def create_node(self, ping_interval: float) -> int: """Create, store in state, and return `node_id`.""" @abc.abstractmethod diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 1757cfac4255..8e49a380bb16 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -319,7 +319,7 @@ def test_create_node_and_get_nodes(self) -> None: # Execute for _ in range(10): - node_ids.append(state.create_node()) + node_ids.append(state.create_node(ping_interval=10)) retrieved_node_ids = state.get_nodes(run_id) # Assert @@ -331,7 +331,7 @@ def test_delete_node(self) -> None: # Prepare state: State = self.state_factory() run_id = state.create_run() - node_id = state.create_node() + node_id = state.create_node(ping_interval=10) # Execute state.delete_node(node_id) @@ -346,7 +346,7 @@ def test_get_nodes_invalid_run_id(self) -> None: state: State = self.state_factory() state.create_run() invalid_run_id = 61016 - state.create_node() + state.create_node(ping_interval=10) # Execute retrieved_node_ids = state.get_nodes(invalid_run_id) @@ -399,7 +399,7 @@ def test_acknowledge_ping(self) -> None: # Prepare state: State = self.state_factory() run_id = state.create_run() - node_ids = [state.create_node() for _ in range(100)] + node_ids = [state.create_node(ping_interval=10) for _ in range(100)] for node_id in node_ids[:70]: state.acknowledge_ping(node_id, ping_interval=30) for node_id in node_ids[70:]: From 987b1985ecd4e2a7fefbea501d9879fc47a964cb Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Tue, 2 Apr 2024 13:56:35 +0200 Subject: [PATCH 2/3] Bumpy up datasets version (#3193) --- datasets/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/pyproject.toml b/datasets/pyproject.toml index 5800faf3f272..7dfa60138582 100644 --- a/datasets/pyproject.toml +++ b/datasets/pyproject.toml @@ -54,7 +54,7 @@ exclude = [ [tool.poetry.dependencies] python = "^3.8" numpy = "^1.21.0" -datasets = "^2.14.3" +datasets = "^2.14.6" pillow = { version = ">=6.2.1", optional = true } soundfile = { version = ">=0.12.1", optional = true } librosa = { version = ">=0.10.0.post2", optional = true } From f95d641cefabb326b389250e6a431ff92c4ccd60 Mon Sep 17 00:00:00 2001 From: Javier Date: Tue, 2 Apr 2024 13:39:29 +0100 Subject: [PATCH 3/3] Introduce `ClientAppException` (#3191) --- src/py/flwr/client/app.py | 33 ++++++++++++++----- src/py/flwr/client/client_app.py | 9 +++++ src/py/flwr/common/constant.py | 3 +- .../server/superlink/fleet/vce/vce_api.py | 17 +++++++--- .../simulation/ray_transport/ray_actor.py | 24 ++------------ 5 files changed, 50 insertions(+), 36 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 7104ba267f57..1720405ab867 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -34,6 +34,7 @@ TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_REST, TRANSPORT_TYPES, + ErrorCode, ) from flwr.common.exit_handlers import register_exit_handlers from flwr.common.logger import log, warn_deprecated_feature @@ -483,7 +484,7 @@ def _load_client_app() -> ClientApp: # Create an error reply message that will never be used to prevent # the used-before-assignment linting error reply_message = message.create_error_reply( - error=Error(code=0, reason="Unknown") + error=Error(code=ErrorCode.UNKNOWN, reason="Unknown") ) # Handle app loading and task message @@ -491,27 +492,41 @@ def _load_client_app() -> ClientApp: # Load ClientApp instance client_app: ClientApp = load_client_app_fn() + # Execute ClientApp reply_message = client_app(message=message, context=context) - # Update node state - node_state.update_context( - run_id=message.metadata.run_id, - context=context, - ) except Exception as ex: # pylint: disable=broad-exception-caught - log(ERROR, "ClientApp raised an exception", exc_info=ex) # Legacy grpc-bidi if transport in ["grpc-bidi", None]: + log(ERROR, "Client raised an exception.", exc_info=ex) # Raise exception, crash process raise ex # Don't update/change NodeState - # Create error message + e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION # Reason example: ":<'division by zero'>" reason = str(type(ex)) + ":<'" + str(ex) + "'>" + exc_entity = "ClientApp" + if isinstance(ex, LoadClientAppError): + reason = ( + "An exception was raised when attempting to load " + "`ClientApp`" + ) + e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION + exc_entity = "SuperNode" + + log(ERROR, "%s raised an exception", exc_entity, exc_info=ex) + + # Create error message reply_message = message.create_error_reply( - error=Error(code=0, reason=reason) + error=Error(code=e_code, reason=reason) + ) + else: + # No exception, update node state + node_state.update_context( + run_id=message.metadata.run_id, + context=context, ) # Send diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py index 79e7720cbb8e..c9d337700147 100644 --- a/src/py/flwr/client/client_app.py +++ b/src/py/flwr/client/client_app.py @@ -28,6 +28,15 @@ from .typing import ClientAppCallable +class ClientAppException(Exception): + """Exception raised when an exception is raised while executing a ClientApp.""" + + def __init__(self, message: str): + ex_name = self.__class__.__name__ + self.message = f"\nException {ex_name} occurred. Message: " + message + super().__init__(self.message) + + class ClientApp: """Flower ClientApp. diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index dd100ba25d25..6a4061a72505 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -81,7 +81,8 @@ class ErrorCode: """Error codes for Message's Error.""" UNKNOWN = 0 - CLIENT_APP_RAISED_EXCEPTION = 1 + LOAD_CLIENT_APP_EXCEPTION = 1 + CLIENT_APP_RAISED_EXCEPTION = 2 def __new__(cls) -> ErrorCode: """Prevent instantiation.""" diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index ea74bf492ab9..9c27fca79c12 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -22,9 +22,9 @@ from logging import DEBUG, ERROR, INFO, WARN from typing import Callable, Dict, List, Optional -from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError from flwr.client.node_state import NodeState -from flwr.common.constant import PING_MAX_INTERVAL +from flwr.common.constant import PING_MAX_INTERVAL, ErrorCode from flwr.common.logger import log from flwr.common.message import Error from flwr.common.object_ref import load_app @@ -94,9 +94,18 @@ async def worker( except Exception as ex: # pylint: disable=broad-exception-caught log(ERROR, ex) log(ERROR, traceback.format_exc()) + + if isinstance(ex, ClientAppException): + e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION + elif isinstance(ex, LoadClientAppError): + e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION + else: + e_code = ErrorCode.UNKNOWN + reason = str(type(ex)) + ":<'" + str(ex) + "'>" - error = Error(code=0, reason=reason) - out_mssg = message.create_error_reply(error=error) + out_mssg = message.create_error_reply( + error=Error(code=e_code, reason=reason) + ) finally: if out_mssg: diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 9773203628ab..9caf0fc3e6c0 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -16,7 +16,6 @@ import asyncio import threading -import traceback from abc import ABC from logging import DEBUG, ERROR, WARNING from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union @@ -25,22 +24,13 @@ from ray import ObjectRef from ray.util.actor_pool import ActorPool -from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError from flwr.common import Context, Message from flwr.common.logger import log ClientAppFn = Callable[[], ClientApp] -class ClientException(Exception): - """Raised when client side logic crashes with an exception.""" - - def __init__(self, message: str): - div = ">" * 7 - self.message = "\n" + div + "A ClientException occurred." + message - super().__init__(self.message) - - class VirtualClientEngineActor(ABC): """Abstract base class for VirtualClientEngine Actors.""" @@ -71,17 +61,7 @@ def run( raise load_ex except Exception as ex: - client_trace = traceback.format_exc() - mssg = ( - "\n\tSomething went wrong when running your client run." - "\n\tClient " - + cid - + " crashed when the " - + self.__class__.__name__ - + " was running its run." - "\n\tException triggered on the client side: " + client_trace, - ) - raise ClientException(str(mssg)) from ex + raise ClientAppException(str(ex)) from ex return cid, out_message, context