From 64dd32591da10251958ae5f0b3a0fb7466c56f52 Mon Sep 17 00:00:00 2001 From: Javier Date: Sat, 13 Jul 2024 08:39:50 +0200 Subject: [PATCH] feat(framework) Introduce new `client_fn` signature passing the `Context` (#3779) Co-authored-by: Daniel J. Beutel --- src/py/flwr/client/app.py | 7 ++-- src/py/flwr/client/client_app.py | 34 +++++++++++++++---- .../client/message_handler/message_handler.py | 2 +- .../message_handler/message_handler_test.py | 6 ++-- src/py/flwr/client/typing.py | 4 +-- .../fleet/vce/backend/raybackend_test.py | 4 +-- .../server/superlink/fleet/vce/vce_api.py | 4 ++- .../ray_transport/ray_client_proxy.py | 17 ++++++---- .../ray_transport/ray_client_proxy_test.py | 32 +++++++++-------- 9 files changed, 66 insertions(+), 44 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index ffcc95489d62..380185ed26e7 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -28,7 +28,7 @@ from flwr.client.client import Client from flwr.client.client_app import ClientApp, LoadClientAppError from flwr.client.typing import ClientFnExt -from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event +from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event from flwr.common.address import parse_address from flwr.common.constant import ( MISSING_EXTRA_REST, @@ -138,7 +138,7 @@ class `flwr.client.Client` (default: None) Starting an SSL-enabled gRPC client using system certificates: - >>> def client_fn(node_id: int, partition_id: Optional[int]): + >>> def client_fn(context: Context): >>> return FlowerClient() >>> >>> start_client( @@ -253,8 +253,7 @@ class `flwr.client.Client` (default: None) if client_fn is None: # Wrap `Client` instance in `client_fn` def single_client_factory( - node_id: int, # pylint: disable=unused-argument - partition_id: Optional[int], # pylint: disable=unused-argument + context: Context, # pylint: disable=unused-argument ) -> Client: if client is None: # Added this to keep mypy happy raise ValueError( diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py index 663d83a8b19e..9566302d0721 100644 --- a/src/py/flwr/client/client_app.py +++ b/src/py/flwr/client/client_app.py @@ -30,21 +30,41 @@ from .typing import ClientAppCallable +def _alert_erroneous_client_fn() -> None: + raise ValueError( + "A `ClientApp` cannot make use of a `client_fn` that does " + "not have a signature in the form: `def client_fn(context: " + "Context)`. You can import the `Context` like this: " + "`from flwr.common import Context`" + ) + + def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt: client_fn_args = inspect.signature(client_fn).parameters + first_arg = list(client_fn_args.keys())[0] + + if len(client_fn_args) != 1: + _alert_erroneous_client_fn() + + first_arg_type = client_fn_args[first_arg].annotation - if not all(key in client_fn_args for key in ["node_id", "partition_id"]): + if first_arg_type is str or first_arg == "cid": + # Warn previous signature for `client_fn` seems to be used warn_deprecated_feature( - "`client_fn` now expects a signature `def client_fn(node_id: int, " - "partition_id: Optional[int])`.\nYou provided `client_fn` with signature: " - f"{dict(client_fn_args.items())}" + "`client_fn` now expects a signature `def client_fn(context: Context)`." + "The provided `client_fn` has signature: " + f"{dict(client_fn_args.items())}. You can import the `Context` like this:" + " `from flwr.common import Context`" ) # Wrap depcreated client_fn inside a function with the expected signature def adaptor_fn( - node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument - ) -> Client: - return client_fn(str(partition_id)) # type: ignore + context: Context, + ) -> Client: # pylint: disable=unused-argument + # if patition-id is defined, pass it. Else pass node_id that should + # always be defined during Context init. + cid = context.node_config.get("partition-id", context.node_id) + return client_fn(str(cid)) # type: ignore return adaptor_fn diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index e9a853a92101..1ab84eb01468 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -92,7 +92,7 @@ def handle_legacy_message_from_msgtype( client_fn: ClientFnExt, message: Message, context: Context ) -> Message: """Handle legacy message in the inner most mod.""" - client = client_fn(message.metadata.dst_node_id, context.partition_id) + client = client_fn(context) # Check if NumPyClient is returend if isinstance(client, NumPyClient): diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 96de7ce0c2cb..557d61ffb32a 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -19,7 +19,7 @@ import unittest import uuid from copy import copy -from typing import List, Optional +from typing import List from flwr.client import Client from flwr.client.typing import ClientFnExt @@ -114,9 +114,7 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes: def _get_client_fn(client: Client) -> ClientFnExt: - def client_fn( - node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument - ) -> Client: + def client_fn(contex: Context) -> Client: # pylint: disable=unused-argument return client return client_fn diff --git a/src/py/flwr/client/typing.py b/src/py/flwr/client/typing.py index bf66a9082c77..9faed4bc7283 100644 --- a/src/py/flwr/client/typing.py +++ b/src/py/flwr/client/typing.py @@ -15,7 +15,7 @@ """Custom types for Flower clients.""" -from typing import Callable, Optional +from typing import Callable from flwr.common import Context, Message @@ -23,7 +23,7 @@ # Compatibility ClientFn = Callable[[str], Client] -ClientFnExt = Callable[[int, Optional[int]], Client] +ClientFnExt = Callable[[Context], Client] ClientAppCallable = Callable[[Message, Context], Message] Mod = Callable[[Message, Context, ClientAppCallable], Message] diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py index da4390194d05..3abdac7a232b 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py @@ -53,9 +53,7 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]: return {"result": result} -def get_dummy_client( - node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument -) -> Client: +def get_dummy_client(context: Context) -> Client: # pylint: disable=unused-argument """Return a DummyClient converted to Client type.""" return DummyClient().to_client() diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index 134fd34ed8f0..422148489b4a 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -285,7 +285,9 @@ def start_vce( node_states: Dict[int, NodeState] = {} for node_id, partition_id in nodes_mapping.items(): node_states[node_id] = NodeState( - node_id=node_id, node_config={}, partition_id=partition_id + node_id=node_id, + node_config={"partition-id": str(partition_id)}, + partition_id=None, ) # Load backend config diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index f2684016048e..b62e04aeed79 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -60,7 +60,9 @@ def _load_app() -> ClientApp: self.app_fn = _load_app self.actor_pool = actor_pool self.proxy_state = NodeState( - node_id=node_id, node_config={}, partition_id=self.partition_id + node_id=node_id, + node_config={"partition-id": str(partition_id)}, + partition_id=None, ) def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: @@ -70,18 +72,19 @@ def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: # Register state self.proxy_state.register_context(run_id=run_id) - # Retrieve state - state = self.proxy_state.retrieve_context(run_id=run_id) + # Retrieve context + context = self.proxy_state.retrieve_context(run_id=run_id) + partition_id_str = context.node_config["partition-id"] try: self.actor_pool.submit_client_job( - lambda a, a_fn, mssg, partition_id, state: a.run.remote( - a_fn, mssg, partition_id, state + lambda a, a_fn, mssg, partition_id, context: a.run.remote( + a_fn, mssg, partition_id, context ), - (self.app_fn, message, str(self.partition_id), state), + (self.app_fn, message, partition_id_str, context), ) out_mssg, updated_context = self.actor_pool.get_client_result( - str(self.partition_id), timeout + partition_id_str, timeout ) # Update state diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 8831e5f475ea..1d44ea1d8d2b 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -17,7 +17,7 @@ from math import pi from random import shuffle -from typing import Dict, List, Optional, Tuple, Type +from typing import Dict, List, Tuple, Type import ray @@ -39,7 +39,10 @@ recordset_to_getpropertiesres, ) from flwr.common.recordset_compat_test import _get_valid_getpropertiesins -from flwr.simulation.app import _create_node_id_to_partition_mapping +from flwr.simulation.app import ( + NodeToPartitionMapping, + _create_node_id_to_partition_mapping, +) from flwr.simulation.ray_transport.ray_actor import ( ClientAppActor, VirtualClientEngineActor, @@ -65,16 +68,16 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]: return {"result": result} -def get_dummy_client( - node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument -) -> Client: +def get_dummy_client(context: Context) -> Client: """Return a DummyClient converted to Client type.""" - return DummyClient(node_id).to_client() + return DummyClient(context.node_id).to_client() def prep( actor_type: Type[VirtualClientEngineActor] = ClientAppActor, -) -> Tuple[List[RayActorClientProxy], VirtualClientEngineActorPool]: # pragma: no cover +) -> Tuple[ + List[RayActorClientProxy], VirtualClientEngineActorPool, NodeToPartitionMapping +]: # pragma: no cover """Prepare ClientProxies and pool for tests.""" client_resources = {"num_cpus": 1, "num_gpus": 0.0} @@ -101,7 +104,7 @@ def create_actor_fn() -> Type[VirtualClientEngineActor]: for node_id, partition_id in mapping.items() ] - return proxies, pool + return proxies, pool, mapping def test_cid_consistency_one_at_a_time() -> None: @@ -109,7 +112,7 @@ def test_cid_consistency_one_at_a_time() -> None: Submit one job and waits for completion. Then submits the next and so on """ - proxies, _ = prep() + proxies, _, _ = prep() getproperties_ins = _get_valid_getpropertiesins() recordset = getpropertiesins_to_recordset(getproperties_ins) @@ -139,7 +142,7 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: All jobs are submitted at the same time. Then fetched one at a time. This also tests NodeState (at each Proxy) and RunState basic functionality. """ - proxies, _ = prep() + proxies, _, _ = prep() run_id = 0 getproperties_ins = _get_valid_getpropertiesins() @@ -186,9 +189,8 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: def test_cid_consistency_without_proxies() -> None: """Test cid consistency of jobs submitted/retrieved to/from pool w/o ClientProxy.""" - proxies, pool = prep() - num_clients = len(proxies) - node_ids = list(range(num_clients)) + _, pool, mapping = prep() + node_ids = list(mapping.keys()) getproperties_ins = _get_valid_getpropertiesins() recordset = getpropertiesins_to_recordset(getproperties_ins) @@ -219,11 +221,11 @@ def _load_app() -> ClientApp: message, str(node_id), Context( - node_id=0, + node_id=node_id, node_config={}, state=RecordSet(), run_config={}, - partition_id=node_id, + partition_id=mapping[node_id], ), ), )