-
Notifications
You must be signed in to change notification settings - Fork 941
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(framework) Add
ClientAppIo
servicer (#3976)
- Loading branch information
1 parent
480f683
commit 9c85be5
Showing
3 changed files
with
278 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright 2024 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Flower AppIO service.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# Copyright 2024 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""ClientAppIo API servicer.""" | ||
|
||
|
||
from dataclasses import dataclass | ||
from logging import DEBUG, ERROR | ||
from typing import Optional | ||
|
||
import grpc | ||
|
||
from flwr.common import Context, Message, typing | ||
from flwr.common.logger import log | ||
from flwr.common.serde import ( | ||
clientappstatus_to_proto, | ||
context_from_proto, | ||
context_to_proto, | ||
message_from_proto, | ||
message_to_proto, | ||
run_to_proto, | ||
) | ||
from flwr.common.typing import Run | ||
|
||
# pylint: disable=E0611 | ||
from flwr.proto import clientappio_pb2_grpc | ||
from flwr.proto.clientappio_pb2 import ( # pylint: disable=E0401 | ||
PullClientAppInputsRequest, | ||
PullClientAppInputsResponse, | ||
PushClientAppOutputsRequest, | ||
PushClientAppOutputsResponse, | ||
) | ||
|
||
|
||
@dataclass | ||
class ClientAppIoInputs: | ||
"""Specify the inputs to the ClientApp.""" | ||
|
||
message: Message | ||
context: Context | ||
run: Run | ||
token: int | ||
|
||
|
||
@dataclass | ||
class ClientAppIoOutputs: | ||
"""Specify the outputs from the ClientApp.""" | ||
|
||
message: Message | ||
context: Context | ||
|
||
|
||
# pylint: disable=C0103,W0613,W0201 | ||
class ClientAppIoServicer(clientappio_pb2_grpc.ClientAppIoServicer): | ||
"""ClientAppIo API servicer.""" | ||
|
||
def __init__(self) -> None: | ||
self.clientapp_input: Optional[ClientAppIoInputs] = None | ||
self.clientapp_output: Optional[ClientAppIoOutputs] = None | ||
|
||
def PullClientAppInputs( | ||
self, request: PullClientAppInputsRequest, context: grpc.ServicerContext | ||
) -> PullClientAppInputsResponse: | ||
"""Pull Message, Context, and Run.""" | ||
log(DEBUG, "ClientAppIo.PullClientAppInputs") | ||
if self.clientapp_input is None: | ||
raise ValueError( | ||
"ClientAppIoInputs not set before calling `PullClientAppInputs`." | ||
) | ||
if request.token != self.clientapp_input.token: | ||
context.abort( | ||
grpc.StatusCode.INVALID_ARGUMENT, | ||
"Mismatch between ClientApp and SuperNode token", | ||
) | ||
return PullClientAppInputsResponse( | ||
message=message_to_proto(self.clientapp_input.message), | ||
context=context_to_proto(self.clientapp_input.context), | ||
run=run_to_proto(self.clientapp_input.run), | ||
) | ||
|
||
def PushClientAppOutputs( | ||
self, request: PushClientAppOutputsRequest, context: grpc.ServicerContext | ||
) -> PushClientAppOutputsResponse: | ||
"""Push Message and Context.""" | ||
log(DEBUG, "ClientAppIo.PushClientAppOutputs") | ||
if self.clientapp_output is None: | ||
raise ValueError( | ||
"ClientAppIoOutputs not set before calling `PushClientAppOutputs`." | ||
) | ||
if self.clientapp_input is None: | ||
raise ValueError( | ||
"ClientAppIoInputs not set before calling `PushClientAppOutputs`." | ||
) | ||
if request.token != self.clientapp_input.token: | ||
context.abort( | ||
grpc.StatusCode.INVALID_ARGUMENT, | ||
"Mismatch between ClientApp and SuperNode token", | ||
) | ||
try: | ||
# Update Message and Context | ||
self.clientapp_output.message = message_from_proto(request.message) | ||
self.clientapp_output.context = context_from_proto(request.context) | ||
# Set status | ||
code = typing.ClientAppOutputCode.SUCCESS | ||
status = typing.ClientAppOutputStatus(code=code, message="Success") | ||
proto_status = clientappstatus_to_proto(status=status) | ||
return PushClientAppOutputsResponse(status=proto_status) | ||
except Exception as e: # pylint: disable=broad-exception-caught | ||
log(ERROR, "ClientApp failed to push message to SuperNode, %s", e) | ||
code = typing.ClientAppOutputCode.UNKNOWN_ERROR | ||
status = typing.ClientAppOutputStatus(code=code, message="Push failed") | ||
proto_status = clientappstatus_to_proto(status=status) | ||
return PushClientAppOutputsResponse(status=proto_status) | ||
|
||
def set_inputs(self, clientapp_input: ClientAppIoInputs) -> None: | ||
"""Set ClientApp inputs.""" | ||
log(DEBUG, "ClientAppIo.SetInputs") | ||
if self.clientapp_input is not None or self.clientapp_output is not None: | ||
raise ValueError( | ||
"ClientAppIoInputs and ClientAppIoOutputs must not be set before " | ||
"calling `set_inputs`." | ||
) | ||
self.clientapp_input = clientapp_input | ||
|
||
def get_outputs(self) -> ClientAppIoOutputs: | ||
"""Get ClientApp outputs.""" | ||
log(DEBUG, "ClientAppIo.GetOutputs") | ||
if self.clientapp_output is None: | ||
raise ValueError("ClientAppIoOutputs not set before calling `get_outputs`.") | ||
# Set outputs to a local variable and clear self.clientapp_output | ||
output: ClientAppIoOutputs = self.clientapp_output | ||
self.clientapp_input = None | ||
self.clientapp_output = None | ||
return output |
118 changes: 118 additions & 0 deletions
118
src/py/flwr/client/process/clientappio_servicer_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# Copyright 2024 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Test the ClientAppIo API servicer.""" | ||
|
||
import unittest | ||
|
||
from flwr.common import Context, Message, typing | ||
from flwr.common.serde_test import RecordMaker | ||
|
||
from .clientappio_servicer import ( | ||
ClientAppIoInputs, | ||
ClientAppIoOutputs, | ||
ClientAppIoServicer, | ||
) | ||
|
||
|
||
class TestClientAppIoServicer(unittest.TestCase): | ||
"""Tests for `ClientAppIoServicer` class.""" | ||
|
||
def setUp(self) -> None: | ||
"""Initialize.""" | ||
self.servicer = ClientAppIoServicer() | ||
self.maker = RecordMaker() | ||
|
||
def tearDown(self) -> None: | ||
"""Cleanup.""" | ||
|
||
def test_set_inputs(self) -> None: | ||
"""Test setting ClientApp inputs.""" | ||
# Prepare | ||
message = Message( | ||
metadata=self.maker.metadata(), | ||
content=self.maker.recordset(2, 2, 1), | ||
) | ||
context = Context( | ||
node_id=1, | ||
node_config={"nodeconfig1": 4.2}, | ||
state=self.maker.recordset(2, 2, 1), | ||
run_config={"runconfig1": 6.1}, | ||
) | ||
run = typing.Run( | ||
run_id=1, | ||
fab_id="lorem", | ||
fab_version="ipsum", | ||
fab_hash="dolor", | ||
override_config=self.maker.user_config(), | ||
) | ||
client_input = ClientAppIoInputs(message, context, run, 1) | ||
client_output = ClientAppIoOutputs(message, context) | ||
|
||
# Execute and assert | ||
# - when ClientAppIoInputs is not None, ClientAppIoOutputs is None | ||
with self.assertRaises(ValueError): | ||
self.servicer.clientapp_input = client_input | ||
self.servicer.clientapp_output = None | ||
self.servicer.set_inputs(client_input) | ||
|
||
# Execute and assert | ||
# - when ClientAppIoInputs is None, ClientAppIoOutputs is not None | ||
with self.assertRaises(ValueError): | ||
self.servicer.clientapp_input = None | ||
self.servicer.clientapp_output = client_output | ||
self.servicer.set_inputs(client_input) | ||
|
||
# Execute and assert | ||
# - when ClientAppIoInputs and ClientAppIoOutputs is not None | ||
with self.assertRaises(ValueError): | ||
self.servicer.clientapp_input = client_input | ||
self.servicer.clientapp_output = client_output | ||
self.servicer.set_inputs(client_input) | ||
|
||
# Execute and assert | ||
# - when ClientAppIoInputs is set at .clientapp_input | ||
self.servicer.clientapp_input = None | ||
self.servicer.clientapp_output = None | ||
self.servicer.set_inputs(client_input) | ||
assert client_input == self.servicer.clientapp_input | ||
|
||
def test_get_outputs(self) -> None: | ||
"""Test getting ClientApp outputs.""" | ||
# Prepare | ||
message = Message( | ||
metadata=self.maker.metadata(), | ||
content=self.maker.recordset(2, 2, 1), | ||
) | ||
context = Context( | ||
node_id=1, | ||
node_config={"nodeconfig1": 4.2}, | ||
state=self.maker.recordset(2, 2, 1), | ||
run_config={"runconfig1": 6.1}, | ||
) | ||
client_output = ClientAppIoOutputs(message, context) | ||
|
||
# Execute and assert - when `ClientAppIoOutputs` is None | ||
self.servicer.clientapp_output = None | ||
with self.assertRaises(ValueError): | ||
# `ClientAppIoOutputs` should not be None | ||
_ = self.servicer.get_outputs() | ||
|
||
# Execute and assert - when `ClientAppIoOutputs` is not None | ||
self.servicer.clientapp_output = client_output | ||
output = self.servicer.get_outputs() | ||
assert isinstance(output, ClientAppIoOutputs) | ||
assert output == client_output | ||
assert self.servicer.clientapp_input is None | ||
assert self.servicer.clientapp_output is None |