From 81db36ff4b897cb955a306f2084beb26eb609ec0 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Tue, 23 Jan 2024 19:54:34 +0000 Subject: [PATCH 01/13] v0 --- src/py/flwr/client/app.py | 40 ++++++++++----- src/py/flwr/client/flower.py | 23 +++++---- .../client/message_handler/message_handler.py | 49 +++++++++++++++++++ src/py/flwr/client/middleware/utils.py | 7 +-- src/py/flwr/client/middleware/utils_test.py | 38 +++++++++----- src/py/flwr/client/typing.py | 5 +- 6 files changed, 124 insertions(+), 38 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index ae5beeae07d6..80ffa3914afb 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -24,7 +24,7 @@ from flwr.client.client import Client from flwr.client.flower import Flower -from flwr.client.typing import Bwd, ClientFn, Fwd +from flwr.client.typing import ClientFn from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event from flwr.common.address import parse_address from flwr.common.constant import ( @@ -34,8 +34,11 @@ TRANSPORT_TYPE_REST, TRANSPORT_TYPES, ) +from flwr.common.flowercontext import FlowerContext, Metadata from flwr.common.logger import log, warn_experimental_feature -from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 +from flwr.common.recordset import RecordSet +from flwr.common.serde import recordset_to_proto +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 from .flower import load_flower_callable from .grpc_client.connection import grpc_connection @@ -323,6 +326,15 @@ def _load_app() -> Flower: connection, address = _init_connection(transport, server_address) node_state = NodeState() + # TODO: remove NodeState/RunState logic ? + + # TODO: initialize context here? + context = FlowerContext( + in_message=RecordSet(), + out_message=RecordSet(), + local=RecordSet(), + metadata=Metadata(run_id=-1, task_id="", group_id="", ttl="", task_type=""), + ) while True: sleep_duration: int = 0 @@ -354,24 +366,30 @@ def _load_app() -> Flower: # Register state node_state.register_runstate(run_id=task_ins.run_id) + # TODO: pulate context.metadata and context.in_message from TaskIns + # Load app app: Flower = load_flower_callable_fn() # Handle task message - fwd_msg: Fwd = Fwd( - task_ins=task_ins, - state=node_state.retrieve_runstate(run_id=task_ins.run_id), - ) - bwd_msg: Bwd = app(fwd=fwd_msg) + context_ = app(context=context) # Update node state - node_state.update_runstate( - run_id=bwd_msg.task_res.run_id, - run_state=bwd_msg.state, + # node_state.update_runstate( + # run_id=bwd_msg.task_res.run_id, + # run_state=bwd_msg.state, + # ) + + # TODO: Construct TaskRes from context.out_message + task_res = TaskRes( + task_id=context_.metadata.task_id, + group_id=context_.metadata.group_id, + run_id=context_.metadata.run_id, + task=Task(recordset=recordset_to_proto(context_.out_message)), ) # Send - send(bwd_msg.task_res) + send(task_res) # Unregister node if delete_node is not None: diff --git a/src/py/flwr/client/flower.py b/src/py/flwr/client/flower.py index 535f096e5866..45f07cf58fde 100644 --- a/src/py/flwr/client/flower.py +++ b/src/py/flwr/client/flower.py @@ -18,9 +18,12 @@ import importlib from typing import List, Optional, cast -from flwr.client.message_handler.message_handler import handle +from flwr.client.message_handler.message_handler import ( + handle_legacy_message_from_tasktype, +) from flwr.client.middleware.utils import make_ffn -from flwr.client.typing import Bwd, ClientFn, Fwd, Layer +from flwr.client.typing import ClientFn, Layer +from flwr.common.flowercontext import FlowerContext class Flower: @@ -55,20 +58,20 @@ def __init__( layers: Optional[List[Layer]] = None, ) -> None: # Create wrapper function for `handle` - def ffn(fwd: Fwd) -> Bwd: # pylint: disable=invalid-name - task_res, state_updated = handle( - client_fn=client_fn, - state=fwd.state, - task_ins=fwd.task_ins, + def ffn( + context: FlowerContext, + ) -> FlowerContext: # pylint: disable=invalid-name + context = handle_legacy_message_from_tasktype( + client_fn=client_fn, context=context ) - return Bwd(task_res=task_res, state=state_updated) + return context # Wrap middleware layers around the wrapped handle function self._call = make_ffn(ffn, layers if layers is not None else []) - def __call__(self, fwd: Fwd) -> Bwd: + def __call__(self, context: FlowerContext) -> FlowerContext: """.""" - return self._call(fwd) + return self._call(context) class LoadCallableError(Exception): diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 8cfe909c1738..a289a542e870 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -32,6 +32,17 @@ from flwr.client.secure_aggregation import SecureAggregationHandler from flwr.client.typing import ClientFn from flwr.common import serde +from flwr.common.flowercontext import FlowerContext +from flwr.common.recordset_compat import ( + evaluateres_to_recordset, + fitres_to_recordset, + getparametersres_to_recordset, + getpropertiesres_to_recordset, + recordset_to_evaluateins, + recordset_to_fitins, + recordset_to_getparametersins, + recordset_to_getpropertiesins, +) from flwr.proto.task_pb2 import ( # pylint: disable=E0611 SecureAggregation, Task, @@ -177,6 +188,44 @@ def handle_legacy_message( raise UnknownServerMessage() +def handle_legacy_message_from_tasktype( + client_fn: ClientFn, context: FlowerContext +) -> FlowerContext: + """Handle legacy message in the inner most middleware layer.""" + client = client_fn("-1") + task_type = context.metadata.task_type + + if task_type == "get_properties_ins": + get_properties_res = maybe_call_get_properties( + client=client, + get_properties_ins=recordset_to_getpropertiesins(context.in_message), + ) + context.out_message = getpropertiesres_to_recordset(get_properties_res) + elif task_type == "get_parameteres_ins": + get_parameters_res = maybe_call_get_parameters( + client=client, + get_parameters_ins=recordset_to_getparametersins(context.in_message), + ) + context.out_message = getparametersres_to_recordset(get_parameters_res) + elif task_type == "fit_ins": + fit_res = maybe_call_fit( + client=client, + fit_ins=recordset_to_fitins(context.in_message, keep_input=False), + ) + context.out_message = fitres_to_recordset(fit_res, keep_input=False) + elif task_type == "evaluate_ins": + evaluate_res = maybe_call_evaluate( + client=client, + evaluate_ins=recordset_to_evaluateins(context.in_message, keep_input=False), + ) + context.out_message = evaluateres_to_recordset(evaluate_res) + else: + # TODO: what to do with reconnect? + print("do something") + + return context + + def _reconnect( reconnect_msg: ServerMessage.ReconnectIns, ) -> Tuple[ClientMessage, int]: diff --git a/src/py/flwr/client/middleware/utils.py b/src/py/flwr/client/middleware/utils.py index d93132403c1e..4d42c3afa3bb 100644 --- a/src/py/flwr/client/middleware/utils.py +++ b/src/py/flwr/client/middleware/utils.py @@ -17,15 +17,16 @@ from typing import List -from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer +from flwr.client.typing import FlowerCallable, Layer +from flwr.common.flowercontext import FlowerContext def make_ffn(ffn: FlowerCallable, layers: List[Layer]) -> FlowerCallable: """.""" def wrap_ffn(_ffn: FlowerCallable, _layer: Layer) -> FlowerCallable: - def new_ffn(fwd: Fwd) -> Bwd: - return _layer(fwd, _ffn) + def new_ffn(context: FlowerContext) -> FlowerContext: + return _layer(context, _ffn) return new_ffn diff --git a/src/py/flwr/client/middleware/utils_test.py b/src/py/flwr/client/middleware/utils_test.py index 006fe6db4799..6a588f3d02eb 100644 --- a/src/py/flwr/client/middleware/utils_test.py +++ b/src/py/flwr/client/middleware/utils_test.py @@ -20,6 +20,9 @@ from flwr.client.run_state import RunState from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer +from flwr.common.configsrecord import ConfigsRecord +from flwr.common.flowercontext import FlowerContext, Metadata +from flwr.common.recordset import RecordSet from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from .utils import make_ffn @@ -28,13 +31,13 @@ def make_mock_middleware(name: str, footprint: List[str]) -> Layer: """Make a mock middleware layer.""" - def middleware(fwd: Fwd, app: FlowerCallable) -> Bwd: + def middleware(context: FlowerContext, app: FlowerCallable) -> FlowerContext: footprint.append(name) - fwd.task_ins.task_id += f"{name}" - bwd = app(fwd) + context.in_message.set_configs(name=name, record=ConfigsRecord()) + ctx: FlowerContext = app(context) footprint.append(name) - bwd.task_res.task_id += f"{name}" - return bwd + ctx.out_message.set_configs(name=name, record=ConfigsRecord()) + return ctx return middleware @@ -42,10 +45,11 @@ def middleware(fwd: Fwd, app: FlowerCallable) -> Bwd: def make_mock_app(name: str, footprint: List[str]) -> FlowerCallable: """Make a mock app.""" - def app(fwd: Fwd) -> Bwd: + def app(context: FlowerContext) -> FlowerContext: footprint.append(name) - fwd.task_ins.task_id += f"{name}" - return Bwd(task_res=TaskRes(task_id=name), state=RunState({})) + context.in_message.set_configs(name=name, record=ConfigsRecord()) + context.out_message.set_configs(name=name, record=ConfigsRecord()) + return context return app @@ -62,18 +66,28 @@ def test_multiple_middlewares(self) -> None: mock_middleware_layers = [ make_mock_middleware(name, footprint) for name in mock_middleware_names ] - task_ins = TaskIns() + + context = FlowerContext( + in_message=RecordSet(), + out_message=RecordSet(), + local=RecordSet(), + metadata=Metadata( + run_id=0, task_id="", group_id="", ttl="", task_type="mock" + ), + ) # Execute wrapped_app = make_ffn(mock_app, mock_middleware_layers) - task_res = wrapped_app(Fwd(task_ins=task_ins, state=RunState({}))).task_res + context_ = wrapped_app(context) # Assert trace = mock_middleware_names + ["app"] self.assertEqual(footprint, trace + list(reversed(mock_middleware_names))) # pylint: disable-next=no-member - self.assertEqual(task_ins.task_id, "".join(trace)) - self.assertEqual(task_res.task_id, "".join(reversed(trace))) + self.assertEqual("".join(context_.in_message.configs.keys()), "".join(trace)) + self.assertEqual( + "".join(context_.out_message.configs.keys()), "".join(reversed(trace)) + ) def test_filter(self) -> None: """Test if a middleware can filter incoming TaskIns.""" diff --git a/src/py/flwr/client/typing.py b/src/py/flwr/client/typing.py index 5291afb83d98..81b2eda26311 100644 --- a/src/py/flwr/client/typing.py +++ b/src/py/flwr/client/typing.py @@ -18,6 +18,7 @@ from typing import Callable from flwr.client.run_state import RunState +from flwr.common.flowercontext import FlowerContext from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from .client import Client as Client @@ -39,6 +40,6 @@ class Bwd: state: RunState -FlowerCallable = Callable[[Fwd], Bwd] +FlowerCallable = Callable[[FlowerContext], FlowerContext] ClientFn = Callable[[str], Client] -Layer = Callable[[Fwd, FlowerCallable], Bwd] +Layer = Callable[[FlowerContext, FlowerCallable], FlowerContext] From c5510ca4af6ddd287f60d852ce58c06affb63b31 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Tue, 23 Jan 2024 20:19:58 +0000 Subject: [PATCH 02/13] more --- src/py/flwr/client/middleware/utils_test.py | 39 +++++++++++---------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/src/py/flwr/client/middleware/utils_test.py b/src/py/flwr/client/middleware/utils_test.py index 6a588f3d02eb..88a1121eaf13 100644 --- a/src/py/flwr/client/middleware/utils_test.py +++ b/src/py/flwr/client/middleware/utils_test.py @@ -18,12 +18,10 @@ import unittest from typing import List -from flwr.client.run_state import RunState -from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer +from flwr.client.typing import FlowerCallable, Layer from flwr.common.configsrecord import ConfigsRecord from flwr.common.flowercontext import FlowerContext, Metadata from flwr.common.recordset import RecordSet -from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from .utils import make_ffn @@ -33,9 +31,11 @@ def make_mock_middleware(name: str, footprint: List[str]) -> Layer: def middleware(context: FlowerContext, app: FlowerCallable) -> FlowerContext: footprint.append(name) + # add empty ConfigRegcord to in_message for this middleware layer context.in_message.set_configs(name=name, record=ConfigsRecord()) ctx: FlowerContext = app(context) footprint.append(name) + # add empty ConfigRegcord to out_message for this middleware layer ctx.out_message.set_configs(name=name, record=ConfigsRecord()) return ctx @@ -54,6 +54,15 @@ def app(context: FlowerContext) -> FlowerContext: return app +def _get_dummy_flower_context() -> FlowerContext: + return FlowerContext( + in_message=RecordSet(), + out_message=RecordSet(), + local=RecordSet(), + metadata=Metadata(run_id=0, task_id="", group_id="", ttl="", task_type="mock"), + ) + + class TestMakeApp(unittest.TestCase): """Tests for the `make_app` function.""" @@ -67,14 +76,7 @@ def test_multiple_middlewares(self) -> None: make_mock_middleware(name, footprint) for name in mock_middleware_names ] - context = FlowerContext( - in_message=RecordSet(), - out_message=RecordSet(), - local=RecordSet(), - metadata=Metadata( - run_id=0, task_id="", group_id="", ttl="", task_type="mock" - ), - ) + context = _get_dummy_flower_context() # Execute wrapped_app = make_ffn(mock_app, mock_middleware_layers) @@ -94,20 +96,21 @@ def test_filter(self) -> None: # Prepare footprint: List[str] = [] mock_app = make_mock_app("app", footprint) - task_ins = TaskIns() + context = _get_dummy_flower_context() - def filter_layer(fwd: Fwd, _: FlowerCallable) -> Bwd: + def filter_layer(context: FlowerContext, _: FlowerCallable) -> FlowerContext: footprint.append("filter") - fwd.task_ins.task_id += "filter" + context.in_message.set_configs(name="filter", record=ConfigsRecord()) + context.out_message.set_configs(name="filter", record=ConfigsRecord()) # Skip calling app - return Bwd(task_res=TaskRes(task_id="filter"), state=RunState({})) + return context # Execute wrapped_app = make_ffn(mock_app, [filter_layer]) - task_res = wrapped_app(Fwd(task_ins=task_ins, state=RunState({}))).task_res + context_ = wrapped_app(context) # Assert self.assertEqual(footprint, ["filter"]) # pylint: disable-next=no-member - self.assertEqual(task_ins.task_id, "filter") - self.assertEqual(task_res.task_id, "filter") + self.assertEqual(list(context_.in_message.configs.keys())[0], "filter") + self.assertEqual(list(context_.out_message.configs.keys())[0], "filter") From 3ad52390dfefd24e99a2d3beab3ba6e0673ee17e Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 25 Jan 2024 10:51:01 +0000 Subject: [PATCH 03/13] update with new `Message` and `Context` logic --- src/py/flwr/client/app.py | 35 +++++----- src/py/flwr/client/flower.py | 18 ++--- .../client/message_handler/message_handler.py | 32 +++++---- src/py/flwr/client/middleware/utils.py | 7 +- src/py/flwr/client/middleware/utils_test.py | 65 +++++++++++-------- src/py/flwr/client/typing.py | 7 +- 6 files changed, 94 insertions(+), 70 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 80ffa3914afb..708bcb35ef16 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -34,8 +34,9 @@ TRANSPORT_TYPE_REST, TRANSPORT_TYPES, ) -from flwr.common.flowercontext import FlowerContext, Metadata +from flwr.common.context import Context from flwr.common.logger import log, warn_experimental_feature +from flwr.common.message import Message, Metadata from flwr.common.recordset import RecordSet from flwr.common.serde import recordset_to_proto from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 @@ -326,15 +327,7 @@ def _load_app() -> Flower: connection, address = _init_connection(transport, server_address) node_state = NodeState() - # TODO: remove NodeState/RunState logic ? - - # TODO: initialize context here? - context = FlowerContext( - in_message=RecordSet(), - out_message=RecordSet(), - local=RecordSet(), - metadata=Metadata(run_id=-1, task_id="", group_id="", ttl="", task_type=""), - ) + # TODO: make NodeState work with RecordSet while True: sleep_duration: int = 0 @@ -366,13 +359,23 @@ def _load_app() -> Flower: # Register state node_state.register_runstate(run_id=task_ins.run_id) - # TODO: pulate context.metadata and context.in_message from TaskIns + # TODO: get runstate from nodestate and construct context for this run + context = Context(state=RecordSet()) + + # TODO: get Message from TaskIns + + message = Message( + metadata=Metadata( + run_id=0, task_id="", group_id="", ttl="", task_type="mock" + ), + message=RecordSet(), + ) # Load app app: Flower = load_flower_callable_fn() # Handle task message - context_ = app(context=context) + out_message = app(message=message, context=context) # Update node state # node_state.update_runstate( @@ -382,10 +385,10 @@ def _load_app() -> Flower: # TODO: Construct TaskRes from context.out_message task_res = TaskRes( - task_id=context_.metadata.task_id, - group_id=context_.metadata.group_id, - run_id=context_.metadata.run_id, - task=Task(recordset=recordset_to_proto(context_.out_message)), + task_id=message.metadata.task_id, + group_id=message.metadata.group_id, + run_id=message.metadata.run_id, + task=Task(recordset=recordset_to_proto(out_message.message)), ) # Send diff --git a/src/py/flwr/client/flower.py b/src/py/flwr/client/flower.py index 45f07cf58fde..91f64502acfa 100644 --- a/src/py/flwr/client/flower.py +++ b/src/py/flwr/client/flower.py @@ -23,7 +23,8 @@ ) from flwr.client.middleware.utils import make_ffn from flwr.client.typing import ClientFn, Layer -from flwr.common.flowercontext import FlowerContext +from flwr.common.context import Context +from flwr.common.message import Message class Flower: @@ -59,19 +60,20 @@ def __init__( ) -> None: # Create wrapper function for `handle` def ffn( - context: FlowerContext, - ) -> FlowerContext: # pylint: disable=invalid-name - context = handle_legacy_message_from_tasktype( - client_fn=client_fn, context=context + message: Message, + context: Context, + ) -> Message: # pylint: disable=invalid-name + out_message = handle_legacy_message_from_tasktype( + client_fn=client_fn, message=message, context=context ) - return context + return out_message # Wrap middleware layers around the wrapped handle function self._call = make_ffn(ffn, layers if layers is not None else []) - def __call__(self, context: FlowerContext) -> FlowerContext: + def __call__(self, message: Message, context: Context) -> Message: """.""" - return self._call(context) + return self._call(message, context) class LoadCallableError(Exception): diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index a289a542e870..47d89b7f2c36 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -32,7 +32,9 @@ from flwr.client.secure_aggregation import SecureAggregationHandler from flwr.client.typing import ClientFn from flwr.common import serde -from flwr.common.flowercontext import FlowerContext +from flwr.common.context import Context +from flwr.common.message import Message +from flwr.common.recordset import RecordSet from flwr.common.recordset_compat import ( evaluateres_to_recordset, fitres_to_recordset, @@ -189,41 +191,45 @@ def handle_legacy_message( def handle_legacy_message_from_tasktype( - client_fn: ClientFn, context: FlowerContext -) -> FlowerContext: + client_fn: ClientFn, message: Message, context: Context +) -> Message: """Handle legacy message in the inner most middleware layer.""" client = client_fn("-1") - task_type = context.metadata.task_type + # TODO: inject state (i.e. context.state) into client? + + task_type = message.metadata.task_type + + out_message = Message(metadata=message.metadata, message=RecordSet()) if task_type == "get_properties_ins": get_properties_res = maybe_call_get_properties( client=client, - get_properties_ins=recordset_to_getpropertiesins(context.in_message), + get_properties_ins=recordset_to_getpropertiesins(message.message), ) - context.out_message = getpropertiesres_to_recordset(get_properties_res) + out_message.message = getpropertiesres_to_recordset(get_properties_res) elif task_type == "get_parameteres_ins": get_parameters_res = maybe_call_get_parameters( client=client, - get_parameters_ins=recordset_to_getparametersins(context.in_message), + get_parameters_ins=recordset_to_getparametersins(message.message), ) - context.out_message = getparametersres_to_recordset(get_parameters_res) + out_message.message = getparametersres_to_recordset(get_parameters_res) elif task_type == "fit_ins": fit_res = maybe_call_fit( client=client, - fit_ins=recordset_to_fitins(context.in_message, keep_input=False), + fit_ins=recordset_to_fitins(message.message, keep_input=False), ) - context.out_message = fitres_to_recordset(fit_res, keep_input=False) + out_message.message = fitres_to_recordset(fit_res, keep_input=False) elif task_type == "evaluate_ins": evaluate_res = maybe_call_evaluate( client=client, - evaluate_ins=recordset_to_evaluateins(context.in_message, keep_input=False), + evaluate_ins=recordset_to_evaluateins(message.message, keep_input=False), ) - context.out_message = evaluateres_to_recordset(evaluate_res) + out_message.message = evaluateres_to_recordset(evaluate_res) else: # TODO: what to do with reconnect? print("do something") - return context + return out_message def _reconnect( diff --git a/src/py/flwr/client/middleware/utils.py b/src/py/flwr/client/middleware/utils.py index 4d42c3afa3bb..8ed445f5599c 100644 --- a/src/py/flwr/client/middleware/utils.py +++ b/src/py/flwr/client/middleware/utils.py @@ -18,15 +18,16 @@ from typing import List from flwr.client.typing import FlowerCallable, Layer -from flwr.common.flowercontext import FlowerContext +from flwr.common.context import Context +from flwr.common.message import Message def make_ffn(ffn: FlowerCallable, layers: List[Layer]) -> FlowerCallable: """.""" def wrap_ffn(_ffn: FlowerCallable, _layer: Layer) -> FlowerCallable: - def new_ffn(context: FlowerContext) -> FlowerContext: - return _layer(context, _ffn) + def new_ffn(message: Message, context: Context) -> Message: + return _layer(message, _ffn, context) return new_ffn diff --git a/src/py/flwr/client/middleware/utils_test.py b/src/py/flwr/client/middleware/utils_test.py index 88a1121eaf13..e30934c86f86 100644 --- a/src/py/flwr/client/middleware/utils_test.py +++ b/src/py/flwr/client/middleware/utils_test.py @@ -20,7 +20,8 @@ from flwr.client.typing import FlowerCallable, Layer from flwr.common.configsrecord import ConfigsRecord -from flwr.common.flowercontext import FlowerContext, Metadata +from flwr.common.context import Context +from flwr.common.message import Message, Metadata from flwr.common.recordset import RecordSet from .utils import make_ffn @@ -29,15 +30,17 @@ def make_mock_middleware(name: str, footprint: List[str]) -> Layer: """Make a mock middleware layer.""" - def middleware(context: FlowerContext, app: FlowerCallable) -> FlowerContext: + def middleware(message: Message, app: FlowerCallable, context: Context) -> Message: footprint.append(name) # add empty ConfigRegcord to in_message for this middleware layer - context.in_message.set_configs(name=name, record=ConfigsRecord()) - ctx: FlowerContext = app(context) + message.message.set_configs(name=name, record=ConfigsRecord()) + context.state.metrics['context']['counter'] += 1 + out_message: Message = app(message, context) footprint.append(name) + context.state.metrics['context']['counter'] += 1 # add empty ConfigRegcord to out_message for this middleware layer - ctx.out_message.set_configs(name=name, record=ConfigsRecord()) - return ctx + out_message.message.set_configs(name=name, record=ConfigsRecord()) + return out_message return middleware @@ -45,20 +48,20 @@ def middleware(context: FlowerContext, app: FlowerCallable) -> FlowerContext: def make_mock_app(name: str, footprint: List[str]) -> FlowerCallable: """Make a mock app.""" - def app(context: FlowerContext) -> FlowerContext: + def app(message: Message, context: Context) -> Message: footprint.append(name) - context.in_message.set_configs(name=name, record=ConfigsRecord()) - context.out_message.set_configs(name=name, record=ConfigsRecord()) - return context + message.message.set_configs(name=name, record=ConfigsRecord()) + out_message = Message(metadata=message.metadata, message=RecordSet()) + out_message.message.set_configs(name=name, record=ConfigsRecord()) + print(context) + return out_message return app -def _get_dummy_flower_context() -> FlowerContext: - return FlowerContext( - in_message=RecordSet(), - out_message=RecordSet(), - local=RecordSet(), +def _get_dummy_flower_message() -> Message: + return Message( + message=RecordSet(), metadata=Metadata(run_id=0, task_id="", group_id="", ttl="", task_type="mock"), ) @@ -76,41 +79,49 @@ def test_multiple_middlewares(self) -> None: make_mock_middleware(name, footprint) for name in mock_middleware_names ] - context = _get_dummy_flower_context() + state = RecordSet() + state.set_metrics('context', {'counter': 0.0}) + context = Context(state=state) + message = _get_dummy_flower_message() # Execute wrapped_app = make_ffn(mock_app, mock_middleware_layers) - context_ = wrapped_app(context) + out_message = wrapped_app(message, context) # Assert trace = mock_middleware_names + ["app"] self.assertEqual(footprint, trace + list(reversed(mock_middleware_names))) # pylint: disable-next=no-member - self.assertEqual("".join(context_.in_message.configs.keys()), "".join(trace)) + self.assertEqual("".join(message.message.configs.keys()), "".join(trace)) self.assertEqual( - "".join(context_.out_message.configs.keys()), "".join(reversed(trace)) + "".join(out_message.message.configs.keys()), "".join(reversed(trace)) ) + self.assertEqual(state.get_metrics('context')['counter'], 2*len(mock_middleware_layers)) def test_filter(self) -> None: """Test if a middleware can filter incoming TaskIns.""" # Prepare footprint: List[str] = [] mock_app = make_mock_app("app", footprint) - context = _get_dummy_flower_context() + context = Context(state=RecordSet()) + message = _get_dummy_flower_message() - def filter_layer(context: FlowerContext, _: FlowerCallable) -> FlowerContext: + def filter_layer( + message: Message, _: FlowerCallable, context: Context + ) -> Message: footprint.append("filter") - context.in_message.set_configs(name="filter", record=ConfigsRecord()) - context.out_message.set_configs(name="filter", record=ConfigsRecord()) + message.message.set_configs(name="filter", record=ConfigsRecord()) + out_message = Message(metadata=message.metadata, message=RecordSet()) + out_message.message.set_configs(name="filter", record=ConfigsRecord()) # Skip calling app - return context + return out_message # Execute wrapped_app = make_ffn(mock_app, [filter_layer]) - context_ = wrapped_app(context) + out_message = wrapped_app(message, context) # Assert self.assertEqual(footprint, ["filter"]) # pylint: disable-next=no-member - self.assertEqual(list(context_.in_message.configs.keys())[0], "filter") - self.assertEqual(list(context_.out_message.configs.keys())[0], "filter") + self.assertEqual(list(message.message.configs.keys())[0], "filter") + self.assertEqual(list(out_message.message.configs.keys())[0], "filter") diff --git a/src/py/flwr/client/typing.py b/src/py/flwr/client/typing.py index 81b2eda26311..f8ca7edfffa7 100644 --- a/src/py/flwr/client/typing.py +++ b/src/py/flwr/client/typing.py @@ -18,7 +18,8 @@ from typing import Callable from flwr.client.run_state import RunState -from flwr.common.flowercontext import FlowerContext +from flwr.common.context import Context +from flwr.common.message import Message from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from .client import Client as Client @@ -40,6 +41,6 @@ class Bwd: state: RunState -FlowerCallable = Callable[[FlowerContext], FlowerContext] +FlowerCallable = Callable[[Message, Context], Message] ClientFn = Callable[[str], Client] -Layer = Callable[[FlowerContext, FlowerCallable], FlowerContext] +Layer = Callable[[Message, FlowerCallable, Context], Message] From d2f326969243e32557a5b570f2f61d1305eb1811 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 25 Jan 2024 18:00:31 +0000 Subject: [PATCH 04/13] wip --- src/py/flwr/client/middleware/utils.py | 2 +- src/py/flwr/client/middleware/utils_test.py | 29 ++++++++++++++++----- src/py/flwr/client/typing.py | 2 +- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/py/flwr/client/middleware/utils.py b/src/py/flwr/client/middleware/utils.py index 8ed445f5599c..6d24da558576 100644 --- a/src/py/flwr/client/middleware/utils.py +++ b/src/py/flwr/client/middleware/utils.py @@ -27,7 +27,7 @@ def make_ffn(ffn: FlowerCallable, layers: List[Layer]) -> FlowerCallable: def wrap_ffn(_ffn: FlowerCallable, _layer: Layer) -> FlowerCallable: def new_ffn(message: Message, context: Context) -> Message: - return _layer(message, _ffn, context) + return _layer(message, context, _ffn) return new_ffn diff --git a/src/py/flwr/client/middleware/utils_test.py b/src/py/flwr/client/middleware/utils_test.py index e30934c86f86..301647755be7 100644 --- a/src/py/flwr/client/middleware/utils_test.py +++ b/src/py/flwr/client/middleware/utils_test.py @@ -22,22 +22,34 @@ from flwr.common.configsrecord import ConfigsRecord from flwr.common.context import Context from flwr.common.message import Message, Metadata +from flwr.common.metricsrecord import MetricsRecord from flwr.common.recordset import RecordSet from .utils import make_ffn +METRIC = "context" +COUNTER = "counter" + + +def _increment_context_counter(context: Context) -> None: + # Read from context + current_counter: int = context.state.get_metrics(METRIC)[COUNTER] # type: ignore + # update and override context + current_counter += 1 + context.state.set_metrics(METRIC, record=MetricsRecord({COUNTER: current_counter})) + def make_mock_middleware(name: str, footprint: List[str]) -> Layer: """Make a mock middleware layer.""" - def middleware(message: Message, app: FlowerCallable, context: Context) -> Message: + def middleware(message: Message, context: Context, app: FlowerCallable) -> Message: footprint.append(name) # add empty ConfigRegcord to in_message for this middleware layer message.message.set_configs(name=name, record=ConfigsRecord()) - context.state.metrics['context']['counter'] += 1 + _increment_context_counter(context) out_message: Message = app(message, context) footprint.append(name) - context.state.metrics['context']['counter'] += 1 + _increment_context_counter(context) # add empty ConfigRegcord to out_message for this middleware layer out_message.message.set_configs(name=name, record=ConfigsRecord()) return out_message @@ -80,7 +92,7 @@ def test_multiple_middlewares(self) -> None: ] state = RecordSet() - state.set_metrics('context', {'counter': 0.0}) + state.set_metrics(METRIC, record=MetricsRecord({COUNTER: 0.0})) context = Context(state=state) message = _get_dummy_flower_message() @@ -96,7 +108,9 @@ def test_multiple_middlewares(self) -> None: self.assertEqual( "".join(out_message.message.configs.keys()), "".join(reversed(trace)) ) - self.assertEqual(state.get_metrics('context')['counter'], 2*len(mock_middleware_layers)) + self.assertEqual( + state.get_metrics(METRIC)[COUNTER], 2 * len(mock_middleware_layers) + ) def test_filter(self) -> None: """Test if a middleware can filter incoming TaskIns.""" @@ -107,10 +121,13 @@ def test_filter(self) -> None: message = _get_dummy_flower_message() def filter_layer( - message: Message, _: FlowerCallable, context: Context + message: Message, + context: Context, + _: FlowerCallable, ) -> Message: footprint.append("filter") message.message.set_configs(name="filter", record=ConfigsRecord()) + context = context # we need to do something with it else mypy issue out_message = Message(metadata=message.metadata, message=RecordSet()) out_message.message.set_configs(name="filter", record=ConfigsRecord()) # Skip calling app diff --git a/src/py/flwr/client/typing.py b/src/py/flwr/client/typing.py index f8ca7edfffa7..ec15981dbab3 100644 --- a/src/py/flwr/client/typing.py +++ b/src/py/flwr/client/typing.py @@ -43,4 +43,4 @@ class Bwd: FlowerCallable = Callable[[Message, Context], Message] ClientFn = Callable[[str], Client] -Layer = Callable[[Message, FlowerCallable, Context], Message] +Layer = Callable[[Message, Context, FlowerCallable], Message] From 1de02b180ca41cc36e2adb7efb6bb95baa98666d Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 25 Jan 2024 19:29:42 +0000 Subject: [PATCH 05/13] Add TaskIns to message and message to TaskRes --- src/py/flwr/client/app.py | 24 ++++++--------------- src/py/flwr/client/middleware/utils_test.py | 5 ++--- 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 708bcb35ef16..72f9e47843d0 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -36,10 +36,9 @@ ) from flwr.common.context import Context from flwr.common.logger import log, warn_experimental_feature -from flwr.common.message import Message, Metadata from flwr.common.recordset import RecordSet -from flwr.common.serde import recordset_to_proto -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 +from flwr.common.serde import message_from_taskins, message_to_taskres +from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from .flower import load_flower_callable from .grpc_client.connection import grpc_connection @@ -362,14 +361,8 @@ def _load_app() -> Flower: # TODO: get runstate from nodestate and construct context for this run context = Context(state=RecordSet()) - # TODO: get Message from TaskIns - - message = Message( - metadata=Metadata( - run_id=0, task_id="", group_id="", ttl="", task_type="mock" - ), - message=RecordSet(), - ) + # Get Message from TaskIns + message = message_from_taskins(task_ins) # Load app app: Flower = load_flower_callable_fn() @@ -383,13 +376,8 @@ def _load_app() -> Flower: # run_state=bwd_msg.state, # ) - # TODO: Construct TaskRes from context.out_message - task_res = TaskRes( - task_id=message.metadata.task_id, - group_id=message.metadata.group_id, - run_id=message.metadata.run_id, - task=Task(recordset=recordset_to_proto(out_message.message)), - ) + # Construct TaskRes from out_message + task_res = message_to_taskres(out_message) # Send send(task_res) diff --git a/src/py/flwr/client/middleware/utils_test.py b/src/py/flwr/client/middleware/utils_test.py index 301647755be7..73abdbc61302 100644 --- a/src/py/flwr/client/middleware/utils_test.py +++ b/src/py/flwr/client/middleware/utils_test.py @@ -122,12 +122,11 @@ def test_filter(self) -> None: def filter_layer( message: Message, - context: Context, - _: FlowerCallable, + _1: Context, + _2: FlowerCallable, ) -> Message: footprint.append("filter") message.message.set_configs(name="filter", record=ConfigsRecord()) - context = context # we need to do something with it else mypy issue out_message = Message(metadata=message.metadata, message=RecordSet()) out_message.message.set_configs(name="filter", record=ConfigsRecord()) # Skip calling app From 69a7a17fac41e18f9dacf2f41395235f3e7abf9f Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Fri, 26 Jan 2024 14:50:25 +0000 Subject: [PATCH 06/13] update send and receive for grpc-bidi client --- src/py/flwr/client/grpc_client/connection.py | 81 ++++++++++++++++++- .../client/message_handler/message_handler.py | 26 ++++-- src/py/flwr/common/constant.py | 5 ++ 3 files changed, 100 insertions(+), 12 deletions(-) diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 5f11912c587c..5a4a3f8fbf8b 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -23,6 +23,14 @@ from typing import Callable, Iterator, Optional, Tuple, Union from flwr.common import GRPC_MAX_MESSAGE_LENGTH +from flwr.common import recordset_compat as compat +from flwr.common import serde, typing +from flwr.common.constant import ( + TASK_TYPE_EVALUATE, + TASK_TYPE_FIT, + TASK_TYPE_GET_PARAMETERS, + TASK_TYPE_GET_PROPERTIES, +) from flwr.common.grpc import create_channel from flwr.common.logger import log from flwr.proto.node_pb2 import Node # pylint: disable=E0611 @@ -118,7 +126,42 @@ def grpc_connection( server_message_iterator: Iterator[ServerMessage] = stub.Join(iter(queue.get, None)) def receive() -> TaskIns: - server_message = next(server_message_iterator) + # Receive ServerMessage proto + proto = next(server_message_iterator) + + # ServerMessage proto --> *Ins --> RecordSet + field = proto.WhichOneof("msg") + task_type = "" + if field == "get_properties_ins": + recordset = compat.getpropertiesins_to_recordset( + serde.get_properties_ins_from_proto(proto.get_properties_ins) + ) + task_type = TASK_TYPE_GET_PROPERTIES + elif field == "get_parameters_ins": + recordset = compat.getparametersins_to_recordset( + serde.get_parameters_ins_from_proto(proto.get_parameters_ins) + ) + task_type = TASK_TYPE_GET_PARAMETERS + elif field == "fit_ins": + recordset = compat.fitins_to_recordset( + serde.fit_ins_from_proto(proto.fit_ins), False + ) + task_type = TASK_TYPE_FIT + elif field == "evaluate_ins": + recordset = compat.evaluateins_to_recordset( + serde.evaluate_ins_from_proto(proto.evaluate_ins), False + ) + task_type = TASK_TYPE_EVALUATE + else: + raise ValueError( + "Unsupported instruction in ServerMessage, " + "cannot deserialize from ProtoBuf" + ) + + # RecordSet --> RecordSet proto + recordset_proto = serde.recordset_to_proto(recordset) + + # Construct TaskIns return TaskIns( task_id=str(uuid.uuid4()), group_id="", @@ -127,13 +170,43 @@ def receive() -> TaskIns: producer=Node(node_id=0, anonymous=True), consumer=Node(node_id=0, anonymous=True), ancestry=[], - legacy_server_message=server_message, + task_type=task_type, + recordset=recordset_proto, ), ) def send(task_res: TaskRes) -> None: - msg = task_res.task.legacy_client_message - return queue.put(msg, block=False) + # Retrieve RecordSet and task_type + recordset = serde.recordset_from_proto(task_res.task.recordset) + task_type = task_res.task.task_type + + # RecordSet --> *Res --> ClientMessage + if task_type == TASK_TYPE_GET_PROPERTIES: + client_message = typing.ClientMessage( + get_properties_res=compat.recordset_to_getpropertiesres(recordset) + ) + elif task_type == TASK_TYPE_GET_PARAMETERS: + client_message = typing.ClientMessage( + get_parameters_res=compat.recordset_to_getparametersres( + recordset, False + ) + ) + elif task_type == TASK_TYPE_FIT: + client_message = typing.ClientMessage( + fit_res=compat.recordset_to_fitres(recordset, False) + ) + elif task_type == TASK_TYPE_EVALUATE: + client_message = typing.ClientMessage( + evaluate_res=compat.recordset_to_evaluateres(recordset) + ) + else: + raise ValueError(f"Invalid task type: {task_type}") + + # ClientMessage --> ClientMessage proto + client_message_proto = serde.client_message_to_proto(client_message) + + # Send ClientMessage proto + return queue.put(client_message_proto, block=False) try: # Yield methods diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 47d89b7f2c36..8397d675b640 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -32,6 +32,12 @@ from flwr.client.secure_aggregation import SecureAggregationHandler from flwr.client.typing import ClientFn from flwr.common import serde +from flwr.common.constant import ( + TASK_TYPE_EVALUATE, + TASK_TYPE_FIT, + TASK_TYPE_GET_PARAMETERS, + TASK_TYPE_GET_PROPERTIES, +) from flwr.common.context import Context from flwr.common.message import Message from flwr.common.recordset import RecordSet @@ -201,34 +207,38 @@ def handle_legacy_message_from_tasktype( task_type = message.metadata.task_type out_message = Message(metadata=message.metadata, message=RecordSet()) - if task_type == "get_properties_ins": + # Handle GetPropertiesIns + if task_type == TASK_TYPE_GET_PROPERTIES: get_properties_res = maybe_call_get_properties( client=client, get_properties_ins=recordset_to_getpropertiesins(message.message), ) out_message.message = getpropertiesres_to_recordset(get_properties_res) - elif task_type == "get_parameteres_ins": + # Handle GetParametersIns + elif task_type == TASK_TYPE_GET_PARAMETERS: get_parameters_res = maybe_call_get_parameters( client=client, get_parameters_ins=recordset_to_getparametersins(message.message), ) - out_message.message = getparametersres_to_recordset(get_parameters_res) - elif task_type == "fit_ins": + out_message.message = getparametersres_to_recordset( + get_parameters_res, keep_input=False + ) + # Handle FitIns + elif task_type == TASK_TYPE_FIT: fit_res = maybe_call_fit( client=client, fit_ins=recordset_to_fitins(message.message, keep_input=False), ) out_message.message = fitres_to_recordset(fit_res, keep_input=False) - elif task_type == "evaluate_ins": + # Handle EvaluateIns + elif task_type == TASK_TYPE_EVALUATE: evaluate_res = maybe_call_evaluate( client=client, evaluate_ins=recordset_to_evaluateins(message.message, keep_input=False), ) out_message.message = evaluateres_to_recordset(evaluate_res) else: - # TODO: what to do with reconnect? - print("do something") - + raise ValueError(f"Invalid task type: {task_type}") return out_message diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index 49802f2815be..8d1d865f084b 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -31,3 +31,8 @@ TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_REST, ] + +TASK_TYPE_GET_PROPERTIES = "get_properties" +TASK_TYPE_GET_PARAMETERS = "get_parameters" +TASK_TYPE_FIT = "fit" +TASK_TYPE_EVALUATE = "evaluate" From d25ebb2d734195b77b8a209d0111131fad7ba4b0 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Fri, 26 Jan 2024 17:58:11 +0000 Subject: [PATCH 07/13] update handle --- src/py/flwr/client/grpc_client/connection.py | 47 ++-- .../client/grpc_client/connection_test.py | 45 ++-- .../client/message_handler/message_handler.py | 207 +++--------------- .../message_handler/message_handler_test.py | 146 ++++-------- 4 files changed, 113 insertions(+), 332 deletions(-) diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 5a4a3f8fbf8b..4ca80642199d 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -20,11 +20,12 @@ from logging import DEBUG from pathlib import Path from queue import Queue -from typing import Callable, Iterator, Optional, Tuple, Union +from typing import Callable, Iterator, Optional, Tuple, Union, cast from flwr.common import GRPC_MAX_MESSAGE_LENGTH from flwr.common import recordset_compat as compat -from flwr.common import serde, typing +from flwr.common import serde +from flwr.common.configsrecord import ConfigsRecord from flwr.common.constant import ( TASK_TYPE_EVALUATE, TASK_TYPE_FIT, @@ -33,10 +34,12 @@ ) from flwr.common.grpc import create_channel from flwr.common.logger import log +from flwr.common.recordset import RecordSet from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, + Reason, ServerMessage, ) from flwr.proto.transport_pb2_grpc import FlowerServiceStub # pylint: disable=E0611 @@ -54,7 +57,7 @@ def on_channel_state_change(channel_connectivity: str) -> None: @contextmanager -def grpc_connection( +def grpc_connection( # pylint: disable=R0915 server_address: str, insecure: bool, max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, @@ -152,6 +155,12 @@ def receive() -> TaskIns: serde.evaluate_ins_from_proto(proto.evaluate_ins), False ) task_type = TASK_TYPE_EVALUATE + elif field == "reconnect_ins": + recordset = RecordSet() + recordset.set_configs( + "config", ConfigsRecord({"seconds": proto.reconnect_ins.seconds}) + ) + task_type = "reconnect" else: raise ValueError( "Unsupported instruction in ServerMessage, " @@ -180,33 +189,33 @@ def send(task_res: TaskRes) -> None: recordset = serde.recordset_from_proto(task_res.task.recordset) task_type = task_res.task.task_type - # RecordSet --> *Res --> ClientMessage + # RecordSet --> *Res --> *Res proto -> ClientMessage proto if task_type == TASK_TYPE_GET_PROPERTIES: - client_message = typing.ClientMessage( - get_properties_res=compat.recordset_to_getpropertiesres(recordset) + getpropres = compat.recordset_to_getpropertiesres(recordset) + msg_proto = ClientMessage( + get_properties_res=serde.get_properties_res_to_proto(getpropres) ) elif task_type == TASK_TYPE_GET_PARAMETERS: - client_message = typing.ClientMessage( - get_parameters_res=compat.recordset_to_getparametersres( - recordset, False - ) + getparamres = compat.recordset_to_getparametersres(recordset, False) + msg_proto = ClientMessage( + get_parameters_res=serde.get_parameters_res_to_proto(getparamres) ) elif task_type == TASK_TYPE_FIT: - client_message = typing.ClientMessage( - fit_res=compat.recordset_to_fitres(recordset, False) - ) + fitres = compat.recordset_to_fitres(recordset, False) + msg_proto = ClientMessage(fit_res=serde.fit_res_to_proto(fitres)) elif task_type == TASK_TYPE_EVALUATE: - client_message = typing.ClientMessage( - evaluate_res=compat.recordset_to_evaluateres(recordset) + evalres = compat.recordset_to_evaluateres(recordset) + msg_proto = ClientMessage(evaluate_res=serde.evaluate_res_to_proto(evalres)) + elif task_type == "reconnect": + reason = cast(Reason.ValueType, recordset.get_configs("config")["reason"]) + msg_proto = ClientMessage( + disconnect_res=ClientMessage.DisconnectRes(reason=reason) ) else: raise ValueError(f"Invalid task type: {task_type}") - # ClientMessage --> ClientMessage proto - client_message_proto = serde.client_message_to_proto(client_message) - # Send ClientMessage proto - return queue.put(client_message_proto, block=False) + return queue.put(msg_proto, block=False) try: # Yield methods diff --git a/src/py/flwr/client/grpc_client/connection_test.py b/src/py/flwr/client/grpc_client/connection_test.py index bcfa76bb36c0..f2b362750df4 100644 --- a/src/py/flwr/client/grpc_client/connection_test.py +++ b/src/py/flwr/client/grpc_client/connection_test.py @@ -23,6 +23,12 @@ import grpc +from flwr.common import recordset_compat as compat +from flwr.common import serde +from flwr.common.configsrecord import ConfigsRecord +from flwr.common.constant import TASK_TYPE_GET_PROPERTIES +from flwr.common.recordset import RecordSet +from flwr.common.typing import Code, GetPropertiesRes, Status from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, @@ -35,11 +41,21 @@ EXPECTED_NUM_SERVER_MESSAGE = 10 -SERVER_MESSAGE = ServerMessage() +SERVER_MESSAGE = ServerMessage(get_properties_ins=ServerMessage.GetPropertiesIns()) SERVER_MESSAGE_RECONNECT = ServerMessage(reconnect_ins=ServerMessage.ReconnectIns()) -CLIENT_MESSAGE = ClientMessage() -CLIENT_MESSAGE_DISCONNECT = ClientMessage(disconnect_res=ClientMessage.DisconnectRes()) +TASK_GET_PROPERTIES = Task( + task_type=TASK_TYPE_GET_PROPERTIES, + recordset=serde.recordset_to_proto( + compat.getpropertiesres_to_recordset(GetPropertiesRes(Status(Code.OK, ""), {})) + ), +) +TASK_DISCONNECT = Task( + task_type="reconnect", + recordset=serde.recordset_to_proto( + RecordSet(configs={"config": ConfigsRecord({"reason": 0})}) + ), +) def unused_tcp_port() -> int: @@ -104,31 +120,14 @@ def run_client() -> int: # Block until server responds with a message task_ins = receive() - if task_ins is None: - raise ValueError("Unexpected None value") - - # pylint: disable=no-member - if task_ins.HasField("task") and task_ins.task.HasField( - "legacy_server_message" - ): - server_message = task_ins.task.legacy_server_message - else: - server_message = None - # pylint: enable=no-member - - if server_message is None: - raise ValueError("Unexpected None value") - messages_received += 1 - if server_message.HasField("reconnect_ins"): - task_res = TaskRes( - task=Task(legacy_client_message=CLIENT_MESSAGE_DISCONNECT) - ) + if task_ins.task.task_type == "reconnect": # type: ignore + task_res = TaskRes(task=TASK_DISCONNECT) send(task_res) break # Process server_message and send client_message... - task_res = TaskRes(task=Task(legacy_client_message=CLIENT_MESSAGE)) + task_res = TaskRes(task=TASK_GET_PROPERTIES) send(task_res) return messages_received diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 000ecdbdb35e..e16d190b0289 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -15,22 +15,17 @@ """Client-side message handler.""" -from typing import Optional, Tuple +from typing import Optional, Tuple, cast from flwr.client.client import ( - Client, maybe_call_evaluate, maybe_call_fit, maybe_call_get_parameters, maybe_call_get_properties, ) -from flwr.client.message_handler.task_handler import ( - get_server_message_from_task_ins, - wrap_client_message_in_task_res, -) -from flwr.client.secure_aggregation import SecureAggregationHandler from flwr.client.typing import ClientFn from flwr.common import serde +from flwr.common.configsrecord import ConfigsRecord from flwr.common.constant import ( TASK_TYPE_EVALUATE, TASK_TYPE_FIT, @@ -50,12 +45,7 @@ recordset_to_getparametersins, recordset_to_getpropertiesins, ) -from flwr.proto.task_pb2 import ( # pylint: disable=E0611 - SecureAggregation, - Task, - TaskIns, - TaskRes, -) +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, Reason, @@ -81,120 +71,37 @@ def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]: Returns ------- + task_res : Optional[TaskRes] + TaskRes to be returned to the server. If None, the client should + continue to process messages from the server. sleep_duration : int Number of seconds that the client should disconnect from the server. - keep_going : bool - Flag that indicates whether the client should continue to process the - next message from the server (True) or disconnect and optionally - reconnect later (False). """ - server_msg = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False) - - # SecAgg message - if server_msg is None: - return None, 0 - - # ReconnectIns message - field = server_msg.WhichOneof("msg") - if field == "reconnect_ins": - disconnect_msg, sleep_duration = _reconnect(server_msg.reconnect_ins) - task_res = wrap_client_message_in_task_res(disconnect_msg) + if task_ins.task.task_type == "reconnect": + # Retrieve ReconnectIns from recordset + recordset = serde.recordset_from_proto(task_ins.task.recordset) + seconds = cast(int, recordset.get_configs("config")["seconds"]) + # Construct ReconnectIns and call _reconnect + disconnect_msg, sleep_duration = _reconnect( + ServerMessage.ReconnectIns(seconds=seconds) + ) + # Store DisconnectRes in recordset + reason = cast(int, disconnect_msg.disconnect_res.reason) + recordset = RecordSet() + recordset.set_configs("config", ConfigsRecord({"reason": reason})) + task_res = TaskRes( + task=Task( + task_type="reconnect", + recordset=serde.recordset_to_proto(recordset), + ) + ) + # Return TaskRes and sleep duration return task_res, sleep_duration # Any other message return None, 0 -def handle( - client_fn: ClientFn, context: Context, task_ins: TaskIns -) -> Tuple[TaskRes, Context]: - """Handle incoming TaskIns from the server. - - Parameters - ---------- - client_fn : ClientFn - A callable that instantiates a Client. - context : Context - A dataclass storing the context for the run being executed by the client. - task_ins: TaskIns - The task instruction coming from the server, to be processed by the client. - - Returns - ------- - task_res : TaskRes - The task response that should be returned to the server. - """ - server_msg = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False) - if server_msg is None: - # Instantiate the client - client = client_fn("-1") - client.set_context(context) - # Secure Aggregation - if task_ins.task.HasField("sa") and isinstance( - client, SecureAggregationHandler - ): - # pylint: disable-next=invalid-name - named_values = serde.named_values_from_proto(task_ins.task.sa.named_values) - res = client.handle_secure_aggregation(named_values) - task_res = TaskRes( - task_id="", - group_id="", - run_id=0, - task=Task( - ancestry=[], - sa=SecureAggregation(named_values=serde.named_values_to_proto(res)), - ), - ) - return task_res, client.get_context() - raise NotImplementedError() - client_msg, updated_context = handle_legacy_message(client_fn, context, server_msg) - task_res = wrap_client_message_in_task_res(client_msg) - return task_res, updated_context - - -def handle_legacy_message( - client_fn: ClientFn, context: Context, server_msg: ServerMessage -) -> Tuple[ClientMessage, Context]: - """Handle incoming messages from the server. - - Parameters - ---------- - client_fn : ClientFn - A callable that instantiates a Client. - context : Context - A dataclass storing the context for the run being executed by the client. - server_msg: ServerMessage - The message coming from the server, to be processed by the client. - - Returns - ------- - client_msg : ClientMessage - The result message that should be returned to the server. - """ - field = server_msg.WhichOneof("msg") - - # Must be handled elsewhere - if field == "reconnect_ins": - raise UnexpectedServerMessage() - - # Instantiate the client - client = client_fn("-1") - client.set_context(context) - # Execute task - message = None - if field == "get_properties_ins": - message = _get_properties(client, server_msg.get_properties_ins) - if field == "get_parameters_ins": - message = _get_parameters(client, server_msg.get_parameters_ins) - if field == "fit_ins": - message = _fit(client, server_msg.fit_ins) - if field == "evaluate_ins": - message = _evaluate(client, server_msg.evaluate_ins) - if message: - return message, client.get_context() - raise UnknownServerMessage() - - def handle_legacy_message_from_tasktype( client_fn: ClientFn, message: Message, context: Context ) -> Message: @@ -253,67 +160,3 @@ def _reconnect( # Build DisconnectRes message disconnect_res = ClientMessage.DisconnectRes(reason=reason) return ClientMessage(disconnect_res=disconnect_res), sleep_duration - - -def _get_properties( - client: Client, get_properties_msg: ServerMessage.GetPropertiesIns -) -> ClientMessage: - # Deserialize `get_properties` instruction - get_properties_ins = serde.get_properties_ins_from_proto(get_properties_msg) - - # Request properties - get_properties_res = maybe_call_get_properties( - client=client, - get_properties_ins=get_properties_ins, - ) - - # Serialize response - get_properties_res_proto = serde.get_properties_res_to_proto(get_properties_res) - return ClientMessage(get_properties_res=get_properties_res_proto) - - -def _get_parameters( - client: Client, get_parameters_msg: ServerMessage.GetParametersIns -) -> ClientMessage: - # Deserialize `get_parameters` instruction - get_parameters_ins = serde.get_parameters_ins_from_proto(get_parameters_msg) - - # Request parameters - get_parameters_res = maybe_call_get_parameters( - client=client, - get_parameters_ins=get_parameters_ins, - ) - - # Serialize response - get_parameters_res_proto = serde.get_parameters_res_to_proto(get_parameters_res) - return ClientMessage(get_parameters_res=get_parameters_res_proto) - - -def _fit(client: Client, fit_msg: ServerMessage.FitIns) -> ClientMessage: - # Deserialize fit instruction - fit_ins = serde.fit_ins_from_proto(fit_msg) - - # Perform fit - fit_res = maybe_call_fit( - client=client, - fit_ins=fit_ins, - ) - - # Serialize fit result - fit_res_proto = serde.fit_res_to_proto(fit_res) - return ClientMessage(fit_res=fit_res_proto) - - -def _evaluate(client: Client, evaluate_msg: ServerMessage.EvaluateIns) -> ClientMessage: - # Deserialize evaluate instruction - evaluate_ins = serde.evaluate_ins_from_proto(evaluate_msg) - - # Perform evaluation - evaluate_res = maybe_call_evaluate( - client=client, - evaluate_ins=evaluate_ins, - ) - - # Serialize evaluate result - evaluate_res_proto = serde.evaluate_res_to_proto(evaluate_res) - return ClientMessage(evaluate_res=evaluate_res_proto) diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 707570cd8e57..842b6d1e1ee1 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -20,6 +20,7 @@ from flwr.client import Client from flwr.client.typing import ClientFn from flwr.common import ( + Code, EvaluateIns, EvaluateRes, FitIns, @@ -29,21 +30,16 @@ GetPropertiesIns, GetPropertiesRes, Parameters, - serde, - typing, + Status, ) +from flwr.common import recordset_compat as compat +from flwr.common import typing +from flwr.common.constant import TASK_TYPE_GET_PROPERTIES from flwr.common.context import Context +from flwr.common.message import Message, Metadata from flwr.common.recordset import RecordSet -from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 -from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 - ClientMessage, - Code, - ServerMessage, - Status, -) -from .message_handler import handle, handle_control_message +from .message_handler import handle_legacy_message_from_tasktype class ClientWithoutProps(Client): @@ -122,137 +118,71 @@ def test_client_without_get_properties() -> None: """Test client implementing get_properties.""" # Prepare client = ClientWithoutProps() - ins = ServerMessage.GetPropertiesIns() - - task_ins: TaskIns = TaskIns( - task_id=str(uuid.uuid4()), - group_id="", - run_id=0, - task=Task( - producer=Node(node_id=0, anonymous=True), - consumer=Node(node_id=0, anonymous=True), - ancestry=[], - legacy_server_message=ServerMessage(get_properties_ins=ins), + recordset = compat.getpropertiesins_to_recordset(GetPropertiesIns({})) + message = Message( + metadata=Metadata( + run_id=0, + task_id=str(uuid.uuid4()), + group_id="", + ttl="", + task_type=TASK_TYPE_GET_PROPERTIES, ), + message=recordset, ) # Execute - disconnect_task_res, actual_sleep_duration = handle_control_message( - task_ins=task_ins - ) - task_res, _ = handle( + actual_msg = handle_legacy_message_from_tasktype( client_fn=_get_client_fn(client), + message=message, context=Context(state=RecordSet()), - task_ins=task_ins, - ) - - if not task_res.HasField("task"): - raise ValueError("Task value not found") - - # pylint: disable=no-member - if not task_res.task.HasField("legacy_client_message"): - raise ValueError("Unexpected None value") - # pylint: enable=no-member - - task_res.MergeFrom( - TaskRes( - task_id=str(uuid.uuid4()), - group_id="", - run_id=0, - ) - ) - # pylint: disable=no-member - task_res.task.MergeFrom( - Task( - producer=Node(node_id=0, anonymous=True), - consumer=Node(node_id=0, anonymous=True), - ancestry=[task_ins.task_id], - ) ) - actual_msg = task_res.task.legacy_client_message - # pylint: enable=no-member - # Assert - expected_get_properties_res = ClientMessage.GetPropertiesRes( + expected_get_properties_res = GetPropertiesRes( status=Status( code=Code.GET_PROPERTIES_NOT_IMPLEMENTED, message="Client does not implement `get_properties`", - ) + ), + properties={}, ) - expected_msg = ClientMessage(get_properties_res=expected_get_properties_res) + expected_rs = compat.getpropertiesres_to_recordset(expected_get_properties_res) + expected_msg = Message(message.metadata, expected_rs) assert actual_msg == expected_msg - assert not disconnect_task_res - assert actual_sleep_duration == 0 def test_client_with_get_properties() -> None: """Test client not implementing get_properties.""" # Prepare client = ClientWithProps() - ins = ServerMessage.GetPropertiesIns() - task_ins = TaskIns( - task_id=str(uuid.uuid4()), - group_id="", - run_id=0, - task=Task( - producer=Node(node_id=0, anonymous=True), - consumer=Node(node_id=0, anonymous=True), - ancestry=[], - legacy_server_message=ServerMessage(get_properties_ins=ins), + recordset = compat.getpropertiesins_to_recordset(GetPropertiesIns({})) + message = Message( + metadata=Metadata( + run_id=0, + task_id=str(uuid.uuid4()), + group_id="", + ttl="", + task_type=TASK_TYPE_GET_PROPERTIES, ), + message=recordset, ) # Execute - disconnect_task_res, actual_sleep_duration = handle_control_message( - task_ins=task_ins - ) - task_res, _ = handle( + actual_msg = handle_legacy_message_from_tasktype( client_fn=_get_client_fn(client), + message=message, context=Context(state=RecordSet()), - task_ins=task_ins, ) - if not task_res.HasField("task"): - raise ValueError("Task value not found") - - # pylint: disable=no-member - if not task_res.task.HasField("legacy_client_message"): - raise ValueError("Unexpected None value") - # pylint: enable=no-member - - task_res.MergeFrom( - TaskRes( - task_id=str(uuid.uuid4()), - group_id="", - run_id=0, - ) - ) - # pylint: disable=no-member - task_res.task.MergeFrom( - Task( - producer=Node(node_id=0, anonymous=True), - consumer=Node(node_id=0, anonymous=True), - ancestry=[task_ins.task_id], - ) - ) - - actual_msg = task_res.task.legacy_client_message - # pylint: enable=no-member - # Assert - expected_get_properties_res = ClientMessage.GetPropertiesRes( + expected_get_properties_res = GetPropertiesRes( status=Status( code=Code.OK, message="Success", ), - properties=serde.properties_to_proto( - properties={"str_prop": "val", "int_prop": 1} - ), + properties={"str_prop": "val", "int_prop": 1}, ) - expected_msg = ClientMessage(get_properties_res=expected_get_properties_res) + expected_rs = compat.getpropertiesres_to_recordset(expected_get_properties_res) + expected_msg = Message(message.metadata, expected_rs) assert actual_msg == expected_msg - assert not disconnect_task_res - assert actual_sleep_duration == 0 From 981c81a8fdbbc900ab6b08761931fd61727a3edb Mon Sep 17 00:00:00 2001 From: jafermarq Date: Sat, 27 Jan 2024 05:50:57 +0000 Subject: [PATCH 08/13] remove fwd/bwk --- src/py/flwr/client/typing.py | 19 ------------------- src/py/flwr/flower/__init__.py | 4 ---- 2 files changed, 23 deletions(-) diff --git a/src/py/flwr/client/typing.py b/src/py/flwr/client/typing.py index de8cabf83d65..18ddcb62f23f 100644 --- a/src/py/flwr/client/typing.py +++ b/src/py/flwr/client/typing.py @@ -14,32 +14,13 @@ # ============================================================================== """Custom types for Flower clients.""" -from dataclasses import dataclass from typing import Callable from flwr.common.context import Context from flwr.common.message import Message -from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from .client import Client as Client - -@dataclass -class Fwd: - """.""" - - task_ins: TaskIns - context: Context - - -@dataclass -class Bwd: - """.""" - - task_res: TaskRes - context: Context - - FlowerCallable = Callable[[Message, Context], Message] ClientFn = Callable[[str], Client] Layer = Callable[[Message, Context, FlowerCallable], Message] diff --git a/src/py/flwr/flower/__init__.py b/src/py/flwr/flower/__init__.py index 892a7ce5afdc..8b566981c77a 100644 --- a/src/py/flwr/flower/__init__.py +++ b/src/py/flwr/flower/__init__.py @@ -16,11 +16,7 @@ from flwr.client.flower import Flower as Flower -from flwr.client.typing import Bwd as Bwd -from flwr.client.typing import Fwd as Fwd __all__ = [ "Flower", - "Fwd", - "Bwd", ] From 1eadea684d442ec9e62edcf3c2828d66f90afc10 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Sat, 27 Jan 2024 09:54:08 +0100 Subject: [PATCH 09/13] Update src/py/flwr/client/middleware/utils_test.py --- src/py/flwr/client/middleware/utils_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/client/middleware/utils_test.py b/src/py/flwr/client/middleware/utils_test.py index 73abdbc61302..a4e1f6e87599 100644 --- a/src/py/flwr/client/middleware/utils_test.py +++ b/src/py/flwr/client/middleware/utils_test.py @@ -44,7 +44,7 @@ def make_mock_middleware(name: str, footprint: List[str]) -> Layer: def middleware(message: Message, context: Context, app: FlowerCallable) -> Message: footprint.append(name) - # add empty ConfigRegcord to in_message for this middleware layer + # add empty ConfigRecord to in_message for this middleware layer message.message.set_configs(name=name, record=ConfigsRecord()) _increment_context_counter(context) out_message: Message = app(message, context) From ecee3caad3f5fcb84d59264344ef56b5bf9a11df Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Sat, 27 Jan 2024 09:26:46 +0000 Subject: [PATCH 10/13] update taskins taskres validation --- .../client/grpc_rere_client/connection.py | 4 +- .../client/message_handler/task_handler.py | 51 +------- .../message_handler/task_handler_test.py | 116 ++---------------- src/py/flwr/client/rest_client/connection.py | 4 +- 4 files changed, 13 insertions(+), 162 deletions(-) diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index cb1a7021dc9d..e5a4b6883378 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -148,9 +148,7 @@ def receive() -> Optional[TaskIns]: task_ins: Optional[TaskIns] = get_task_ins(response) # Discard the current TaskIns if not valid - if task_ins is not None and not validate_task_ins( - task_ins, discard_reconnect_ins=True - ): + if task_ins is not None and not validate_task_ins(task_ins): task_ins = None # Remember `task_ins` until `task_res` is available diff --git a/src/py/flwr/client/message_handler/task_handler.py b/src/py/flwr/client/message_handler/task_handler.py index 667cb9c98d46..daac1be77138 100644 --- a/src/py/flwr/client/message_handler/task_handler.py +++ b/src/py/flwr/client/message_handler/task_handler.py @@ -20,21 +20,15 @@ from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611 from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 -from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 - ClientMessage, - ServerMessage, -) -def validate_task_ins(task_ins: TaskIns, discard_reconnect_ins: bool) -> bool: +def validate_task_ins(task_ins: TaskIns) -> bool: """Validate a TaskIns before it entering the message handling process. Parameters ---------- task_ins: TaskIns The task instruction coming from the server. - discard_reconnect_ins: bool - If True, ReconnectIns will not be considered as valid content. Returns ------- @@ -42,23 +36,8 @@ def validate_task_ins(task_ins: TaskIns, discard_reconnect_ins: bool) -> bool: True if the TaskIns is deemed valid and therefore suitable for undergoing the message handling process, False otherwise. """ - # Check if the task_ins contains legacy_server_message or sa. - # If legacy_server_message is set, check if ServerMessage is one of - # {GetPropertiesIns, GetParametersIns, FitIns, EvaluateIns, ReconnectIns*} - # Discard ReconnectIns if discard_reconnect_ins is true. - if ( - not task_ins.HasField("task") - or ( - not task_ins.task.HasField("legacy_server_message") - and not task_ins.task.HasField("sa") - ) - or ( - discard_reconnect_ins - and task_ins.task.legacy_server_message.WhichOneof("msg") == "reconnect_ins" - ) - ): + if not (task_ins.HasField("task") and task_ins.task.HasField("recordset")): return False - return True @@ -110,32 +89,6 @@ def get_task_ins( return task_ins -def get_server_message_from_task_ins( - task_ins: TaskIns, exclude_reconnect_ins: bool -) -> Optional[ServerMessage]: - """Get ServerMessage from TaskIns, if available.""" - # Return the message if it is in - # {GetPropertiesIns, GetParametersIns, FitIns, EvaluateIns} - # Return the message if it is ReconnectIns and exclude_reconnect_ins is False. - if not validate_task_ins( - task_ins, discard_reconnect_ins=exclude_reconnect_ins - ) or not task_ins.task.HasField("legacy_server_message"): - return None - - return task_ins.task.legacy_server_message - - -def wrap_client_message_in_task_res(client_message: ClientMessage) -> TaskRes: - """Wrap ClientMessage in TaskRes.""" - # Instantiate a TaskRes, only filling client_message field. - return TaskRes( - task_id="", - group_id="", - run_id=0, - task=Task(ancestry=[], legacy_client_message=client_message), - ) - - def configure_task_res( task_res: TaskRes, ref_task_ins: TaskIns, producer: Node ) -> TaskRes: diff --git a/src/py/flwr/client/message_handler/task_handler_test.py b/src/py/flwr/client/message_handler/task_handler_test.py index c1111d0935c0..9a668231d509 100644 --- a/src/py/flwr/client/message_handler/task_handler_test.py +++ b/src/py/flwr/client/message_handler/task_handler_test.py @@ -16,75 +16,35 @@ from flwr.client.message_handler.task_handler import ( - get_server_message_from_task_ins, get_task_ins, validate_task_ins, validate_task_res, - wrap_client_message_in_task_res, ) +from flwr.common import serde +from flwr.common.recordset import RecordSet from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611 -from flwr.proto.task_pb2 import ( # pylint: disable=E0611 - SecureAggregation, - Task, - TaskIns, - TaskRes, -) -from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 - ClientMessage, - ServerMessage, -) +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 def test_validate_task_ins_no_task() -> None: """Test validate_task_ins.""" task_ins = TaskIns(task=None) - assert not validate_task_ins(task_ins, discard_reconnect_ins=True) - assert not validate_task_ins(task_ins, discard_reconnect_ins=False) + assert not validate_task_ins(task_ins) def test_validate_task_ins_no_content() -> None: """Test validate_task_ins.""" - task_ins = TaskIns(task=Task(legacy_server_message=None, sa=None)) - - assert not validate_task_ins(task_ins, discard_reconnect_ins=True) - assert not validate_task_ins(task_ins, discard_reconnect_ins=False) - - -def test_validate_task_ins_with_reconnect_ins() -> None: - """Test validate_task_ins.""" - task_ins = TaskIns( - task=Task( - legacy_server_message=ServerMessage( - reconnect_ins=ServerMessage.ReconnectIns(seconds=3) - ) - ) - ) - - assert not validate_task_ins(task_ins, discard_reconnect_ins=True) - assert validate_task_ins(task_ins, discard_reconnect_ins=False) - - -def test_validate_task_ins_valid_legacy_server_message() -> None: - """Test validate_task_ins.""" - task_ins = TaskIns( - task=Task( - legacy_server_message=ServerMessage( - get_properties_ins=ServerMessage.GetPropertiesIns() - ) - ) - ) + task_ins = TaskIns(task=Task(recordset=None)) - assert validate_task_ins(task_ins, discard_reconnect_ins=True) - assert validate_task_ins(task_ins, discard_reconnect_ins=False) + assert not validate_task_ins(task_ins) -def test_validate_task_ins_valid_sa() -> None: +def test_validate_task_ins_valid() -> None: """Test validate_task_ins.""" - task_ins = TaskIns(task=Task(sa=SecureAggregation())) + task_ins = TaskIns(task=Task(recordset=serde.recordset_to_proto(RecordSet()))) - assert validate_task_ins(task_ins, discard_reconnect_ins=True) - assert validate_task_ins(task_ins, discard_reconnect_ins=False) + assert validate_task_ins(task_ins) def test_validate_task_res() -> None: @@ -142,61 +102,3 @@ def test_get_task_ins_multiple_ins() -> None: ) actual_task_ins = get_task_ins(res) assert actual_task_ins == expected_task_ins - - -def test_get_server_message_from_task_ins_invalid() -> None: - """Test get_server_message_from_task_ins.""" - task_ins = TaskIns(task=Task(legacy_server_message=None)) - msg_t = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=True) - msg_f = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False) - - assert msg_t is None - assert msg_f is None - - -def test_get_server_message_from_task_ins_reconnect_ins() -> None: - """Test get_server_message_from_task_ins.""" - expected_server_message = ServerMessage( - reconnect_ins=ServerMessage.ReconnectIns(seconds=3) - ) - task_ins = TaskIns(task=Task(legacy_server_message=expected_server_message)) - msg_t = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=True) - msg_f = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False) - - assert msg_t is None - assert msg_f == expected_server_message - - -def test_get_server_message_from_task_ins_sa() -> None: - """Test get_server_message_from_task_ins.""" - task_ins = TaskIns(task=Task(sa=SecureAggregation())) - msg_t = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=True) - msg_f = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False) - - assert msg_t is None - assert msg_f is None - - -def test_get_server_message_from_task_ins_valid_legacy_server_message() -> None: - """Test get_server_message_from_task_ins.""" - expected_server_message = ServerMessage( - get_properties_ins=ServerMessage.GetPropertiesIns() - ) - task_ins = TaskIns(task=Task(legacy_server_message=expected_server_message)) - msg_t = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=True) - msg_f = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False) - - assert msg_t == expected_server_message - assert msg_f == expected_server_message - - -def test_wrap_client_message_in_task_res() -> None: - """Test wrap_client_message_in_task_res.""" - expected_client_message = ClientMessage( - get_properties_res=ClientMessage.GetPropertiesRes() - ) - task_res = wrap_client_message_in_task_res(expected_client_message) - - assert validate_task_res(task_res) - # pylint: disable-next=no-member - assert task_res.task.legacy_client_message == expected_client_message diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index bb55f130f1a8..1ab15f57e521 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -256,9 +256,7 @@ def receive() -> Optional[TaskIns]: task_ins: Optional[TaskIns] = get_task_ins(pull_task_ins_response_proto) # Discard the current TaskIns if not valid - if task_ins is not None and not validate_task_ins( - task_ins, discard_reconnect_ins=True - ): + if task_ins is not None and not validate_task_ins(task_ins): task_ins = None # Remember `task_ins` until `task_res` is available From ffef19e2adea07fece51d7581a3ecc1b7c4b17f9 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Sat, 27 Jan 2024 10:24:09 +0000 Subject: [PATCH 11/13] fix a bug in metadata of returning message --- .../client/message_handler/message_handler.py | 23 ++++++++++++++----- .../message_handler/message_handler_test.py | 6 +++-- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index e16d190b0289..5a77e86c1afd 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -33,7 +33,7 @@ TASK_TYPE_GET_PROPERTIES, ) from flwr.common.context import Context -from flwr.common.message import Message +from flwr.common.message import Message, Metadata from flwr.common.recordset import RecordSet from flwr.common.recordset_compat import ( evaluateres_to_recordset, @@ -112,21 +112,20 @@ def handle_legacy_message_from_tasktype( task_type = message.metadata.task_type - out_message = Message(metadata=message.metadata, message=RecordSet()) # Handle GetPropertiesIns if task_type == TASK_TYPE_GET_PROPERTIES: get_properties_res = maybe_call_get_properties( client=client, get_properties_ins=recordset_to_getpropertiesins(message.message), ) - out_message.message = getpropertiesres_to_recordset(get_properties_res) + out_recordset = getpropertiesres_to_recordset(get_properties_res) # Handle GetParametersIns elif task_type == TASK_TYPE_GET_PARAMETERS: get_parameters_res = maybe_call_get_parameters( client=client, get_parameters_ins=recordset_to_getparametersins(message.message), ) - out_message.message = getparametersres_to_recordset( + out_recordset = getparametersres_to_recordset( get_parameters_res, keep_input=False ) # Handle FitIns @@ -135,16 +134,28 @@ def handle_legacy_message_from_tasktype( client=client, fit_ins=recordset_to_fitins(message.message, keep_input=False), ) - out_message.message = fitres_to_recordset(fit_res, keep_input=False) + out_recordset = fitres_to_recordset(fit_res, keep_input=False) # Handle EvaluateIns elif task_type == TASK_TYPE_EVALUATE: evaluate_res = maybe_call_evaluate( client=client, evaluate_ins=recordset_to_evaluateins(message.message, keep_input=False), ) - out_message.message = evaluateres_to_recordset(evaluate_res) + out_recordset = evaluateres_to_recordset(evaluate_res) else: raise ValueError(f"Invalid task type: {task_type}") + + # Return Message + out_message = Message( + metadata=Metadata( + run_id=0, # Non-user defined + task_id="", # Non-user defined + group_id="", # Non-user defined + ttl="", + task_type=task_type, + ), + message=out_recordset, + ) return out_message diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 842b6d1e1ee1..ad09ca95abc7 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -148,7 +148,8 @@ def test_client_without_get_properties() -> None: expected_rs = compat.getpropertiesres_to_recordset(expected_get_properties_res) expected_msg = Message(message.metadata, expected_rs) - assert actual_msg == expected_msg + assert actual_msg.message == expected_msg.message + assert actual_msg.metadata.task_type == expected_msg.metadata.task_type def test_client_with_get_properties() -> None: @@ -185,4 +186,5 @@ def test_client_with_get_properties() -> None: expected_rs = compat.getpropertiesres_to_recordset(expected_get_properties_res) expected_msg = Message(message.metadata, expected_rs) - assert actual_msg == expected_msg + assert actual_msg.message == expected_msg.message + assert actual_msg.metadata.task_type == expected_msg.metadata.task_type From 5cff08eaee681771118273005b7ce55dbd898620 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Sat, 27 Jan 2024 10:44:22 +0000 Subject: [PATCH 12/13] keep input in handle function --- src/py/flwr/client/message_handler/message_handler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 5a77e86c1afd..b9abf9bf142b 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -126,20 +126,20 @@ def handle_legacy_message_from_tasktype( get_parameters_ins=recordset_to_getparametersins(message.message), ) out_recordset = getparametersres_to_recordset( - get_parameters_res, keep_input=False + get_parameters_res, keep_input=True ) # Handle FitIns elif task_type == TASK_TYPE_FIT: fit_res = maybe_call_fit( client=client, - fit_ins=recordset_to_fitins(message.message, keep_input=False), + fit_ins=recordset_to_fitins(message.message, keep_input=True), ) - out_recordset = fitres_to_recordset(fit_res, keep_input=False) + out_recordset = fitres_to_recordset(fit_res, keep_input=True) # Handle EvaluateIns elif task_type == TASK_TYPE_EVALUATE: evaluate_res = maybe_call_evaluate( client=client, - evaluate_ins=recordset_to_evaluateins(message.message, keep_input=False), + evaluate_ins=recordset_to_evaluateins(message.message, keep_input=True), ) out_recordset = evaluateres_to_recordset(evaluate_res) else: From 599a9744e5601c3260a6db482e442b85d8a00dd8 Mon Sep 17 00:00:00 2001 From: Heng Pan <134433891+panh99@users.noreply.github.com> Date: Mon, 29 Jan 2024 10:40:55 +0000 Subject: [PATCH 13/13] Make `DriverClientProxy` work with `RecordSet` and `task_type`. (#2853) Co-authored-by: jafermarq Co-authored-by: Heng Pan Co-authored-by: Daniel J. Beutel --- .../client/message_handler/message_handler.py | 2 +- src/py/flwr/driver/driver_client_proxy.py | 89 +++++++++-------- .../flwr/driver/driver_client_proxy_test.py | 98 +++++++++++++------ src/py/flwr/server/state/state_test.py | 15 +-- src/py/flwr/server/utils/validator.py | 38 ++----- src/py/flwr/server/utils/validator_test.py | 84 ++-------------- 6 files changed, 136 insertions(+), 190 deletions(-) diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index b9abf9bf142b..ea87c35c83f7 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -72,7 +72,7 @@ def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]: Returns ------- task_res : Optional[TaskRes] - TaskRes to be returned to the server. If None, the client should + TaskRes to be sent back to the server. If None, the client should continue to process messages from the server. sleep_duration : int Number of seconds that the client should disconnect from the server. diff --git a/src/py/flwr/driver/driver_client_proxy.py b/src/py/flwr/driver/driver_client_proxy.py index 8b2e51c17ea0..e0ff26c035f7 100644 --- a/src/py/flwr/driver/driver_client_proxy.py +++ b/src/py/flwr/driver/driver_client_proxy.py @@ -16,16 +16,19 @@ import time -from typing import List, Optional, cast +from typing import List, Optional from flwr import common +from flwr.common import recordset_compat as compat from flwr.common import serde -from flwr.proto import ( # pylint: disable=E0611 - driver_pb2, - node_pb2, - task_pb2, - transport_pb2, +from flwr.common.constant import ( + TASK_TYPE_EVALUATE, + TASK_TYPE_FIT, + TASK_TYPE_GET_PARAMETERS, + TASK_TYPE_GET_PROPERTIES, ) +from flwr.common.recordset import RecordSet +from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 from flwr.server.client_proxy import ClientProxy from .grpc_driver import GrpcDriver @@ -47,55 +50,51 @@ def get_properties( self, ins: common.GetPropertiesIns, timeout: Optional[float] ) -> common.GetPropertiesRes: """Return client's properties.""" - server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101 - serde.server_message_to_proto( - server_message=common.ServerMessage(get_properties_ins=ins) - ) - ) - return cast( - common.GetPropertiesRes, - self._send_receive_msg(server_message_proto, timeout).get_properties_res, + # Ins to RecordSet + out_recordset = compat.getpropertiesins_to_recordset(ins) + # Fetch response + in_recordset = self._send_receive_recordset( + out_recordset, TASK_TYPE_GET_PROPERTIES, timeout ) + # RecordSet to Res + return compat.recordset_to_getpropertiesres(in_recordset) def get_parameters( self, ins: common.GetParametersIns, timeout: Optional[float] ) -> common.GetParametersRes: """Return the current local model parameters.""" - server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101 - serde.server_message_to_proto( - server_message=common.ServerMessage(get_parameters_ins=ins) - ) - ) - return cast( - common.GetParametersRes, - self._send_receive_msg(server_message_proto, timeout).get_parameters_res, + # Ins to RecordSet + out_recordset = compat.getparametersins_to_recordset(ins) + # Fetch response + in_recordset = self._send_receive_recordset( + out_recordset, TASK_TYPE_GET_PARAMETERS, timeout ) + # RecordSet to Res + return compat.recordset_to_getparametersres(in_recordset, False) def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: """Train model parameters on the locally held dataset.""" - server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101 - serde.server_message_to_proto( - server_message=common.ServerMessage(fit_ins=ins) - ) - ) - return cast( - common.FitRes, - self._send_receive_msg(server_message_proto, timeout).fit_res, + # Ins to RecordSet + out_recordset = compat.fitins_to_recordset(ins, keep_input=True) + # Fetch response + in_recordset = self._send_receive_recordset( + out_recordset, TASK_TYPE_FIT, timeout ) + # RecordSet to Res + return compat.recordset_to_fitres(in_recordset, keep_input=False) def evaluate( self, ins: common.EvaluateIns, timeout: Optional[float] ) -> common.EvaluateRes: """Evaluate model parameters on the locally held dataset.""" - server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101 - serde.server_message_to_proto( - server_message=common.ServerMessage(evaluate_ins=ins) - ) - ) - return cast( - common.EvaluateRes, - self._send_receive_msg(server_message_proto, timeout).evaluate_res, + # Ins to RecordSet + out_recordset = compat.evaluateins_to_recordset(ins, keep_input=True) + # Fetch response + in_recordset = self._send_receive_recordset( + out_recordset, TASK_TYPE_EVALUATE, timeout ) + # RecordSet to Res + return compat.recordset_to_evaluateres(in_recordset) def reconnect( self, ins: common.ReconnectIns, timeout: Optional[float] @@ -103,11 +102,12 @@ def reconnect( """Disconnect and (optionally) reconnect later.""" return common.DisconnectRes(reason="") # Nothing to do here (yet) - def _send_receive_msg( + def _send_receive_recordset( self, - server_message: transport_pb2.ServerMessage, # pylint: disable=E1101 + recordset: RecordSet, + task_type: str, timeout: Optional[float], - ) -> transport_pb2.ClientMessage: # pylint: disable=E1101 + ) -> RecordSet: task_ins = task_pb2.TaskIns( # pylint: disable=E1101 task_id="", group_id="", @@ -121,7 +121,8 @@ def _send_receive_msg( node_id=self.node_id, anonymous=self.anonymous, ), - legacy_server_message=server_message, + task_type=task_type, + recordset=serde.recordset_to_proto(recordset), ), ) push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101 @@ -155,9 +156,7 @@ def _send_receive_msg( ) if len(task_res_list) == 1: task_res = task_res_list[0] - return serde.client_message_from_proto( # type: ignore - task_res.task.legacy_client_message - ) + return serde.recordset_from_proto(task_res.task.recordset) if timeout is not None and time.time() > start_time + timeout: raise RuntimeError("Timeout reached") diff --git a/src/py/flwr/driver/driver_client_proxy_test.py b/src/py/flwr/driver/driver_client_proxy_test.py index d3cab152e4db..4e9a02a6cbf9 100644 --- a/src/py/flwr/driver/driver_client_proxy_test.py +++ b/src/py/flwr/driver/driver_client_proxy_test.py @@ -16,23 +16,63 @@ import unittest +from typing import Union, cast from unittest.mock import MagicMock import numpy as np import flwr -from flwr.common.typing import Config, GetParametersIns -from flwr.driver.driver_client_proxy import DriverClientProxy -from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 -from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 - ClientMessage, +from flwr.common import recordset_compat as compat +from flwr.common import serde +from flwr.common.constant import ( + TASK_TYPE_EVALUATE, + TASK_TYPE_FIT, + TASK_TYPE_GET_PARAMETERS, + TASK_TYPE_GET_PROPERTIES, +) +from flwr.common.typing import ( + Code, + Config, + EvaluateIns, + EvaluateRes, + FitRes, + GetParametersIns, + GetParametersRes, + GetPropertiesRes, Parameters, - Scalar, + Properties, + Status, ) +from flwr.driver.driver_client_proxy import DriverClientProxy +from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 MESSAGE_PARAMETERS = Parameters(tensors=[b"abc"], tensor_type="np") -CLIENT_PROPERTIES = {"tensor_type": Scalar(string="numpy.ndarray")} +CLIENT_PROPERTIES = cast(Properties, {"tensor_type": "numpy.ndarray"}) +CLIENT_STATUS = Status(code=Code.OK, message="OK") + + +def _make_task( + res: Union[GetParametersRes, GetPropertiesRes, FitRes, EvaluateRes] +) -> task_pb2.Task: # pylint: disable=E1101 + if isinstance(res, GetParametersRes): + task_type = TASK_TYPE_GET_PARAMETERS + recordset = compat.getparametersres_to_recordset(res, True) + elif isinstance(res, GetPropertiesRes): + task_type = TASK_TYPE_GET_PROPERTIES + recordset = compat.getpropertiesres_to_recordset(res) + elif isinstance(res, FitRes): + task_type = TASK_TYPE_FIT + recordset = compat.fitres_to_recordset(res, True) + elif isinstance(res, EvaluateRes): + task_type = TASK_TYPE_EVALUATE + recordset = compat.evaluateres_to_recordset(res) + else: + raise ValueError(f"Unsupported type: {type(res)}") + return task_pb2.Task( # pylint: disable=E1101 + task_type=task_type, + recordset=serde.recordset_to_proto(recordset), + ) class DriverClientProxyTestCase(unittest.TestCase): @@ -64,11 +104,9 @@ def test_get_properties(self) -> None: task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", run_id=0, - task=task_pb2.Task( # pylint: disable=E1101 - legacy_client_message=ClientMessage( - get_properties_res=ClientMessage.GetPropertiesRes( - properties=CLIENT_PROPERTIES - ) + task=_make_task( + GetPropertiesRes( + status=CLIENT_STATUS, properties=CLIENT_PROPERTIES ) ), ) @@ -104,11 +142,10 @@ def test_get_parameters(self) -> None: task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", run_id=0, - task=task_pb2.Task( # pylint: disable=E1101 - legacy_client_message=ClientMessage( - get_parameters_res=ClientMessage.GetParametersRes( - parameters=MESSAGE_PARAMETERS, - ) + task=_make_task( + GetParametersRes( + status=CLIENT_STATUS, + parameters=MESSAGE_PARAMETERS, ) ), ) @@ -143,12 +180,12 @@ def test_fit(self) -> None: task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", run_id=0, - task=task_pb2.Task( # pylint: disable=E1101 - legacy_client_message=ClientMessage( - fit_res=ClientMessage.FitRes( - parameters=MESSAGE_PARAMETERS, - num_examples=10, - ) + task=_make_task( + FitRes( + status=CLIENT_STATUS, + parameters=MESSAGE_PARAMETERS, + num_examples=10, + metrics={}, ) ), ) @@ -184,11 +221,12 @@ def test_evaluate(self) -> None: task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", run_id=0, - task=task_pb2.Task( # pylint: disable=E1101 - legacy_client_message=ClientMessage( - evaluate_res=ClientMessage.EvaluateRes( - loss=0.0, num_examples=0 - ) + task=_make_task( + EvaluateRes( + status=CLIENT_STATUS, + loss=0.0, + num_examples=0, + metrics={}, ) ), ) @@ -198,8 +236,8 @@ def test_evaluate(self) -> None: client = DriverClientProxy( node_id=1, driver=self.driver, anonymous=True, run_id=0 ) - parameters = flwr.common.Parameters(tensors=[], tensor_type="np") - evaluate_ins: flwr.common.EvaluateIns = flwr.common.EvaluateIns(parameters, {}) + parameters = Parameters(tensors=[], tensor_type="np") + evaluate_ins = EvaluateIns(parameters, {}) # Execute evaluate_res = client.evaluate(evaluate_ins, timeout=None) diff --git a/src/py/flwr/server/state/state_test.py b/src/py/flwr/server/state/state_test.py index 7f9094625765..95d764792ff3 100644 --- a/src/py/flwr/server/state/state_test.py +++ b/src/py/flwr/server/state/state_test.py @@ -23,11 +23,8 @@ from uuid import uuid4 from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 -from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 - ClientMessage, - ServerMessage, -) from flwr.server.state import InMemoryState, SqliteState, State @@ -421,9 +418,8 @@ def create_task_ins( delivered_at=delivered_at, producer=Node(node_id=0, anonymous=True), consumer=consumer, - legacy_server_message=ServerMessage( - reconnect_ins=ServerMessage.ReconnectIns() - ), + task_type="mock", + recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) return task @@ -444,9 +440,8 @@ def create_task_res( producer=Node(node_id=producer_node_id, anonymous=anonymous), consumer=Node(node_id=0, anonymous=True), ancestry=ancestry, - legacy_client_message=ClientMessage( - disconnect_res=ClientMessage.DisconnectRes() - ), + task_type="mock", + recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) return task_res diff --git a/src/py/flwr/server/utils/validator.py b/src/py/flwr/server/utils/validator.py index 01dbcf982cce..f9b271beafdc 100644 --- a/src/py/flwr/server/utils/validator.py +++ b/src/py/flwr/server/utils/validator.py @@ -64,21 +64,10 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str validation_errors.append("non-anonymous consumer MUST provide a `node_id`") # Content check - has_fields = { - "sa": tasks_ins_res.task.HasField("sa"), - "legacy_server_message": tasks_ins_res.task.HasField( - "legacy_server_message" - ), - } - if not (has_fields["sa"] or has_fields["legacy_server_message"]): - err_msg = ", ".join([f"`{field}`" for field in has_fields]) - validation_errors.append( - f"`task` in `TaskIns` must set at least one of fields {{{err_msg}}}" - ) - if has_fields[ - "legacy_server_message" - ] and not tasks_ins_res.task.legacy_server_message.HasField("msg"): - validation_errors.append("`legacy_server_message` does not set field `msg`") + if tasks_ins_res.task.task_type == "": + validation_errors.append("`task_type` MUST be set") + if not tasks_ins_res.task.HasField("recordset"): + validation_errors.append("`recordset` MUST be set") # Ancestors if len(tasks_ins_res.task.ancestry) != 0: @@ -115,21 +104,10 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str validation_errors.append("non-anonymous consumer MUST provide a `node_id`") # Content check - has_fields = { - "sa": tasks_ins_res.task.HasField("sa"), - "legacy_client_message": tasks_ins_res.task.HasField( - "legacy_client_message" - ), - } - if not (has_fields["sa"] or has_fields["legacy_client_message"]): - err_msg = ", ".join([f"`{field}`" for field in has_fields]) - validation_errors.append( - f"`task` in `TaskRes` must set at least one of fields {{{err_msg}}}" - ) - if has_fields[ - "legacy_client_message" - ] and not tasks_ins_res.task.legacy_client_message.HasField("msg"): - validation_errors.append("`legacy_client_message` does not set field `msg`") + if tasks_ins_res.task.task_type == "": + validation_errors.append("`task_type` MUST be set") + if not tasks_ins_res.task.HasField("recordset"): + validation_errors.append("`recordset` MUST be set") # Ancestors if len(tasks_ins_res.task.ancestry) == 0: diff --git a/src/py/flwr/server/utils/validator_test.py b/src/py/flwr/server/utils/validator_test.py index a93e4fb4d457..8e0849508020 100644 --- a/src/py/flwr/server/utils/validator_test.py +++ b/src/py/flwr/server/utils/validator_test.py @@ -19,16 +19,8 @@ from typing import List, Tuple from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.task_pb2 import ( # pylint: disable=E0611 - SecureAggregation, - Task, - TaskIns, - TaskRes, -) -from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 - ClientMessage, - ServerMessage, -) +from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 from .validator import validate_task_ins_or_res @@ -45,16 +37,12 @@ def test_task_ins(self) -> None: # Execute & Assert for consumer_node_id, anonymous in valid_ins: - msg = create_task_ins( - consumer_node_id, anonymous, has_legacy_server_message=True - ) + msg = create_task_ins(consumer_node_id, anonymous) val_errors = validate_task_ins_or_res(msg) self.assertFalse(val_errors) for consumer_node_id, anonymous in invalid_ins: - msg = create_task_ins( - consumer_node_id, anonymous, has_legacy_server_message=True - ) + msg = create_task_ins(consumer_node_id, anonymous) val_errors = validate_task_ins_or_res(msg) self.assertTrue(val_errors) @@ -78,61 +66,19 @@ def test_is_valid_task_res(self) -> None: # Execute & Assert for producer_node_id, anonymous, ancestry in valid_res: - msg = create_task_res( - producer_node_id, anonymous, ancestry, has_legacy_client_message=True - ) + msg = create_task_res(producer_node_id, anonymous, ancestry) val_errors = validate_task_ins_or_res(msg) self.assertFalse(val_errors) for producer_node_id, anonymous, ancestry in invalid_res: - msg = create_task_res( - producer_node_id, anonymous, ancestry, has_legacy_client_message=True - ) + msg = create_task_res(producer_node_id, anonymous, ancestry) val_errors = validate_task_ins_or_res(msg) self.assertTrue(val_errors, (producer_node_id, anonymous, ancestry)) - def test_task_ins_secure_aggregation(self) -> None: - """Test is_valid task_ins for Secure Aggregation.""" - # Prepare - # (has_legacy_server_message, has_sa) - valid_ins = [(True, True), (False, True)] - invalid_ins = [(False, False)] - - # Execute & Assert - for has_legacy_server_message, has_sa in valid_ins: - msg = create_task_ins(1, False, has_legacy_server_message, has_sa) - val_errors = validate_task_ins_or_res(msg) - self.assertFalse(val_errors) - - for has_legacy_server_message, has_sa in invalid_ins: - msg = create_task_ins(1, False, has_legacy_server_message, has_sa) - val_errors = validate_task_ins_or_res(msg) - self.assertTrue(val_errors) - - def test_task_res_secure_aggregation(self) -> None: - """Test is_valid task_res for Secure Aggregation.""" - # Prepare - # (has_legacy_server_message, has_sa) - valid_res = [(True, True), (False, True)] - invalid_res = [(False, False)] - - # Execute & Assert - for has_legacy_client_message, has_sa in valid_res: - msg = create_task_res(0, True, ["1"], has_legacy_client_message, has_sa) - val_errors = validate_task_ins_or_res(msg) - self.assertFalse(val_errors) - - for has_legacy_client_message, has_sa in invalid_res: - msg = create_task_res(0, True, ["1"], has_legacy_client_message, has_sa) - val_errors = validate_task_ins_or_res(msg) - self.assertTrue(val_errors) - def create_task_ins( consumer_node_id: int, anonymous: bool, - has_legacy_server_message: bool = False, - has_sa: bool = False, delivered_at: str = "", ) -> TaskIns: """Create a TaskIns for testing.""" @@ -148,12 +94,8 @@ def create_task_ins( delivered_at=delivered_at, producer=Node(node_id=0, anonymous=True), consumer=consumer, - legacy_server_message=ServerMessage( - reconnect_ins=ServerMessage.ReconnectIns() - ) - if has_legacy_server_message - else None, - sa=SecureAggregation(named_values={}) if has_sa else None, + task_type="mock", + recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) return task @@ -163,8 +105,6 @@ def create_task_res( producer_node_id: int, anonymous: bool, ancestry: List[str], - has_legacy_client_message: bool = False, - has_sa: bool = False, ) -> TaskRes: """Create a TaskRes for testing.""" task_res = TaskRes( @@ -175,12 +115,8 @@ def create_task_res( producer=Node(node_id=producer_node_id, anonymous=anonymous), consumer=Node(node_id=0, anonymous=True), ancestry=ancestry, - legacy_client_message=ClientMessage( - disconnect_res=ClientMessage.DisconnectRes() - ) - if has_legacy_client_message - else None, - sa=SecureAggregation(named_values={}) if has_sa else None, + task_type="mock", + recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) return task_res