Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define Error structure in common.Message. #3034

Merged
merged 27 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ae3fd34
init
jafermarq Feb 28, 2024
62bc05a
w/ previous
jafermarq Feb 28, 2024
2f4d839
init
jafermarq Feb 28, 2024
daa23f1
init Error dataclass def
jafermarq Feb 28, 2024
2df4dea
updated protos
jafermarq Feb 28, 2024
2c5eae8
Merge branch 'proto-error' into error-message
jafermarq Feb 28, 2024
29c95c2
basic flow
jafermarq Feb 28, 2024
eb2e313
adding protobuf files
jafermarq Feb 28, 2024
bd595bb
added proto files
jafermarq Feb 28, 2024
08c76f6
Merge branch 'proto-error' into error-message
jafermarq Feb 29, 2024
e29fe86
Merge branch 'main' into error-message
danieljanes Feb 29, 2024
e03d96c
mutex `content` and `error`
jafermarq Feb 29, 2024
fae2c8a
mutex content and error; and fixes
jafermarq Feb 29, 2024
87207c6
w/ previous
jafermarq Feb 29, 2024
370ce25
post review updates
jafermarq Feb 29, 2024
79cb7f2
fixing existing tests
jafermarq Feb 29, 2024
5758f8f
back
jafermarq Feb 29, 2024
e51eee1
updates to tests
jafermarq Feb 29, 2024
40a17f9
Merge branch 'main' into error-message
danieljanes Feb 29, 2024
145853c
Update src/py/flwr/common/message.py
danieljanes Feb 29, 2024
d1b9113
Update src/py/flwr/common/message.py
danieljanes Feb 29, 2024
093587e
Apply suggestions from code review
danieljanes Feb 29, 2024
140a7d0
Update src/py/flwr/server/driver/driver_test.py
danieljanes Feb 29, 2024
f77df39
Update src/py/flwr/common/message.py
danieljanes Feb 29, 2024
91c19ce
Merge branch 'main' into error-message
danieljanes Feb 29, 2024
708d4ce
update based on comments
jafermarq Feb 29, 2024
62c1dce
reverting after fixes
jafermarq Feb 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]:
reason = cast(int, disconnect_msg.disconnect_res.reason)
recordset = RecordSet()
recordset.configs_records["config"] = ConfigsRecord({"reason": reason})
out_message = message.create_reply(recordset, ttl="")
out_message = message.create_reply(content=recordset, ttl="")
# Return TaskRes and sleep duration
return out_message, sleep_duration

