Skip to content

Commit

Permalink
Improve Message tests (#3040)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Mar 5, 2024
1 parent c05df2a commit a62d4fa
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 59 deletions.
2 changes: 2 additions & 0 deletions src/py/flwr/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .grpc import GRPC_MAX_MESSAGE_LENGTH
from .logger import configure as configure
from .logger import log as log
from .message import Error as Error
from .message import Message as Message
from .message import Metadata as Metadata
from .parameter import bytes_to_ndarray as bytes_to_ndarray
Expand Down Expand Up @@ -74,6 +75,7 @@
"EventType",
"FitIns",
"FitRes",
"Error",
"GetParametersIns",
"GetParametersRes",
"GetPropertiesIns",
Expand Down
109 changes: 109 additions & 0 deletions src/py/flwr/common/message_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2020 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Message tests."""


from contextlib import ExitStack
from typing import Any, Callable

import pytest

# pylint: enable=E0611
from . import RecordSet
from .message import Error, Message
from .serde_test import RecordMaker


@pytest.mark.parametrize(
"content_fn, error_fn, context",
[
(
lambda maker: maker.recordset(1, 1, 1),
None,
None,
), # check when only content is set
(None, lambda code: Error(code=code), None), # check when only error is set
(
lambda maker: maker.recordset(1, 1, 1),
lambda code: Error(code=code),
pytest.raises(ValueError),
), # check when both are set (ERROR)
(None, None, pytest.raises(ValueError)), # check when neither is set (ERROR)
],
)
def test_message_creation(
content_fn: Callable[
[
RecordMaker,
],
RecordSet,
],
error_fn: Callable[[int], Error],
context: Any,
) -> None:
"""Test Message creation attempting to pass content and/or error."""
# Prepare
maker = RecordMaker(state=2)
metadata = maker.metadata()

with ExitStack() as stack:
if context:
stack.enter_context(context)

_ = Message(
metadata=metadata,
content=None if content_fn is None else content_fn(maker),
error=None if error_fn is None else error_fn(0),
)


def create_message_with_content() -> Message:
"""Create a Message with content."""
maker = RecordMaker(state=2)
metadata = maker.metadata()
return Message(metadata=metadata, content=RecordSet())


def create_message_with_error() -> Message:
"""Create a Message with error."""
maker = RecordMaker(state=2)
metadata = maker.metadata()
return Message(metadata=metadata, error=Error(code=1))


@pytest.mark.parametrize(
"message_creation_fn",
[
create_message_with_content,
create_message_with_error,
],
)
def test_altering_message(
message_creation_fn: Callable[
[],
Message,
],
) -> None:
"""Test that a message with content doesn't allow setting an error.
And viceversa.
"""
message = message_creation_fn()

with pytest.raises(ValueError):
if message.has_content():
message.error = Error(code=123)
if message.has_error():
message.content = RecordSet()
93 changes: 34 additions & 59 deletions src/py/flwr/common/serde_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import random
import string
from contextlib import ExitStack
from typing import Any, Callable, Optional, OrderedDict, Type, TypeVar, Union, cast

import pytest
Expand Down Expand Up @@ -301,20 +300,13 @@ def test_recordset_serialization_deserialization() -> None:


@pytest.mark.parametrize(
"content_fn, error_fn, context",
"content_fn, error_fn",
[
(
lambda maker: maker.recordset(1, 1, 1),
None,
None,
), # check when only content is set
(None, lambda code: Error(code=code), None), # check when only error is set
(
lambda maker: maker.recordset(1, 1, 1),
lambda code: Error(code=code),
pytest.raises(ValueError),
), # check when both are set (ERROR)
(None, None, pytest.raises(ValueError)), # check when neither is set (ERROR)
(None, lambda code: Error(code=code)), # check when only error is set
],
)
def test_message_to_and_from_taskins(
Expand All @@ -325,7 +317,6 @@ def test_message_to_and_from_taskins(
RecordSet,
],
error_fn: Callable[[int], Error],
context: Any,
) -> None:
"""Test Message to and from TaskIns."""
# Prepare
Expand All @@ -335,44 +326,33 @@ def test_message_to_and_from_taskins(
# pylint: disable-next=protected-access
metadata._src_node_id = 0 # Assume driver node

with ExitStack() as stack:
if context:
stack.enter_context(context)

original = Message(
metadata=metadata,
content=None if content_fn is None else content_fn(maker),
error=None if error_fn is None else error_fn(0),
)
original = Message(
metadata=metadata,
content=None if content_fn is None else content_fn(maker),
error=None if error_fn is None else error_fn(0),
)

# Execute
taskins = message_to_taskins(original)
taskins.task_id = metadata.message_id
deserialized = message_from_taskins(taskins)
# Execute
taskins = message_to_taskins(original)
taskins.task_id = metadata.message_id
deserialized = message_from_taskins(taskins)

# Assert
if original.has_content():
assert original.content == deserialized.content
if original.has_error():
assert original.error == deserialized.error
assert metadata == deserialized.metadata
# Assert
if original.has_content():
assert original.content == deserialized.content
if original.has_error():
assert original.error == deserialized.error
assert metadata == deserialized.metadata


@pytest.mark.parametrize(
"content_fn, error_fn, context",
"content_fn, error_fn",
[
(
lambda maker: maker.recordset(1, 1, 1),
None,
None,
), # check when only content is set
(None, lambda code: Error(code=code), None), # check when only error is set
(
lambda maker: maker.recordset(1, 1, 1),
lambda code: Error(code=code),
pytest.raises(ValueError),
), # check when both are set (ERROR)
(None, None, pytest.raises(ValueError)), # check when neither is set (ERROR)
(None, lambda code: Error(code=code)), # check when only error is set
],
)
def test_message_to_and_from_taskres(
Expand All @@ -383,32 +363,27 @@ def test_message_to_and_from_taskres(
RecordSet,
],
error_fn: Callable[[int], Error],
context: Any,
) -> None:
"""Test Message to and from TaskRes."""
# Prepare
maker = RecordMaker(state=2)
metadata = maker.metadata()
metadata.dst_node_id = 0 # Assume driver node

with ExitStack() as stack:
if context:
stack.enter_context(context)

original = Message(
metadata=metadata,
content=None if content_fn is None else content_fn(maker),
error=None if error_fn is None else error_fn(0),
)
original = Message(
metadata=metadata,
content=None if content_fn is None else content_fn(maker),
error=None if error_fn is None else error_fn(0),
)

# Execute
taskres = message_to_taskres(original)
taskres.task_id = metadata.message_id
deserialized = message_from_taskres(taskres)
# Execute
taskres = message_to_taskres(original)
taskres.task_id = metadata.message_id
deserialized = message_from_taskres(taskres)

# Assert
if original.has_content():
assert original.content == deserialized.content
if original.has_error():
assert original.error == deserialized.error
assert metadata == deserialized.metadata
# Assert
if original.has_content():
assert original.content == deserialized.content
if original.has_error():
assert original.error == deserialized.error
assert metadata == deserialized.metadata

0 comments on commit a62d4fa

Please sign in to comment.