Skip to content

Commit

Permalink
more
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq committed Jan 23, 2024
1 parent 81db36f commit c5510ca
Showing 1 changed file with 21 additions and 18 deletions.
39 changes: 21 additions & 18 deletions src/py/flwr/client/middleware/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@
import unittest
from typing import List

from flwr.client.run_state import RunState
from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer
from flwr.client.typing import FlowerCallable, Layer
from flwr.common.configsrecord import ConfigsRecord
from flwr.common.flowercontext import FlowerContext, Metadata
from flwr.common.recordset import RecordSet
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611

from .utils import make_ffn

Expand All @@ -33,9 +31,11 @@ def make_mock_middleware(name: str, footprint: List[str]) -> Layer:

def middleware(context: FlowerContext, app: FlowerCallable) -> FlowerContext:
footprint.append(name)
# add empty ConfigRegcord to in_message for this middleware layer
context.in_message.set_configs(name=name, record=ConfigsRecord())
ctx: FlowerContext = app(context)
footprint.append(name)
# add empty ConfigRegcord to out_message for this middleware layer
ctx.out_message.set_configs(name=name, record=ConfigsRecord())
return ctx

Expand All @@ -54,6 +54,15 @@ def app(context: FlowerContext) -> FlowerContext:
return app


def _get_dummy_flower_context() -> FlowerContext:
return FlowerContext(
in_message=RecordSet(),
out_message=RecordSet(),
local=RecordSet(),
metadata=Metadata(run_id=0, task_id="", group_id="", ttl="", task_type="mock"),
)


class TestMakeApp(unittest.TestCase):
"""Tests for the `make_app` function."""

Expand All @@ -67,14 +76,7 @@ def test_multiple_middlewares(self) -> None:
make_mock_middleware(name, footprint) for name in mock_middleware_names
]

context = FlowerContext(
in_message=RecordSet(),
out_message=RecordSet(),
local=RecordSet(),
metadata=Metadata(
run_id=0, task_id="", group_id="", ttl="", task_type="mock"
),
)
context = _get_dummy_flower_context()

# Execute
wrapped_app = make_ffn(mock_app, mock_middleware_layers)
Expand All @@ -94,20 +96,21 @@ def test_filter(self) -> None:
# Prepare
footprint: List[str] = []
mock_app = make_mock_app("app", footprint)
task_ins = TaskIns()
context = _get_dummy_flower_context()

def filter_layer(fwd: Fwd, _: FlowerCallable) -> Bwd:
def filter_layer(context: FlowerContext, _: FlowerCallable) -> FlowerContext:
footprint.append("filter")
fwd.task_ins.task_id += "filter"
context.in_message.set_configs(name="filter", record=ConfigsRecord())
context.out_message.set_configs(name="filter", record=ConfigsRecord())
# Skip calling app
return Bwd(task_res=TaskRes(task_id="filter"), state=RunState({}))
return context

# Execute
wrapped_app = make_ffn(mock_app, [filter_layer])
task_res = wrapped_app(Fwd(task_ins=task_ins, state=RunState({}))).task_res
context_ = wrapped_app(context)

# Assert
self.assertEqual(footprint, ["filter"])
# pylint: disable-next=no-member
self.assertEqual(task_ins.task_id, "filter")
self.assertEqual(task_res.task_id, "filter")
self.assertEqual(list(context_.in_message.configs.keys())[0], "filter")
self.assertEqual(list(context_.out_message.configs.keys())[0], "filter")

0 comments on commit c5510ca

Please sign in to comment.