diff --git a/src/py/flwr/client/middleware/utils_test.py b/src/py/flwr/client/middleware/utils_test.py index 6a588f3d02eb..88a1121eaf13 100644 --- a/src/py/flwr/client/middleware/utils_test.py +++ b/src/py/flwr/client/middleware/utils_test.py @@ -18,12 +18,10 @@ import unittest from typing import List -from flwr.client.run_state import RunState -from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer +from flwr.client.typing import FlowerCallable, Layer from flwr.common.configsrecord import ConfigsRecord from flwr.common.flowercontext import FlowerContext, Metadata from flwr.common.recordset import RecordSet -from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from .utils import make_ffn @@ -33,9 +31,11 @@ def make_mock_middleware(name: str, footprint: List[str]) -> Layer: def middleware(context: FlowerContext, app: FlowerCallable) -> FlowerContext: footprint.append(name) + # add empty ConfigRegcord to in_message for this middleware layer context.in_message.set_configs(name=name, record=ConfigsRecord()) ctx: FlowerContext = app(context) footprint.append(name) + # add empty ConfigRegcord to out_message for this middleware layer ctx.out_message.set_configs(name=name, record=ConfigsRecord()) return ctx @@ -54,6 +54,15 @@ def app(context: FlowerContext) -> FlowerContext: return app +def _get_dummy_flower_context() -> FlowerContext: + return FlowerContext( + in_message=RecordSet(), + out_message=RecordSet(), + local=RecordSet(), + metadata=Metadata(run_id=0, task_id="", group_id="", ttl="", task_type="mock"), + ) + + class TestMakeApp(unittest.TestCase): """Tests for the `make_app` function.""" @@ -67,14 +76,7 @@ def test_multiple_middlewares(self) -> None: make_mock_middleware(name, footprint) for name in mock_middleware_names ] - context = FlowerContext( - in_message=RecordSet(), - out_message=RecordSet(), - local=RecordSet(), - metadata=Metadata( - run_id=0, task_id="", group_id="", ttl="", task_type="mock" - ), - ) + context = _get_dummy_flower_context() # Execute wrapped_app = make_ffn(mock_app, mock_middleware_layers) @@ -94,20 +96,21 @@ def test_filter(self) -> None: # Prepare footprint: List[str] = [] mock_app = make_mock_app("app", footprint) - task_ins = TaskIns() + context = _get_dummy_flower_context() - def filter_layer(fwd: Fwd, _: FlowerCallable) -> Bwd: + def filter_layer(context: FlowerContext, _: FlowerCallable) -> FlowerContext: footprint.append("filter") - fwd.task_ins.task_id += "filter" + context.in_message.set_configs(name="filter", record=ConfigsRecord()) + context.out_message.set_configs(name="filter", record=ConfigsRecord()) # Skip calling app - return Bwd(task_res=TaskRes(task_id="filter"), state=RunState({})) + return context # Execute wrapped_app = make_ffn(mock_app, [filter_layer]) - task_res = wrapped_app(Fwd(task_ins=task_ins, state=RunState({}))).task_res + context_ = wrapped_app(context) # Assert self.assertEqual(footprint, ["filter"]) # pylint: disable-next=no-member - self.assertEqual(task_ins.task_id, "filter") - self.assertEqual(task_res.task_id, "filter") + self.assertEqual(list(context_.in_message.configs.keys())[0], "filter") + self.assertEqual(list(context_.out_message.configs.keys())[0], "filter")