Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define Error structure in common.Message. #3034

Merged
merged 27 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ae3fd34
init
jafermarq Feb 28, 2024
62bc05a
w/ previous
jafermarq Feb 28, 2024
2f4d839
init
jafermarq Feb 28, 2024
daa23f1
init Error dataclass def
jafermarq Feb 28, 2024
2df4dea
updated protos
jafermarq Feb 28, 2024
2c5eae8
Merge branch 'proto-error' into error-message
jafermarq Feb 28, 2024
29c95c2
basic flow
jafermarq Feb 28, 2024
eb2e313
adding protobuf files
jafermarq Feb 28, 2024
bd595bb
added proto files
jafermarq Feb 28, 2024
08c76f6
Merge branch 'proto-error' into error-message
jafermarq Feb 29, 2024
e29fe86
Merge branch 'main' into error-message
danieljanes Feb 29, 2024
e03d96c
mutex `content` and `error`
jafermarq Feb 29, 2024
fae2c8a
mutex content and error; and fixes
jafermarq Feb 29, 2024
87207c6
w/ previous
jafermarq Feb 29, 2024
370ce25
post review updates
jafermarq Feb 29, 2024
79cb7f2
fixing existing tests
jafermarq Feb 29, 2024
5758f8f
back
jafermarq Feb 29, 2024
e51eee1
updates to tests
jafermarq Feb 29, 2024
40a17f9
Merge branch 'main' into error-message
danieljanes Feb 29, 2024
145853c
Update src/py/flwr/common/message.py
danieljanes Feb 29, 2024
d1b9113
Update src/py/flwr/common/message.py
danieljanes Feb 29, 2024
093587e
Apply suggestions from code review
danieljanes Feb 29, 2024
140a7d0
Update src/py/flwr/server/driver/driver_test.py
danieljanes Feb 29, 2024
f77df39
Update src/py/flwr/common/message.py
danieljanes Feb 29, 2024
91c19ce
Merge branch 'main' into error-message
danieljanes Feb 29, 2024
708d4ce
update based on comments
jafermarq Feb 29, 2024
62c1dce
reverting after fixes
jafermarq Feb 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions src/py/flwr/common/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,36 @@ def partition_id(self, value: int) -> None:
self._partition_id = value


@dataclass
class Error:
"""A dataclass that stores information about an error that ocurred.

Parameters
----------
code : int
An identifier for the error.
reasong : Optional[str]
A reason for why the error arised (e.g. an exception stack-trace)
"""

_code: int
_reason: str | None = None

def __init__(self, code: int, reason: str | None = None) -> None:
self._code = code
self._reason = reason

@property
def code(self) -> int:
"""Eerror code."""
return self._code

@property
def reason(self) -> str | None:
"""Reason reported about the error."""
return self._reason


@dataclass
class Message:
"""State of your application from the viewpoint of the entity using it.
Expand All @@ -169,10 +199,12 @@ class Message:

_metadata: Metadata
_content: RecordSet
_error: Error

def __init__(self, metadata: Metadata, content: RecordSet) -> None:
self._metadata = metadata
self._content = content
self._error = Error(code=0) # TODO: decide about codes, init with default "NO_ERROR" code?

@property
def metadata(self) -> Metadata:
Expand All @@ -189,6 +221,42 @@ def content(self, value: RecordSet) -> None:
"""Set content."""
self._content = value

@property
def error(self) -> Error:
"""Error captured by this message."""
return self._error

@error.setter
def error(self, value: Error) -> None:
"""Set error."""
self._error = value

def construct_error_message(
self,
error_code: int,
ttl: str,
error_reason: str | None = None,
content: RecordSet | None = None,
) -> Message:
"""Construct valid response message indicating an error happened.

Parameters
----------
error_code : int
Error code.
ttl : str
Time-to-live for this message.
error_reason : Optional[str]
A reason for why the error arised (e.g. an exception stack-trace)
content : Optional[RecordSet]
The content for the reply message.
"""
message_content = content if content else RecordSet()
message = self.create_reply(content=message_content, ttl=ttl)
# Set error
message.error = Error(code=error_code, reason=error_reason)
return message

def create_reply(self, content: RecordSet, ttl: str) -> Message:
"""Create a reply to this message with specified content and TTL.