Expand Down Expand Up @@ -148,7 +148,7 @@ def handle_legacy_message_from_msgtype(
raise ValueError(f"Invalid message type: {message_type}")

# Return Message
return message.create_reply(out_recordset, ttl="")
return message.create_reply(content=out_recordset, ttl="")


def _reconnect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def secaggplus_mod(

# Return message
content = RecordSet(configs_records={RECORD_KEY_CONFIGS: ConfigsRecord(res, False)})
return msg.create_reply(content, ttl="")
return msg.create_reply(content=content, ttl="")


def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_test_handler(
"""."""

def empty_ffn(_msg: Message, _2: Context) -> Message:
return _msg.create_reply(RecordSet(), ttl="")
return _msg.create_reply(content=RecordSet(), ttl="")

app = make_ffn(empty_ffn, [secaggplus_mod])

Expand Down
111 changes: 104 additions & 7 deletions src/py/flwr/common/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,36 @@ def partition_id(self, value: int) -> None:
self._partition_id = value


@dataclass
class Error:
"""A dataclass that stores information about an error that ocurred.

Parameters
----------
code : int
An identifier for the error.
reasong : Optional[str]
A reason for why the error arised (e.g. an exception stack-trace)
"""

_code: int
_reason: str | None = None

def __init__(self, code: int, reason: str | None = None) -> None:
self._code = code
self._reason = reason

@property
def code(self) -> int:
"""Eerror code."""
return self._code

@property
def reason(self) -> str | None:
"""Reason reported about the error."""
return self._reason


@dataclass
class Message:
"""State of your application from the viewpoint of the entity using it.
Expand All @@ -162,17 +192,31 @@ class Message:
----------
metadata : Metadata
A dataclass including information about the message to be executed.
content : RecordSet
content : Optional[RecordSet]
Holds records either sent by another entity (e.g. sent by the server-side
logic to a client, or vice-versa) or that will be sent to it.
error : Optional[Error]
A dataclass that captures information about an error that took place
when processing another message.
"""

_metadata: Metadata
_content: RecordSet
_content: RecordSet | None = None
_error: Error | None = None

def __init__(self, metadata: Metadata, content: RecordSet) -> None:
def __init__(
self,
metadata: Metadata,
content: RecordSet | None = None,
error: Error | None = None,
) -> None:
self._metadata = metadata

if not (content is None) ^ (error is None):
raise ValueError("Either `content` or `error` must be set, but not both.")

self._content = content
self._error = error

@property
def metadata(self) -> Metadata:
Expand All @@ -182,14 +226,67 @@ def metadata(self) -> Metadata:
@property
def content(self) -> RecordSet:
"""The content of this message."""
if self._content is None:
raise ValueError(
"Message content is None. Use <message>.has_content() "
"if you'd like to check first if a message has content."
)
return self._content

@content.setter
def content(self, value: RecordSet) -> None:
"""Set content."""
self._content = value
if self._error is None:
self._content = value
else:
raise ValueError("A message with an error set cannot have content.")

@property
def error(self) -> Error:
"""Error captured by this message."""
if self._error is None:
raise ValueError(
"Message error is None. Use <message>.has_error() "
"if you'd like to check first if a message carries an error."
)
return self._error

@error.setter
def error(self, value: Error) -> None:
"""Set error."""
if self.has_content():
raise ValueError("A message with content set cannot carry an error.")
self._error = value

def has_content(self) -> bool:
"""Return True if message has content, else False."""
return self._content is not None

def has_error(self) -> bool:
"""Return True if message has an error, else False."""
return self._error is not None

def create_error(
self,
error: Error,
ttl: str,
) -> Message:
"""Construct valid response message indicating an error happened.

Parameters
----------
error : Error
The error that was encountered.
ttl : str
Time-to-live for this message.
"""
# Create reply without content
message = self.create_reply(ttl=ttl)
# Set error
message.error = error
return message

def create_reply(self, content: RecordSet, ttl: str) -> Message:
def create_reply(self, ttl: str, content: RecordSet | None = None) -> Message:
"""Create a reply to this message with specified content and TTL.

The method generates a new `Message` as a reply to this message.
Expand All @@ -198,10 +295,10 @@ def create_reply(self, content: RecordSet, ttl: str) -> Message:

Parameters
----------
content : RecordSet
The content for the reply message.
ttl : str
Time-to-live for this message.
content : Optional[RecordSet]
The content for the reply message.

Returns
-------
Expand Down
54 changes: 47 additions & 7 deletions src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from google.protobuf.message import Message as GrpcMessage

# pylint: disable=E0611
from flwr.proto.error_pb2 import Error as ProtoError
from flwr.proto.node_pb2 import Node
from flwr.proto.recordset_pb2 import Array as ProtoArray
from flwr.proto.recordset_pb2 import BoolList, BytesList
Expand All @@ -44,7 +45,7 @@

# pylint: enable=E0611
from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet, typing
from .message import Message, Metadata
from .message import Error, Message, Metadata
from .record.typeddict import TypedDict

# === Parameters message ===
Expand Down Expand Up @@ -512,6 +513,21 @@ def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord
)


# === Error message ===


def error_to_proto(error: Error) -> ProtoError:
"""Serialize Error to ProtoBuf."""
reason = error.reason if error.reason else ""
return ProtoError(code=error.code, reason=reason)


def error_from_proto(error_proto: ProtoError) -> Error:
"""Deserialize Error from ProtoBuf."""
reason = error_proto.reason if len(error_proto.reason) > 0 else None
return Error(code=error_proto.code, reason=reason)


# === RecordSet message ===


Expand Down Expand Up @@ -562,7 +578,10 @@ def message_to_taskins(message: Message) -> TaskIns:
ttl=md.ttl,
ancestry=[md.reply_to_message] if md.reply_to_message != "" else [],
task_type=md.message_type,
recordset=recordset_to_proto(message.content),
recordset=(
recordset_to_proto(message.content) if message.has_content() else None
),
error=error_to_proto(message.error) if message.has_error() else None,
),
)

Expand All @@ -581,10 +600,19 @@ def message_from_taskins(taskins: TaskIns) -> Message:
message_type=taskins.task.task_type,
)

# Return the Message
# Construct Message
return Message(
metadata=metadata,
content=recordset_from_proto(taskins.task.recordset),
content=(
recordset_from_proto(taskins.task.recordset)
if taskins.task.HasField("recordset")
else None
),
error=(
error_from_proto(taskins.task.error)
if taskins.task.HasField("error")
else None
),
)


Expand All @@ -601,7 +629,10 @@ def message_to_taskres(message: Message) -> TaskRes:
ttl=md.ttl,
ancestry=[md.reply_to_message] if md.reply_to_message != "" else [],
task_type=md.message_type,
recordset=recordset_to_proto(message.content),
recordset=(
recordset_to_proto(message.content) if message.has_content() else None
),
error=error_to_proto(message.error) if message.has_error() else None,
),
)

Expand All @@ -620,8 +651,17 @@ def message_from_taskres(taskres: TaskRes) -> Message:
message_type=taskres.task.task_type,
)

# Return the Message
# Construct the Message
return Message(
metadata=metadata,
content=recordset_from_proto(taskres.task.recordset),
content=(
recordset_from_proto(taskres.task.recordset)
if taskres.task.HasField("recordset")
else None
),
error=(
error_from_proto(taskres.task.error)
if taskres.task.HasField("error")
else None
),
)
Loading