diff --git a/src/py/flwr/common/message.py b/src/py/flwr/common/message.py index 1e1132e42e27..ecb1615a3441 100644 --- a/src/py/flwr/common/message.py +++ b/src/py/flwr/common/message.py @@ -154,6 +154,36 @@ def partition_id(self, value: int) -> None: self._partition_id = value +@dataclass +class Error: + """A dataclass that stores information about an error that ocurred. + + Parameters + ---------- + code : int + An identifier for the error. + reason : Optional[str] + A reason for why the error arised (e.g. an exception stack-trace) + """ + + _code: int + _reason: str | None = None + + def __init__(self, code: int, reason: str | None = None) -> None: + self._code = code + self._reason = reason + + @property + def code(self) -> int: + """Error code.""" + return self._code + + @property + def reason(self) -> str | None: + """Reason reported about the error.""" + return self._reason + + @dataclass class Message: """State of your application from the viewpoint of the entity using it. @@ -162,17 +192,31 @@ class Message: ---------- metadata : Metadata A dataclass including information about the message to be executed. - content : RecordSet + content : Optional[RecordSet] Holds records either sent by another entity (e.g. sent by the server-side logic to a client, or vice-versa) or that will be sent to it. + error : Optional[Error] + A dataclass that captures information about an error that took place + when processing another message. """ _metadata: Metadata - _content: RecordSet + _content: RecordSet | None = None + _error: Error | None = None - def __init__(self, metadata: Metadata, content: RecordSet) -> None: + def __init__( + self, + metadata: Metadata, + content: RecordSet | None = None, + error: Error | None = None, + ) -> None: self._metadata = metadata + + if not (content is None) ^ (error is None): + raise ValueError("Either `content` or `error` must be set, but not both.") + self._content = content + self._error = error @property def metadata(self) -> Metadata: @@ -182,12 +226,77 @@ def metadata(self) -> Metadata: @property def content(self) -> RecordSet: """The content of this message.""" + if self._content is None: + raise ValueError( + "Message content is None. Use .has_content() " + "to check if a message has content." + ) return self._content @content.setter def content(self, value: RecordSet) -> None: """Set content.""" - self._content = value + if self._error is None: + self._content = value + else: + raise ValueError("A message with an error set cannot have content.") + + @property + def error(self) -> Error: + """Error captured by this message.""" + if self._error is None: + raise ValueError( + "Message error is None. Use .has_error() " + "to check first if a message carries an error." + ) + return self._error + + @error.setter + def error(self, value: Error) -> None: + """Set error.""" + if self.has_content(): + raise ValueError("A message with content set cannot carry an error.") + self._error = value + + def has_content(self) -> bool: + """Return True if message has content, else False.""" + return self._content is not None + + def has_error(self) -> bool: + """Return True if message has an error, else False.""" + return self._error is not None + + def _create_reply_metadata(self, ttl: str) -> Metadata: + """Construct metadata for a reply message.""" + return Metadata( + run_id=self.metadata.run_id, + message_id="", + src_node_id=self.metadata.dst_node_id, + dst_node_id=self.metadata.src_node_id, + reply_to_message=self.metadata.message_id, + group_id=self.metadata.group_id, + ttl=ttl, + message_type=self.metadata.message_type, + partition_id=self.metadata.partition_id, + ) + + def create_error_reply( + self, + error: Error, + ttl: str, + ) -> Message: + """Construct a reply message indicating an error happened. + + Parameters + ---------- + error : Error + The error that was encountered. + ttl : str + Time-to-live for this message. + """ + # Create reply with error + message = Message(metadata=self._create_reply_metadata(ttl), error=error) + return message def create_reply(self, content: RecordSet, ttl: str) -> Message: """Create a reply to this message with specified content and TTL. @@ -209,16 +318,6 @@ def create_reply(self, content: RecordSet, ttl: str) -> Message: A new `Message` instance representing the reply. """ return Message( - metadata=Metadata( - run_id=self.metadata.run_id, - message_id="", - src_node_id=self.metadata.dst_node_id, - dst_node_id=self.metadata.src_node_id, - reply_to_message=self.metadata.message_id, - group_id=self.metadata.group_id, - ttl=ttl, - message_type=self.metadata.message_type, - partition_id=self.metadata.partition_id, - ), + metadata=self._create_reply_metadata(ttl), content=content, ) diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index 531a4bde6e9d..6c7a077d2f9f 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -20,6 +20,7 @@ from google.protobuf.message import Message as GrpcMessage # pylint: disable=E0611 +from flwr.proto.error_pb2 import Error as ProtoError from flwr.proto.node_pb2 import Node from flwr.proto.recordset_pb2 import Array as ProtoArray from flwr.proto.recordset_pb2 import BoolList, BytesList @@ -44,7 +45,7 @@ # pylint: enable=E0611 from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet, typing -from .message import Message, Metadata +from .message import Error, Message, Metadata from .record.typeddict import TypedDict # === Parameters message === @@ -512,6 +513,21 @@ def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord ) +# === Error message === + + +def error_to_proto(error: Error) -> ProtoError: + """Serialize Error to ProtoBuf.""" + reason = error.reason if error.reason else "" + return ProtoError(code=error.code, reason=reason) + + +def error_from_proto(error_proto: ProtoError) -> Error: + """Deserialize Error from ProtoBuf.""" + reason = error_proto.reason if len(error_proto.reason) > 0 else None + return Error(code=error_proto.code, reason=reason) + + # === RecordSet message === @@ -562,7 +578,10 @@ def message_to_taskins(message: Message) -> TaskIns: ttl=md.ttl, ancestry=[md.reply_to_message] if md.reply_to_message != "" else [], task_type=md.message_type, - recordset=recordset_to_proto(message.content), + recordset=( + recordset_to_proto(message.content) if message.has_content() else None + ), + error=error_to_proto(message.error) if message.has_error() else None, ), ) @@ -581,10 +600,19 @@ def message_from_taskins(taskins: TaskIns) -> Message: message_type=taskins.task.task_type, ) - # Return the Message + # Construct Message return Message( metadata=metadata, - content=recordset_from_proto(taskins.task.recordset), + content=( + recordset_from_proto(taskins.task.recordset) + if taskins.task.HasField("recordset") + else None + ), + error=( + error_from_proto(taskins.task.error) + if taskins.task.HasField("error") + else None + ), ) @@ -601,7 +629,10 @@ def message_to_taskres(message: Message) -> TaskRes: ttl=md.ttl, ancestry=[md.reply_to_message] if md.reply_to_message != "" else [], task_type=md.message_type, - recordset=recordset_to_proto(message.content), + recordset=( + recordset_to_proto(message.content) if message.has_content() else None + ), + error=error_to_proto(message.error) if message.has_error() else None, ), ) @@ -620,8 +651,17 @@ def message_from_taskres(taskres: TaskRes) -> Message: message_type=taskres.task.task_type, ) - # Return the Message + # Construct the Message return Message( metadata=metadata, - content=recordset_from_proto(taskres.task.recordset), + content=( + recordset_from_proto(taskres.task.recordset) + if taskres.task.HasField("recordset") + else None + ), + error=( + error_from_proto(taskres.task.error) + if taskres.task.HasField("error") + else None + ), ) diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index 1f25fd1852c1..509ce2700948 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -14,10 +14,12 @@ # ============================================================================== """(De-)serialization tests.""" - import random import string -from typing import Any, Optional, OrderedDict, Type, TypeVar, Union, cast +from contextlib import ExitStack +from typing import Any, Callable, Optional, OrderedDict, Type, TypeVar, Union, cast + +import pytest # pylint: disable=E0611 from flwr.proto import transport_pb2 as pb2 @@ -29,7 +31,7 @@ # pylint: enable=E0611 from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet, typing -from .message import Message, Metadata +from .message import Error, Message, Metadata from .serde import ( array_from_proto, array_to_proto, @@ -298,44 +300,115 @@ def test_recordset_serialization_deserialization() -> None: assert original == deserialized -def test_message_to_and_from_taskins() -> None: +@pytest.mark.parametrize( + "content_fn, error_fn, context", + [ + ( + lambda maker: maker.recordset(1, 1, 1), + None, + None, + ), # check when only content is set + (None, lambda code: Error(code=code), None), # check when only error is set + ( + lambda maker: maker.recordset(1, 1, 1), + lambda code: Error(code=code), + pytest.raises(ValueError), + ), # check when both are set (ERROR) + (None, None, pytest.raises(ValueError)), # check when neither is set (ERROR) + ], +) +def test_message_to_and_from_taskins( + content_fn: Callable[ + [ + RecordMaker, + ], + RecordSet, + ], + error_fn: Callable[[int], Error], + context: Any, +) -> None: """Test Message to and from TaskIns.""" # Prepare + maker = RecordMaker(state=1) metadata = maker.metadata() # pylint: disable-next=protected-access metadata._src_node_id = 0 # Assume driver node - original = Message( - metadata=metadata, - content=maker.recordset(1, 1, 1), - ) - # Execute - taskins = message_to_taskins(original) - taskins.task_id = metadata.message_id - deserialized = message_from_taskins(taskins) + with ExitStack() as stack: + if context: + stack.enter_context(context) - # Assert - assert original.content == deserialized.content - assert metadata == deserialized.metadata + original = Message( + metadata=metadata, + content=None if content_fn is None else content_fn(maker), + error=None if error_fn is None else error_fn(0), + ) + # Execute + taskins = message_to_taskins(original) + taskins.task_id = metadata.message_id + deserialized = message_from_taskins(taskins) -def test_message_to_and_from_taskres() -> None: + # Assert + if original.has_content(): + assert original.content == deserialized.content + if original.has_error(): + assert original.error == deserialized.error + assert metadata == deserialized.metadata + + +@pytest.mark.parametrize( + "content_fn, error_fn, context", + [ + ( + lambda maker: maker.recordset(1, 1, 1), + None, + None, + ), # check when only content is set + (None, lambda code: Error(code=code), None), # check when only error is set + ( + lambda maker: maker.recordset(1, 1, 1), + lambda code: Error(code=code), + pytest.raises(ValueError), + ), # check when both are set (ERROR) + (None, None, pytest.raises(ValueError)), # check when neither is set (ERROR) + ], +) +def test_message_to_and_from_taskres( + content_fn: Callable[ + [ + RecordMaker, + ], + RecordSet, + ], + error_fn: Callable[[int], Error], + context: Any, +) -> None: """Test Message to and from TaskRes.""" # Prepare maker = RecordMaker(state=2) metadata = maker.metadata() metadata.dst_node_id = 0 # Assume driver node - original = Message( - metadata=metadata, - content=maker.recordset(1, 1, 1), - ) - # Execute - taskres = message_to_taskres(original) - taskres.task_id = metadata.message_id - deserialized = message_from_taskres(taskres) + with ExitStack() as stack: + if context: + stack.enter_context(context) - # Assert - assert original.content == deserialized.content - assert metadata == deserialized.metadata + original = Message( + metadata=metadata, + content=None if content_fn is None else content_fn(maker), + error=None if error_fn is None else error_fn(0), + ) + + # Execute + taskres = message_to_taskres(original) + taskres.task_id = metadata.message_id + deserialized = message_from_taskres(taskres) + + # Assert + if original.has_content(): + assert original.content == deserialized.content + if original.has_error(): + assert original.error == deserialized.error + assert metadata == deserialized.metadata diff --git a/src/py/flwr/server/driver/driver_test.py b/src/py/flwr/server/driver/driver_test.py index bd5c23a407fd..2bf253222f94 100644 --- a/src/py/flwr/server/driver/driver_test.py +++ b/src/py/flwr/server/driver/driver_test.py @@ -20,6 +20,8 @@ from unittest.mock import Mock, patch from flwr.common import RecordSet +from flwr.common.message import Error +from flwr.common.serde import error_to_proto, recordset_to_proto from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 GetNodesRequest, PullTaskResRequest, @@ -132,9 +134,13 @@ def test_pull_messages_with_given_message_ids(self) -> None: """Test pulling messages with specific message IDs.""" # Prepare mock_response = Mock() + # A Message must have either content or error set so we prepare + # two tasks that contain these. mock_response.task_res_list = [ - TaskRes(task=Task(ancestry=["id2"])), - TaskRes(task=Task(ancestry=["id3"])), + TaskRes( + task=Task(ancestry=["id2"], recordset=recordset_to_proto(RecordSet())) + ), + TaskRes(task=Task(ancestry=["id3"], error=error_to_proto(Error(code=0)))), ] self.mock_grpc_driver.pull_task_res.return_value = mock_response msg_ids = ["id1", "id2", "id3"] @@ -157,7 +163,12 @@ def test_send_and_receive_messages_complete(self) -> None: # Prepare mock_response = Mock(task_ids=["id1"]) self.mock_grpc_driver.push_task_ins.return_value = mock_response - mock_response = Mock(task_res_list=[TaskRes(task=Task(ancestry=["id1"]))]) + # The response message must include either `content` (i.e. a recordset) or + # an `Error`. We choose the latter in this case + error_proto = error_to_proto(Error(code=0)) + mock_response = Mock( + task_res_list=[TaskRes(task=Task(ancestry=["id1"], error=error_proto))] + ) self.mock_grpc_driver.pull_task_res.return_value = mock_response msgs = [self.driver.create_message(RecordSet(), "", 0, "", "")]