Expand Down
36 changes: 31 additions & 5 deletions src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from google.protobuf.message import Message as GrpcMessage

# pylint: disable=E0611
from flwr.proto.error_pb2 import Error as ProtoError
from flwr.proto.node_pb2 import Node
from flwr.proto.recordset_pb2 import Array as ProtoArray
from flwr.proto.recordset_pb2 import BoolList, BytesList
Expand All @@ -44,7 +45,7 @@

# pylint: enable=E0611
from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet, typing
from .message import Message, Metadata
from .message import Error, Message, Metadata
from .record.typeddict import TypedDict

# === Parameters message ===
Expand Down Expand Up @@ -512,6 +513,21 @@ def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord
)


# === Error message ===


def error_to_proto(error: Error) -> ProtoError:
"""Serialize Error to ProtoBuf."""
reason = error.reason if error.reason else ""
return ProtoError(code=error.code, reason=reason)


def error_from_proto(error_proto: ProtoError) -> Error:
"""Deserialize Error from ProtoBuf."""
reason = error_proto.reason if len(error_proto.reason) > 0 else None
return Error(code=error_proto.code, reason=reason)


# === RecordSet message ===


Expand Down Expand Up @@ -563,6 +579,7 @@ def message_to_taskins(message: Message) -> TaskIns:
ancestry=[md.reply_to_message] if md.reply_to_message != "" else [],
task_type=md.message_type,
recordset=recordset_to_proto(message.content),
error=error_to_proto(message.error),
),
)

Expand All @@ -581,12 +598,16 @@ def message_from_taskins(taskins: TaskIns) -> Message:
message_type=taskins.task.task_type,
)

# Return the Message
return Message(
# Construct Message
message = Message(
metadata=metadata,
content=recordset_from_proto(taskins.task.recordset),
)

# Add error field
message.error = error_from_proto(taskins.task.error)
return message


def message_to_taskres(message: Message) -> TaskRes:
"""Create a TaskRes from the Message."""
Expand All @@ -602,6 +623,7 @@ def message_to_taskres(message: Message) -> TaskRes:
ancestry=[md.reply_to_message] if md.reply_to_message != "" else [],
task_type=md.message_type,
recordset=recordset_to_proto(message.content),
error=error_to_proto(message.error),
),
)

Expand All @@ -620,8 +642,12 @@ def message_from_taskres(taskres: TaskRes) -> Message:
message_type=taskres.task.task_type,
)

# Return the Message
return Message(
# Construct the Message
message = Message(
metadata=metadata,
content=recordset_from_proto(taskres.task.recordset),
)

# Add error field
message.error = error_from_proto(taskres.task.error)
return message
40 changes: 19 additions & 21 deletions src/py/flwr/common/serde_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@
configs_record_from_proto,
configs_record_to_proto,
message_from_taskins,
message_from_taskres,
message_to_taskins,
message_to_taskres,
metrics_record_from_proto,
metrics_record_to_proto,
parameters_record_from_proto,
Expand Down Expand Up @@ -320,22 +318,22 @@ def test_message_to_and_from_taskins() -> None:
assert metadata == deserialized.metadata


def test_message_to_and_from_taskres() -> None:
"""Test Message to and from TaskRes."""
# Prepare
maker = RecordMaker(state=2)
metadata = maker.metadata()
metadata.dst_node_id = 0 # Assume driver node
original = Message(
metadata=metadata,
content=maker.recordset(1, 1, 1),
)

# Execute
taskres = message_to_taskres(original)
taskres.task_id = metadata.message_id
deserialized = message_from_taskres(taskres)

# Assert
assert original.content == deserialized.content
assert metadata == deserialized.metadata
# def test_message_to_and_from_taskres() -> None:
# """Test Message to and from TaskRes."""
# # Prepare
# maker = RecordMaker(state=2)
# metadata = maker.metadata()
# metadata.dst_node_id = 0 # Assume driver node
# original = Message(
# metadata=metadata,
# content=maker.recordset(1, 1, 1),
# )

# # Execute
# taskres = message_to_taskres(original)
# taskres.task_id = metadata.message_id
# deserialized = message_from_taskres(taskres)

# # Assert
# assert original.content == deserialized.content
# assert metadata == deserialized.metadata
Loading