Skip to content

Commit

Permalink
update taskins taskres validation
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Jan 27, 2024
1 parent d25ebb2 commit ecee3ca
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 162 deletions.
4 changes: 1 addition & 3 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ def receive() -> Optional[TaskIns]:
task_ins: Optional[TaskIns] = get_task_ins(response)

# Discard the current TaskIns if not valid
if task_ins is not None and not validate_task_ins(
task_ins, discard_reconnect_ins=True
):
if task_ins is not None and not validate_task_ins(task_ins):
task_ins = None

# Remember `task_ins` until `task_res` is available
Expand Down
51 changes: 2 additions & 49 deletions src/py/flwr/client/message_handler/task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,45 +20,24 @@
from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
ClientMessage,
ServerMessage,
)


def validate_task_ins(task_ins: TaskIns, discard_reconnect_ins: bool) -> bool:
def validate_task_ins(task_ins: TaskIns) -> bool:
"""Validate a TaskIns before it entering the message handling process.
Parameters
----------
task_ins: TaskIns
The task instruction coming from the server.
discard_reconnect_ins: bool
If True, ReconnectIns will not be considered as valid content.
Returns
-------
is_valid: bool
True if the TaskIns is deemed valid and therefore suitable for
undergoing the message handling process, False otherwise.
"""
# Check if the task_ins contains legacy_server_message or sa.
# If legacy_server_message is set, check if ServerMessage is one of
# {GetPropertiesIns, GetParametersIns, FitIns, EvaluateIns, ReconnectIns*}
# Discard ReconnectIns if discard_reconnect_ins is true.
if (
not task_ins.HasField("task")
or (
not task_ins.task.HasField("legacy_server_message")
and not task_ins.task.HasField("sa")
)
or (
discard_reconnect_ins
and task_ins.task.legacy_server_message.WhichOneof("msg") == "reconnect_ins"
)
):
if not (task_ins.HasField("task") and task_ins.task.HasField("recordset")):
return False

return True


Expand Down Expand Up @@ -110,32 +89,6 @@ def get_task_ins(
return task_ins


def get_server_message_from_task_ins(
task_ins: TaskIns, exclude_reconnect_ins: bool
) -> Optional[ServerMessage]:
"""Get ServerMessage from TaskIns, if available."""
# Return the message if it is in
# {GetPropertiesIns, GetParametersIns, FitIns, EvaluateIns}
# Return the message if it is ReconnectIns and exclude_reconnect_ins is False.
if not validate_task_ins(
task_ins, discard_reconnect_ins=exclude_reconnect_ins
) or not task_ins.task.HasField("legacy_server_message"):
return None

return task_ins.task.legacy_server_message


def wrap_client_message_in_task_res(client_message: ClientMessage) -> TaskRes:
"""Wrap ClientMessage in TaskRes."""
# Instantiate a TaskRes, only filling client_message field.
return TaskRes(
task_id="",
group_id="",
run_id=0,
task=Task(ancestry=[], legacy_client_message=client_message),
)


def configure_task_res(
task_res: TaskRes, ref_task_ins: TaskIns, producer: Node
) -> TaskRes:
Expand Down
116 changes: 9 additions & 107 deletions src/py/flwr/client/message_handler/task_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,75 +16,35 @@


from flwr.client.message_handler.task_handler import (
get_server_message_from_task_ins,
get_task_ins,
validate_task_ins,
validate_task_res,
wrap_client_message_in_task_res,
)
from flwr.common import serde
from flwr.common.recordset import RecordSet
from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611
from flwr.proto.task_pb2 import ( # pylint: disable=E0611
SecureAggregation,
Task,
TaskIns,
TaskRes,
)
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
ClientMessage,
ServerMessage,
)
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611


def test_validate_task_ins_no_task() -> None:
"""Test validate_task_ins."""
task_ins = TaskIns(task=None)

assert not validate_task_ins(task_ins, discard_reconnect_ins=True)
assert not validate_task_ins(task_ins, discard_reconnect_ins=False)
assert not validate_task_ins(task_ins)


def test_validate_task_ins_no_content() -> None:
"""Test validate_task_ins."""
task_ins = TaskIns(task=Task(legacy_server_message=None, sa=None))

