Skip to content

Commit

Permalink
refactor(framework) Make DriverClientProxy use driver's `send_and_r…
Browse files Browse the repository at this point in the history
…eceive` (#4289)
  • Loading branch information
jafermarq authored Oct 8, 2024
1 parent ef07d49 commit 850ccce
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 62 deletions.
44 changes: 15 additions & 29 deletions src/py/flwr/server/compat/driver_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Flower ClientProxy implementation for Driver API."""


import time
from typing import Optional

from flwr import common
Expand All @@ -25,8 +24,6 @@

from ..driver.driver import Driver

SLEEP_TIME = 1


class DriverClientProxy(ClientProxy):
"""Flower client proxy which delegates work using the Driver API."""
Expand Down Expand Up @@ -122,29 +119,18 @@ def _send_receive_recordset(
ttl=timeout,
)

# Push message
message_ids = list(self.driver.push_messages(messages=[message]))
if len(message_ids) != 1:
raise ValueError("Unexpected number of message_ids")

message_id = message_ids[0]
if message_id == "":
raise ValueError(f"Failed to send message to node {self.node_id}")

if timeout:
start_time = time.time()

while True:
messages = list(self.driver.pull_messages(message_ids))
if len(messages) == 1:
msg: Message = messages[0]
if msg.has_error():
raise ValueError(
f"Message contains an Error (reason: {msg.error.reason}). "
"It originated during client-side execution of a message."
)
return msg.content

if timeout is not None and time.time() > start_time + timeout:
raise RuntimeError("Timeout reached")
time.sleep(SLEEP_TIME)
# Send message and wait for reply
messages = list(self.driver.send_and_receive(messages=[message]))

# A single reply is expected
if len(messages) != 1:
raise ValueError(f"Expected one Message but got: {len(messages)}")

# Only messages without errors can be handled beyond these point
msg: Message = messages[0]
if msg.has_error():
raise ValueError(
f"Message contains an Error (reason: {msg.error.reason}). "
"It originated during client-side execution of a message."
)
return msg.content
53 changes: 20 additions & 33 deletions src/py/flwr/server/compat/driver_client_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@

RUN_ID = 61016
NODE_ID = 1
INSTRUCTION_MESSAGE_ID = "mock instruction message id"
REPLY_MESSAGE_ID = "mock reply message id"


class DriverClientProxyTestCase(unittest.TestCase):
Expand All @@ -77,7 +75,7 @@ def test_get_properties(self) -> None:
"""Test positive case."""
# Prepare
res = GetPropertiesRes(status=CLIENT_STATUS, properties=CLIENT_PROPERTIES)
self.driver.push_messages.side_effect = self._get_push_messages(res)
self.driver.send_and_receive.side_effect = self._exec_send_and_receive(res)
request_properties: Config = {"tensor_type": "str"}
ins = GetPropertiesIns(config=request_properties)

Expand All @@ -95,7 +93,7 @@ def test_get_parameters(self) -> None:
status=CLIENT_STATUS,
parameters=MESSAGE_PARAMETERS,
)
self.driver.push_messages.side_effect = self._get_push_messages(res)
self.driver.send_and_receive.side_effect = self._exec_send_and_receive(res)
ins = GetParametersIns(config={})

# Execute
Expand All @@ -114,7 +112,7 @@ def test_fit(self) -> None:
num_examples=10,
metrics={},
)
self.driver.push_messages.side_effect = self._get_push_messages(res)
self.driver.send_and_receive.side_effect = self._exec_send_and_receive(res)
parameters = flwr.common.ndarrays_to_parameters([np.ones((2, 2))])
ins = FitIns(parameters, {})

Expand All @@ -134,7 +132,7 @@ def test_evaluate(self) -> None:
num_examples=0,
metrics={},
)
self.driver.push_messages.side_effect = self._get_push_messages(res)
self.driver.send_and_receive.side_effect = self._exec_send_and_receive(res)
parameters = Parameters(tensors=[b"random params%^&*F"], tensor_type="np")
ins = EvaluateIns(parameters, {})

Expand All @@ -148,7 +146,7 @@ def test_evaluate(self) -> None:
def test_get_properties_and_fail(self) -> None:
"""Test negative case."""
# Prepare
self.driver.push_messages.side_effect = self._get_push_messages(
self.driver.send_and_receive.side_effect = self._exec_send_and_receive(
None, error_reply=True
)
request_properties: Config = {"tensor_type": "str"}
Expand All @@ -163,7 +161,7 @@ def test_get_properties_and_fail(self) -> None:
def test_get_parameters_and_fail(self) -> None:
"""Test negative case."""
# Prepare
self.driver.push_messages.side_effect = self._get_push_messages(
self.driver.send_and_receive.side_effect = self._exec_send_and_receive(
None, error_reply=True
)
ins = GetParametersIns(config={})
Expand All @@ -177,7 +175,7 @@ def test_get_parameters_and_fail(self) -> None:
def test_fit_and_fail(self) -> None:
"""Test negative case."""
# Prepare
self.driver.push_messages.side_effect = self._get_push_messages(
self.driver.send_and_receive.side_effect = self._exec_send_and_receive(
None, error_reply=True
)
parameters = flwr.common.ndarrays_to_parameters([np.ones((2, 2))])
Expand All @@ -190,7 +188,7 @@ def test_fit_and_fail(self) -> None:
def test_evaluate_and_fail(self) -> None:
"""Test negative case."""
# Prepare
self.driver.push_messages.side_effect = self._get_push_messages(
self.driver.send_and_receive.side_effect = self._exec_send_and_receive(
None, error_reply=True
)
parameters = Parameters(tensors=[b"random params%^&*F"], tensor_type="np")
Expand Down Expand Up @@ -229,15 +227,15 @@ def _create_message_dummy( # pylint: disable=R0913
self.created_msg = Message(metadata=metadata, content=content)
return self.created_msg

def _get_push_messages(
def _exec_send_and_receive(
self,
res: Union[GetParametersRes, GetPropertiesRes, FitRes, EvaluateRes, None],
error_reply: bool = False,
) -> Callable[[Iterable[Message]], Iterable[str]]:
"""Get the push_messages function that sets the return value of pull_messages
when called."""
) -> Callable[[Iterable[Message]], Iterable[Message]]:
"""Get the generate_replies function that sets the return value of driver's
send_and_receive when called."""

def push_messages(messages: Iterable[Message]) -> Iterable[str]:
def generate_replies(messages: Iterable[Message]) -> Iterable[Message]:
msg = list(messages)[0]
if error_reply:
recordset = None
Expand All @@ -254,13 +252,11 @@ def push_messages(messages: Iterable[Message]) -> Iterable[str]:
raise ValueError(f"Unsupported type: {type(res)}")
if recordset is not None:
ret = msg.create_reply(recordset)
ret.metadata.__dict__["_message_id"] = REPLY_MESSAGE_ID

# Set the return value of `pull_messages`
self.driver.pull_messages.return_value = [ret]
return [INSTRUCTION_MESSAGE_ID]
# Reply messages given the push message
return [ret]

return push_messages
return generate_replies

def _common_assertions(self, original_ins: Any) -> None:
"""Check common assertions."""
Expand All @@ -275,18 +271,9 @@ def _common_assertions(self, original_ins: Any) -> None:
self.assertEqual(self.called_times, 1)
self.assertEqual(actual_ins, original_ins)

# Check if push_messages is called once with expected args/kwargs.
self.driver.push_messages.assert_called_once()
# Check if send_and_receive is called once with expected args/kwargs.
self.driver.send_and_receive.assert_called_once()
try:
self.driver.push_messages.assert_any_call([self.created_msg])
self.driver.send_and_receive.assert_any_call([self.created_msg])
except AssertionError:
self.driver.push_messages.assert_any_call(messages=[self.created_msg])

# Check if pull_messages is called once with expected args/kwargs.
self.driver.pull_messages.assert_called_once()
try:
self.driver.pull_messages.assert_called_with([INSTRUCTION_MESSAGE_ID])
except AssertionError:
self.driver.pull_messages.assert_called_with(
message_ids=[INSTRUCTION_MESSAGE_ID]
)
self.driver.send_and_receive.assert_any_call(messages=[self.created_msg])

0 comments on commit 850ccce

Please sign in to comment.