From 9c85be5a6c932eb7ec8166f1b2a58b5b386a92e5 Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Thu, 15 Aug 2024 16:47:51 +0100 Subject: [PATCH] feat(framework) Add `ClientAppIo` servicer (#3976) --- src/py/flwr/client/process/__init__.py | 15 ++ .../client/process/clientappio_servicer.py | 145 ++++++++++++++++++ .../process/clientappio_servicer_test.py | 118 ++++++++++++++ 3 files changed, 278 insertions(+) create mode 100644 src/py/flwr/client/process/__init__.py create mode 100644 src/py/flwr/client/process/clientappio_servicer.py create mode 100644 src/py/flwr/client/process/clientappio_servicer_test.py diff --git a/src/py/flwr/client/process/__init__.py b/src/py/flwr/client/process/__init__.py new file mode 100644 index 000000000000..653cee434c12 --- /dev/null +++ b/src/py/flwr/client/process/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower AppIO service.""" diff --git a/src/py/flwr/client/process/clientappio_servicer.py b/src/py/flwr/client/process/clientappio_servicer.py new file mode 100644 index 000000000000..f614fadf8070 --- /dev/null +++ b/src/py/flwr/client/process/clientappio_servicer.py @@ -0,0 +1,145 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ClientAppIo API servicer.""" + + +from dataclasses import dataclass +from logging import DEBUG, ERROR +from typing import Optional + +import grpc + +from flwr.common import Context, Message, typing +from flwr.common.logger import log +from flwr.common.serde import ( + clientappstatus_to_proto, + context_from_proto, + context_to_proto, + message_from_proto, + message_to_proto, + run_to_proto, +) +from flwr.common.typing import Run + +# pylint: disable=E0611 +from flwr.proto import clientappio_pb2_grpc +from flwr.proto.clientappio_pb2 import ( # pylint: disable=E0401 + PullClientAppInputsRequest, + PullClientAppInputsResponse, + PushClientAppOutputsRequest, + PushClientAppOutputsResponse, +) + + +@dataclass +class ClientAppIoInputs: + """Specify the inputs to the ClientApp.""" + + message: Message + context: Context + run: Run + token: int + + +@dataclass +class ClientAppIoOutputs: + """Specify the outputs from the ClientApp.""" + + message: Message + context: Context + + +# pylint: disable=C0103,W0613,W0201 +class ClientAppIoServicer(clientappio_pb2_grpc.ClientAppIoServicer): + """ClientAppIo API servicer.""" + + def __init__(self) -> None: + self.clientapp_input: Optional[ClientAppIoInputs] = None + self.clientapp_output: Optional[ClientAppIoOutputs] = None + + def PullClientAppInputs( + self, request: PullClientAppInputsRequest, context: grpc.ServicerContext + ) -> PullClientAppInputsResponse: + """Pull Message, Context, and Run.""" + log(DEBUG, "ClientAppIo.PullClientAppInputs") + if self.clientapp_input is None: + raise ValueError( + "ClientAppIoInputs not set before calling `PullClientAppInputs`." + ) + if request.token != self.clientapp_input.token: + context.abort( + grpc.StatusCode.INVALID_ARGUMENT, + "Mismatch between ClientApp and SuperNode token", + ) + return PullClientAppInputsResponse( + message=message_to_proto(self.clientapp_input.message), + context=context_to_proto(self.clientapp_input.context), + run=run_to_proto(self.clientapp_input.run), + ) + + def PushClientAppOutputs( + self, request: PushClientAppOutputsRequest, context: grpc.ServicerContext + ) -> PushClientAppOutputsResponse: + """Push Message and Context.""" + log(DEBUG, "ClientAppIo.PushClientAppOutputs") + if self.clientapp_output is None: + raise ValueError( + "ClientAppIoOutputs not set before calling `PushClientAppOutputs`." + ) + if self.clientapp_input is None: + raise ValueError( + "ClientAppIoInputs not set before calling `PushClientAppOutputs`." + ) + if request.token != self.clientapp_input.token: + context.abort( + grpc.StatusCode.INVALID_ARGUMENT, + "Mismatch between ClientApp and SuperNode token", + ) + try: + # Update Message and Context + self.clientapp_output.message = message_from_proto(request.message) + self.clientapp_output.context = context_from_proto(request.context) + # Set status + code = typing.ClientAppOutputCode.SUCCESS + status = typing.ClientAppOutputStatus(code=code, message="Success") + proto_status = clientappstatus_to_proto(status=status) + return PushClientAppOutputsResponse(status=proto_status) + except Exception as e: # pylint: disable=broad-exception-caught + log(ERROR, "ClientApp failed to push message to SuperNode, %s", e) + code = typing.ClientAppOutputCode.UNKNOWN_ERROR + status = typing.ClientAppOutputStatus(code=code, message="Push failed") + proto_status = clientappstatus_to_proto(status=status) + return PushClientAppOutputsResponse(status=proto_status) + + def set_inputs(self, clientapp_input: ClientAppIoInputs) -> None: + """Set ClientApp inputs.""" + log(DEBUG, "ClientAppIo.SetInputs") + if self.clientapp_input is not None or self.clientapp_output is not None: + raise ValueError( + "ClientAppIoInputs and ClientAppIoOutputs must not be set before " + "calling `set_inputs`." + ) + self.clientapp_input = clientapp_input + + def get_outputs(self) -> ClientAppIoOutputs: + """Get ClientApp outputs.""" + log(DEBUG, "ClientAppIo.GetOutputs") + if self.clientapp_output is None: + raise ValueError("ClientAppIoOutputs not set before calling `get_outputs`.") + # Set outputs to a local variable and clear self.clientapp_output + output: ClientAppIoOutputs = self.clientapp_output + self.clientapp_input = None + self.clientapp_output = None + return output diff --git a/src/py/flwr/client/process/clientappio_servicer_test.py b/src/py/flwr/client/process/clientappio_servicer_test.py new file mode 100644 index 000000000000..b06e9eb0e1c0 --- /dev/null +++ b/src/py/flwr/client/process/clientappio_servicer_test.py @@ -0,0 +1,118 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test the ClientAppIo API servicer.""" + +import unittest + +from flwr.common import Context, Message, typing +from flwr.common.serde_test import RecordMaker + +from .clientappio_servicer import ( + ClientAppIoInputs, + ClientAppIoOutputs, + ClientAppIoServicer, +) + + +class TestClientAppIoServicer(unittest.TestCase): + """Tests for `ClientAppIoServicer` class.""" + + def setUp(self) -> None: + """Initialize.""" + self.servicer = ClientAppIoServicer() + self.maker = RecordMaker() + + def tearDown(self) -> None: + """Cleanup.""" + + def test_set_inputs(self) -> None: + """Test setting ClientApp inputs.""" + # Prepare + message = Message( + metadata=self.maker.metadata(), + content=self.maker.recordset(2, 2, 1), + ) + context = Context( + node_id=1, + node_config={"nodeconfig1": 4.2}, + state=self.maker.recordset(2, 2, 1), + run_config={"runconfig1": 6.1}, + ) + run = typing.Run( + run_id=1, + fab_id="lorem", + fab_version="ipsum", + fab_hash="dolor", + override_config=self.maker.user_config(), + ) + client_input = ClientAppIoInputs(message, context, run, 1) + client_output = ClientAppIoOutputs(message, context) + + # Execute and assert + # - when ClientAppIoInputs is not None, ClientAppIoOutputs is None + with self.assertRaises(ValueError): + self.servicer.clientapp_input = client_input + self.servicer.clientapp_output = None + self.servicer.set_inputs(client_input) + + # Execute and assert + # - when ClientAppIoInputs is None, ClientAppIoOutputs is not None + with self.assertRaises(ValueError): + self.servicer.clientapp_input = None + self.servicer.clientapp_output = client_output + self.servicer.set_inputs(client_input) + + # Execute and assert + # - when ClientAppIoInputs and ClientAppIoOutputs is not None + with self.assertRaises(ValueError): + self.servicer.clientapp_input = client_input + self.servicer.clientapp_output = client_output + self.servicer.set_inputs(client_input) + + # Execute and assert + # - when ClientAppIoInputs is set at .clientapp_input + self.servicer.clientapp_input = None + self.servicer.clientapp_output = None + self.servicer.set_inputs(client_input) + assert client_input == self.servicer.clientapp_input + + def test_get_outputs(self) -> None: + """Test getting ClientApp outputs.""" + # Prepare + message = Message( + metadata=self.maker.metadata(), + content=self.maker.recordset(2, 2, 1), + ) + context = Context( + node_id=1, + node_config={"nodeconfig1": 4.2}, + state=self.maker.recordset(2, 2, 1), + run_config={"runconfig1": 6.1}, + ) + client_output = ClientAppIoOutputs(message, context) + + # Execute and assert - when `ClientAppIoOutputs` is None + self.servicer.clientapp_output = None + with self.assertRaises(ValueError): + # `ClientAppIoOutputs` should not be None + _ = self.servicer.get_outputs() + + # Execute and assert - when `ClientAppIoOutputs` is not None + self.servicer.clientapp_output = client_output + output = self.servicer.get_outputs() + assert isinstance(output, ClientAppIoOutputs) + assert output == client_output + assert self.servicer.clientapp_input is None + assert self.servicer.clientapp_output is None