assert not validate_task_ins(task_ins, discard_reconnect_ins=True)
assert not validate_task_ins(task_ins, discard_reconnect_ins=False)


def test_validate_task_ins_with_reconnect_ins() -> None:
"""Test validate_task_ins."""
task_ins = TaskIns(
task=Task(
legacy_server_message=ServerMessage(
reconnect_ins=ServerMessage.ReconnectIns(seconds=3)
)
)
)

assert not validate_task_ins(task_ins, discard_reconnect_ins=True)
assert validate_task_ins(task_ins, discard_reconnect_ins=False)


def test_validate_task_ins_valid_legacy_server_message() -> None:
"""Test validate_task_ins."""
task_ins = TaskIns(
task=Task(
legacy_server_message=ServerMessage(
get_properties_ins=ServerMessage.GetPropertiesIns()
)
)
)
task_ins = TaskIns(task=Task(recordset=None))

assert validate_task_ins(task_ins, discard_reconnect_ins=True)
assert validate_task_ins(task_ins, discard_reconnect_ins=False)
assert not validate_task_ins(task_ins)


def test_validate_task_ins_valid_sa() -> None:
def test_validate_task_ins_valid() -> None:
"""Test validate_task_ins."""
task_ins = TaskIns(task=Task(sa=SecureAggregation()))
task_ins = TaskIns(task=Task(recordset=serde.recordset_to_proto(RecordSet())))

assert validate_task_ins(task_ins, discard_reconnect_ins=True)
assert validate_task_ins(task_ins, discard_reconnect_ins=False)
assert validate_task_ins(task_ins)


def test_validate_task_res() -> None:
Expand Down Expand Up @@ -142,61 +102,3 @@ def test_get_task_ins_multiple_ins() -> None:
)
actual_task_ins = get_task_ins(res)
assert actual_task_ins == expected_task_ins


def test_get_server_message_from_task_ins_invalid() -> None:
"""Test get_server_message_from_task_ins."""
task_ins = TaskIns(task=Task(legacy_server_message=None))
msg_t = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=True)
msg_f = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False)

assert msg_t is None
assert msg_f is None


def test_get_server_message_from_task_ins_reconnect_ins() -> None:
"""Test get_server_message_from_task_ins."""
expected_server_message = ServerMessage(
reconnect_ins=ServerMessage.ReconnectIns(seconds=3)
)
task_ins = TaskIns(task=Task(legacy_server_message=expected_server_message))
msg_t = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=True)
msg_f = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False)

assert msg_t is None
assert msg_f == expected_server_message


def test_get_server_message_from_task_ins_sa() -> None:
"""Test get_server_message_from_task_ins."""
task_ins = TaskIns(task=Task(sa=SecureAggregation()))
msg_t = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=True)
msg_f = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False)

assert msg_t is None
assert msg_f is None


def test_get_server_message_from_task_ins_valid_legacy_server_message() -> None:
"""Test get_server_message_from_task_ins."""
expected_server_message = ServerMessage(
get_properties_ins=ServerMessage.GetPropertiesIns()
)
task_ins = TaskIns(task=Task(legacy_server_message=expected_server_message))
msg_t = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=True)
msg_f = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False)

assert msg_t == expected_server_message
assert msg_f == expected_server_message


def test_wrap_client_message_in_task_res() -> None:
"""Test wrap_client_message_in_task_res."""
expected_client_message = ClientMessage(
get_properties_res=ClientMessage.GetPropertiesRes()
)
task_res = wrap_client_message_in_task_res(expected_client_message)

assert validate_task_res(task_res)
# pylint: disable-next=no-member
assert task_res.task.legacy_client_message == expected_client_message
4 changes: 1 addition & 3 deletions src/py/flwr/client/rest_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,7 @@ def receive() -> Optional[TaskIns]:
task_ins: Optional[TaskIns] = get_task_ins(pull_task_ins_response_proto)

# Discard the current TaskIns if not valid
if task_ins is not None and not validate_task_ins(
task_ins, discard_reconnect_ins=True
):
if task_ins is not None and not validate_task_ins(task_ins):
task_ins = None

# Remember `task_ins` until `task_res` is available
Expand Down

0 comments on commit ecee3ca

Please sign in to comment.