From 850ccce687878f94b8a01d9ac65a7b402bcf1e68 Mon Sep 17 00:00:00 2001 From: Javier Date: Tue, 8 Oct 2024 14:49:35 +0100 Subject: [PATCH] refactor(framework) Make `DriverClientProxy` use driver's `send_and_receive` (#4289) --- .../flwr/server/compat/driver_client_proxy.py | 44 ++++++--------- .../server/compat/driver_client_proxy_test.py | 53 +++++++------------ 2 files changed, 35 insertions(+), 62 deletions(-) diff --git a/src/py/flwr/server/compat/driver_client_proxy.py b/src/py/flwr/server/compat/driver_client_proxy.py index 7190786784ec..c5a3f561d474 100644 --- a/src/py/flwr/server/compat/driver_client_proxy.py +++ b/src/py/flwr/server/compat/driver_client_proxy.py @@ -15,7 +15,6 @@ """Flower ClientProxy implementation for Driver API.""" -import time from typing import Optional from flwr import common @@ -25,8 +24,6 @@ from ..driver.driver import Driver -SLEEP_TIME = 1 - class DriverClientProxy(ClientProxy): """Flower client proxy which delegates work using the Driver API.""" @@ -122,29 +119,18 @@ def _send_receive_recordset( ttl=timeout, ) - # Push message - message_ids = list(self.driver.push_messages(messages=[message])) - if len(message_ids) != 1: - raise ValueError("Unexpected number of message_ids") - - message_id = message_ids[0] - if message_id == "": - raise ValueError(f"Failed to send message to node {self.node_id}") - - if timeout: - start_time = time.time() - - while True: - messages = list(self.driver.pull_messages(message_ids)) - if len(messages) == 1: - msg: Message = messages[0] - if msg.has_error(): - raise ValueError( - f"Message contains an Error (reason: {msg.error.reason}). " - "It originated during client-side execution of a message." - ) - return msg.content - - if timeout is not None and time.time() > start_time + timeout: - raise RuntimeError("Timeout reached") - time.sleep(SLEEP_TIME) + # Send message and wait for reply + messages = list(self.driver.send_and_receive(messages=[message])) + + # A single reply is expected + if len(messages) != 1: + raise ValueError(f"Expected one Message but got: {len(messages)}") + + # Only messages without errors can be handled beyond these point + msg: Message = messages[0] + if msg.has_error(): + raise ValueError( + f"Message contains an Error (reason: {msg.error.reason}). " + "It originated during client-side execution of a message." + ) + return msg.content diff --git a/src/py/flwr/server/compat/driver_client_proxy_test.py b/src/py/flwr/server/compat/driver_client_proxy_test.py index a5b454c79f90..335f47cc7732 100644 --- a/src/py/flwr/server/compat/driver_client_proxy_test.py +++ b/src/py/flwr/server/compat/driver_client_proxy_test.py @@ -52,8 +52,6 @@ RUN_ID = 61016 NODE_ID = 1 -INSTRUCTION_MESSAGE_ID = "mock instruction message id" -REPLY_MESSAGE_ID = "mock reply message id" class DriverClientProxyTestCase(unittest.TestCase): @@ -77,7 +75,7 @@ def test_get_properties(self) -> None: """Test positive case.""" # Prepare res = GetPropertiesRes(status=CLIENT_STATUS, properties=CLIENT_PROPERTIES) - self.driver.push_messages.side_effect = self._get_push_messages(res) + self.driver.send_and_receive.side_effect = self._exec_send_and_receive(res) request_properties: Config = {"tensor_type": "str"} ins = GetPropertiesIns(config=request_properties) @@ -95,7 +93,7 @@ def test_get_parameters(self) -> None: status=CLIENT_STATUS, parameters=MESSAGE_PARAMETERS, ) - self.driver.push_messages.side_effect = self._get_push_messages(res) + self.driver.send_and_receive.side_effect = self._exec_send_and_receive(res) ins = GetParametersIns(config={}) # Execute @@ -114,7 +112,7 @@ def test_fit(self) -> None: num_examples=10, metrics={}, ) - self.driver.push_messages.side_effect = self._get_push_messages(res) + self.driver.send_and_receive.side_effect = self._exec_send_and_receive(res) parameters = flwr.common.ndarrays_to_parameters([np.ones((2, 2))]) ins = FitIns(parameters, {}) @@ -134,7 +132,7 @@ def test_evaluate(self) -> None: num_examples=0, metrics={}, ) - self.driver.push_messages.side_effect = self._get_push_messages(res) + self.driver.send_and_receive.side_effect = self._exec_send_and_receive(res) parameters = Parameters(tensors=[b"random params%^&*F"], tensor_type="np") ins = EvaluateIns(parameters, {}) @@ -148,7 +146,7 @@ def test_evaluate(self) -> None: def test_get_properties_and_fail(self) -> None: """Test negative case.""" # Prepare - self.driver.push_messages.side_effect = self._get_push_messages( + self.driver.send_and_receive.side_effect = self._exec_send_and_receive( None, error_reply=True ) request_properties: Config = {"tensor_type": "str"} @@ -163,7 +161,7 @@ def test_get_properties_and_fail(self) -> None: def test_get_parameters_and_fail(self) -> None: """Test negative case.""" # Prepare - self.driver.push_messages.side_effect = self._get_push_messages( + self.driver.send_and_receive.side_effect = self._exec_send_and_receive( None, error_reply=True ) ins = GetParametersIns(config={}) @@ -177,7 +175,7 @@ def test_get_parameters_and_fail(self) -> None: def test_fit_and_fail(self) -> None: """Test negative case.""" # Prepare - self.driver.push_messages.side_effect = self._get_push_messages( + self.driver.send_and_receive.side_effect = self._exec_send_and_receive( None, error_reply=True ) parameters = flwr.common.ndarrays_to_parameters([np.ones((2, 2))]) @@ -190,7 +188,7 @@ def test_fit_and_fail(self) -> None: def test_evaluate_and_fail(self) -> None: """Test negative case.""" # Prepare - self.driver.push_messages.side_effect = self._get_push_messages( + self.driver.send_and_receive.side_effect = self._exec_send_and_receive( None, error_reply=True ) parameters = Parameters(tensors=[b"random params%^&*F"], tensor_type="np") @@ -229,15 +227,15 @@ def _create_message_dummy( # pylint: disable=R0913 self.created_msg = Message(metadata=metadata, content=content) return self.created_msg - def _get_push_messages( + def _exec_send_and_receive( self, res: Union[GetParametersRes, GetPropertiesRes, FitRes, EvaluateRes, None], error_reply: bool = False, - ) -> Callable[[Iterable[Message]], Iterable[str]]: - """Get the push_messages function that sets the return value of pull_messages - when called.""" + ) -> Callable[[Iterable[Message]], Iterable[Message]]: + """Get the generate_replies function that sets the return value of driver's + send_and_receive when called.""" - def push_messages(messages: Iterable[Message]) -> Iterable[str]: + def generate_replies(messages: Iterable[Message]) -> Iterable[Message]: msg = list(messages)[0] if error_reply: recordset = None @@ -254,13 +252,11 @@ def push_messages(messages: Iterable[Message]) -> Iterable[str]: raise ValueError(f"Unsupported type: {type(res)}") if recordset is not None: ret = msg.create_reply(recordset) - ret.metadata.__dict__["_message_id"] = REPLY_MESSAGE_ID - # Set the return value of `pull_messages` - self.driver.pull_messages.return_value = [ret] - return [INSTRUCTION_MESSAGE_ID] + # Reply messages given the push message + return [ret] - return push_messages + return generate_replies def _common_assertions(self, original_ins: Any) -> None: """Check common assertions.""" @@ -275,18 +271,9 @@ def _common_assertions(self, original_ins: Any) -> None: self.assertEqual(self.called_times, 1) self.assertEqual(actual_ins, original_ins) - # Check if push_messages is called once with expected args/kwargs. - self.driver.push_messages.assert_called_once() + # Check if send_and_receive is called once with expected args/kwargs. + self.driver.send_and_receive.assert_called_once() try: - self.driver.push_messages.assert_any_call([self.created_msg]) + self.driver.send_and_receive.assert_any_call([self.created_msg]) except AssertionError: - self.driver.push_messages.assert_any_call(messages=[self.created_msg]) - - # Check if pull_messages is called once with expected args/kwargs. - self.driver.pull_messages.assert_called_once() - try: - self.driver.pull_messages.assert_called_with([INSTRUCTION_MESSAGE_ID]) - except AssertionError: - self.driver.pull_messages.assert_called_with( - message_ids=[INSTRUCTION_MESSAGE_ID] - ) + self.driver.send_and_receive.assert_any_call(messages=[self.created_msg])