Skip to content

Commit

Permalink
Define Error structure in common.Message. (#3034)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <[email protected]>
Co-authored-by: Heng Pan <[email protected]>
  • Loading branch information
3 people authored Mar 1, 2024
1 parent f25ddc1 commit da7125a
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 52 deletions.
129 changes: 114 additions & 15 deletions src/py/flwr/common/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,36 @@ def partition_id(self, value: int) -> None:
self._partition_id = value


@dataclass
class Error:
"""A dataclass that stores information about an error that ocurred.
Parameters
----------
code : int
An identifier for the error.
reason : Optional[str]
A reason for why the error arised (e.g. an exception stack-trace)
"""

_code: int
_reason: str | None = None

def __init__(self, code: int, reason: str | None = None) -> None:
self._code = code
self._reason = reason

@property
def code(self) -> int:
"""Error code."""
return self._code

@property
def reason(self) -> str | None:
"""Reason reported about the error."""
return self._reason


@dataclass
class Message:
"""State of your application from the viewpoint of the entity using it.
Expand All @@ -162,17 +192,31 @@ class Message:
----------
metadata : Metadata
A dataclass including information about the message to be executed.
content : RecordSet
content : Optional[RecordSet]
Holds records either sent by another entity (e.g. sent by the server-side
logic to a client, or vice-versa) or that will be sent to it.
error : Optional[Error]
A dataclass that captures information about an error that took place
when processing another message.
"""

_metadata: Metadata
_content: RecordSet
_content: RecordSet | None = None
_error: Error | None = None

def __init__(self, metadata: Metadata, content: RecordSet) -> None:
def __init__(
self,
metadata: Metadata,
content: RecordSet | None = None,
error: Error | None = None,
) -> None:
self._metadata = metadata

if not (content is None) ^ (error is None):
raise ValueError("Either `content` or `error` must be set, but not both.")

self._content = content
self._error = error

@property
def metadata(self) -> Metadata:
Expand All @@ -182,12 +226,77 @@ def metadata(self) -> Metadata:
@property
def content(self) -> RecordSet:
"""The content of this message."""
if self._content is None:
raise ValueError(
"Message content is None. Use <message>.has_content() "
"to check if a message has content."
)
return self._content

@content.setter
def content(self, value: RecordSet) -> None:
"""Set content."""
self._content = value
if self._error is None:
self._content = value
else:
raise ValueError("A message with an error set cannot have content.")

@property
def error(self) -> Error:
"""Error captured by this message."""
if self._error is None:
raise ValueError(
"Message error is None. Use <message>.has_error() "
"to check first if a message carries an error."
)
return self._error

@error.setter
def error(self, value: Error) -> None:
"""Set error."""
if self.has_content():
raise ValueError("A message with content set cannot carry an error.")
self._error = value

def has_content(self) -> bool:
"""Return True if message has content, else False."""
return self._content is not None

def has_error(self) -> bool:
"""Return True if message has an error, else False."""
return self._error is not None

def _create_reply_metadata(self, ttl: str) -> Metadata:
"""Construct metadata for a reply message."""
return Metadata(
run_id=self.metadata.run_id,
message_id="",
src_node_id=self.metadata.dst_node_id,
dst_node_id=self.metadata.src_node_id,
reply_to_message=self.metadata.message_id,
group_id=self.metadata.group_id,
ttl=ttl,
message_type=self.metadata.message_type,
partition_id=self.metadata.partition_id,
)

def create_error_reply(
self,
error: Error,
ttl: str,
) -> Message:
"""Construct a reply message indicating an error happened.
Parameters
----------
error : Error
The error that was encountered.
ttl : str
Time-to-live for this message.
"""
# Create reply with error
message = Message(metadata=self._create_reply_metadata(ttl), error=error)
return message

def create_reply(self, content: RecordSet, ttl: str) -> Message:
"""Create a reply to this message with specified content and TTL.
Expand All @@ -209,16 +318,6 @@ def create_reply(self, content: RecordSet, ttl: str) -> Message:
A new `Message` instance representing the reply.
"""
return Message(
metadata=Metadata(
run_id=self.metadata.run_id,
message_id="",
src_node_id=self.metadata.dst_node_id,
dst_node_id=self.metadata.src_node_id,
reply_to_message=self.metadata.message_id,
group_id=self.metadata.group_id,
ttl=ttl,
message_type=self.metadata.message_type,
partition_id=self.metadata.partition_id,
),
metadata=self._create_reply_metadata(ttl),
content=content,
)
54 changes: 47 additions & 7 deletions src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from google.protobuf.message import Message as GrpcMessage

# pylint: disable=E0611
from flwr.proto.error_pb2 import Error as ProtoError
from flwr.proto.node_pb2 import Node
from flwr.proto.recordset_pb2 import Array as ProtoArray
from flwr.proto.recordset_pb2 import BoolList, BytesList
Expand All @@ -44,7 +45,7 @@

# pylint: enable=E0611
from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet, typing
from .message import Message, Metadata
from .message import Error, Message, Metadata
from .record.typeddict import TypedDict

# === Parameters message ===
Expand Down Expand Up @@ -512,6 +513,21 @@ def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord
)


# === Error message ===


def error_to_proto(error: Error) -> ProtoError:
"""Serialize Error to ProtoBuf."""
reason = error.reason if error.reason else ""
return ProtoError(code=error.code, reason=reason)


def error_from_proto(error_proto: ProtoError) -> Error:
"""Deserialize Error from ProtoBuf."""
reason = error_proto.reason if len(error_proto.reason) > 0 else None
return Error(code=error_proto.code, reason=reason)


# === RecordSet message ===


Expand Down Expand Up @@ -562,7 +578,10 @@ def message_to_taskins(message: Message) -> TaskIns:
ttl=md.ttl,
ancestry=[md.reply_to_message] if md.reply_to_message != "" else [],
task_type=md.message_type,
recordset=recordset_to_proto(message.content),
recordset=(
recordset_to_proto(message.content) if message.has_content() else None
),
error=error_to_proto(message.error) if message.has_error() else None,
),
)

Expand All @@ -581,10 +600,19 @@ def message_from_taskins(taskins: TaskIns) -> Message:
message_type=taskins.task.task_type,
)

# Return the Message
# Construct Message
return Message(
metadata=metadata,
content=recordset_from_proto(taskins.task.recordset),
content=(
recordset_from_proto(taskins.task.recordset)
if taskins.task.HasField("recordset")
else None
),
error=(
error_from_proto(taskins.task.error)
if taskins.task.HasField("error")
else None
),
)


Expand All @@ -601,7 +629,10 @@ def message_to_taskres(message: Message) -> TaskRes:
ttl=md.ttl,
ancestry=[md.reply_to_message] if md.reply_to_message != "" else [],
task_type=md.message_type,
recordset=recordset_to_proto(message.content),
recordset=(
recordset_to_proto(message.content) if message.has_content() else None
),
error=error_to_proto(message.error) if message.has_error() else None,
),
)

Expand All @@ -620,8 +651,17 @@ def message_from_taskres(taskres: TaskRes) -> Message:
message_type=taskres.task.task_type,
)

# Return the Message
# Construct the Message
return Message(
metadata=metadata,
content=recordset_from_proto(taskres.task.recordset),
content=(
recordset_from_proto(taskres.task.recordset)
if taskres.task.HasField("recordset")
else None
),
error=(
error_from_proto(taskres.task.error)
if taskres.task.HasField("error")
else None
),
)
Loading

0 comments on commit da7125a

Please sign in to comment.