Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll committed Dec 22, 2023
1 parent 58fb02f commit 8dea6c5
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
4 changes: 3 additions & 1 deletion src/py/flwr/driver/driver_client_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def test_get_properties(self) -> None:
)

# Execute
value: flwr.common.GetPropertiesRes = client.get_properties(ins, timeout=None)
value: flwr.common.GetPropertiesRes = client.get_properties(
ins, timeout=None, group_id=None
)

# Assert
assert value.properties["tensor_type"] == "numpy.ndarray"
Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def _call_client_proxy(
get_properties_res: GetPropertiesRes = client_proxy.get_properties(
ins=get_properties_ins,
timeout=timeout,
group_id=None,
)
get_properties_res_proto = serde.get_properties_res_to_proto(
res=get_properties_res
Expand All @@ -142,6 +143,7 @@ def _call_client_proxy(
get_parameters_res: GetParametersRes = client_proxy.get_parameters(
ins=get_parameters_ins,
timeout=timeout,
group_id=None,
)
get_parameters_res_proto = serde.get_parameters_res_to_proto(
res=get_parameters_res
Expand Down
16 changes: 10 additions & 6 deletions src/py/flwr/server/server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ class SuccessClient(ClientProxy):
"""Test class."""

def get_properties(
self, ins: GetPropertiesIns, timeout: Optional[float]
self, ins: GetPropertiesIns, timeout: Optional[float], group_id: Optional[str]
) -> GetPropertiesRes:
"""Raise an Exception because this method is not expected to be called."""
raise Exception()

def get_parameters(
self, ins: GetParametersIns, timeout: Optional[float]
self, ins: GetParametersIns, timeout: Optional[float], group_id: Optional[str]
) -> GetParametersRes:
"""Raise an Exception because this method is not expected to be called."""
raise Exception()
Expand Down Expand Up @@ -80,7 +80,9 @@ def evaluate(
metrics={},
)

def reconnect(self, ins: ReconnectIns, timeout: Optional[float]) -> DisconnectRes:
def reconnect(
self, ins: ReconnectIns, timeout: Optional[float], group_id: Optional[str]
) -> DisconnectRes:
"""Simulate reconnect by returning a DisconnectRes with UNKNOWN reason."""
return DisconnectRes(reason="UNKNOWN")

Expand All @@ -89,13 +91,13 @@ class FailingClient(ClientProxy):
"""Test class."""

def get_properties(
self, ins: GetPropertiesIns, timeout: Optional[float]
self, ins: GetPropertiesIns, timeout: Optional[float], group_id: Optional[str]
) -> GetPropertiesRes:
"""Raise an Exception to simulate failure in the client."""
raise Exception()

def get_parameters(
self, ins: GetParametersIns, timeout: Optional[float]
self, ins: GetParametersIns, timeout: Optional[float], group_id: Optional[str]
) -> GetParametersRes:
"""Raise an Exception to simulate failure in the client."""
raise Exception()
Expand All @@ -112,7 +114,9 @@ def evaluate(
"""Raise an Exception to simulate failure in the client."""
raise Exception()

def reconnect(self, ins: ReconnectIns, timeout: Optional[float]) -> DisconnectRes:
def reconnect(
self, ins: ReconnectIns, timeout: Optional[float], group_id: Optional[str]
) -> DisconnectRes:
"""Raise an Exception to simulate failure in the client."""
raise Exception()

Expand Down

0 comments on commit 8dea6c5

Please sign in to comment.