Skip to content

Commit

Permalink
feat(framework) Add ClientAppIo servicer (#3976)
Browse files Browse the repository at this point in the history
  • Loading branch information
chongshenng authored Aug 15, 2024
1 parent 480f683 commit 9c85be5
Show file tree
Hide file tree
Showing 3 changed files with 278 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/py/flwr/client/process/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower AppIO service."""
145 changes: 145 additions & 0 deletions src/py/flwr/client/process/clientappio_servicer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ClientAppIo API servicer."""


from dataclasses import dataclass
from logging import DEBUG, ERROR
from typing import Optional

import grpc

from flwr.common import Context, Message, typing
from flwr.common.logger import log
from flwr.common.serde import (
clientappstatus_to_proto,
context_from_proto,
context_to_proto,
message_from_proto,
message_to_proto,
run_to_proto,
)
from flwr.common.typing import Run

# pylint: disable=E0611
from flwr.proto import clientappio_pb2_grpc
from flwr.proto.clientappio_pb2 import ( # pylint: disable=E0401
PullClientAppInputsRequest,
PullClientAppInputsResponse,
PushClientAppOutputsRequest,
PushClientAppOutputsResponse,
)


@dataclass
class ClientAppIoInputs:
"""Specify the inputs to the ClientApp."""

message: Message
context: Context
run: Run
token: int


@dataclass
class ClientAppIoOutputs:
"""Specify the outputs from the ClientApp."""

message: Message
context: Context


# pylint: disable=C0103,W0613,W0201
class ClientAppIoServicer(clientappio_pb2_grpc.ClientAppIoServicer):
"""ClientAppIo API servicer."""

def __init__(self) -> None:
self.clientapp_input: Optional[ClientAppIoInputs] = None
self.clientapp_output: Optional[ClientAppIoOutputs] = None

def PullClientAppInputs(
self, request: PullClientAppInputsRequest, context: grpc.ServicerContext
) -> PullClientAppInputsResponse:
"""Pull Message, Context, and Run."""
log(DEBUG, "ClientAppIo.PullClientAppInputs")
if self.clientapp_input is None:
raise ValueError(
"ClientAppIoInputs not set before calling `PullClientAppInputs`."
)
if request.token != self.clientapp_input.token:
context.abort(
grpc.StatusCode.INVALID_ARGUMENT,
"Mismatch between ClientApp and SuperNode token",
)
return PullClientAppInputsResponse(
message=message_to_proto(self.clientapp_input.message),
context=context_to_proto(self.clientapp_input.context),
run=run_to_proto(self.clientapp_input.run),
)

def PushClientAppOutputs(
self, request: PushClientAppOutputsRequest, context: grpc.ServicerContext
) -> PushClientAppOutputsResponse:
"""Push Message and Context."""
log(DEBUG, "ClientAppIo.PushClientAppOutputs")
if self.clientapp_output is None:
raise ValueError(
"ClientAppIoOutputs not set before calling `PushClientAppOutputs`."
)
if self.clientapp_input is None:
raise ValueError(
"ClientAppIoInputs not set before calling `PushClientAppOutputs`."
)
if request.token != self.clientapp_input.token:
context.abort(
grpc.StatusCode.INVALID_ARGUMENT,
"Mismatch between ClientApp and SuperNode token",
)
try:
# Update Message and Context
self.clientapp_output.message = message_from_proto(request.message)
self.clientapp_output.context = context_from_proto(request.context)
# Set status
code = typing.ClientAppOutputCode.SUCCESS
status = typing.ClientAppOutputStatus(code=code, message="Success")
proto_status = clientappstatus_to_proto(status=status)
return PushClientAppOutputsResponse(status=proto_status)
except Exception as e: # pylint: disable=broad-exception-caught
log(ERROR, "ClientApp failed to push message to SuperNode, %s", e)
code = typing.ClientAppOutputCode.UNKNOWN_ERROR
status = typing.ClientAppOutputStatus(code=code, message="Push failed")
proto_status = clientappstatus_to_proto(status=status)
return PushClientAppOutputsResponse(status=proto_status)

def set_inputs(self, clientapp_input: ClientAppIoInputs) -> None:
"""Set ClientApp inputs."""
log(DEBUG, "ClientAppIo.SetInputs")
if self.clientapp_input is not None or self.clientapp_output is not None:
raise ValueError(
"ClientAppIoInputs and ClientAppIoOutputs must not be set before "
"calling `set_inputs`."
)
self.clientapp_input = clientapp_input

def get_outputs(self) -> ClientAppIoOutputs:
"""Get ClientApp outputs."""
log(DEBUG, "ClientAppIo.GetOutputs")
if self.clientapp_output is None:
raise ValueError("ClientAppIoOutputs not set before calling `get_outputs`.")
# Set outputs to a local variable and clear self.clientapp_output
output: ClientAppIoOutputs = self.clientapp_output
self.clientapp_input = None
self.clientapp_output = None
return output
118 changes: 118 additions & 0 deletions src/py/flwr/client/process/clientappio_servicer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test the ClientAppIo API servicer."""

import unittest

from flwr.common import Context, Message, typing
from flwr.common.serde_test import RecordMaker

from .clientappio_servicer import (
ClientAppIoInputs,
ClientAppIoOutputs,
ClientAppIoServicer,
)


class TestClientAppIoServicer(unittest.TestCase):
"""Tests for `ClientAppIoServicer` class."""

def setUp(self) -> None:
"""Initialize."""
self.servicer = ClientAppIoServicer()
self.maker = RecordMaker()

def tearDown(self) -> None:
"""Cleanup."""

def test_set_inputs(self) -> None:
"""Test setting ClientApp inputs."""
# Prepare
message = Message(
metadata=self.maker.metadata(),
content=self.maker.recordset(2, 2, 1),
)
context = Context(
node_id=1,
node_config={"nodeconfig1": 4.2},
state=self.maker.recordset(2, 2, 1),
run_config={"runconfig1": 6.1},
)
run = typing.Run(
run_id=1,
fab_id="lorem",
fab_version="ipsum",
fab_hash="dolor",
override_config=self.maker.user_config(),
)
client_input = ClientAppIoInputs(message, context, run, 1)
client_output = ClientAppIoOutputs(message, context)

# Execute and assert
# - when ClientAppIoInputs is not None, ClientAppIoOutputs is None
with self.assertRaises(ValueError):
self.servicer.clientapp_input = client_input
self.servicer.clientapp_output = None
self.servicer.set_inputs(client_input)

# Execute and assert
# - when ClientAppIoInputs is None, ClientAppIoOutputs is not None
with self.assertRaises(ValueError):
self.servicer.clientapp_input = None
self.servicer.clientapp_output = client_output
self.servicer.set_inputs(client_input)

# Execute and assert
# - when ClientAppIoInputs and ClientAppIoOutputs is not None
with self.assertRaises(ValueError):
self.servicer.clientapp_input = client_input
self.servicer.clientapp_output = client_output
self.servicer.set_inputs(client_input)

# Execute and assert
# - when ClientAppIoInputs is set at .clientapp_input
self.servicer.clientapp_input = None
self.servicer.clientapp_output = None
self.servicer.set_inputs(client_input)
assert client_input == self.servicer.clientapp_input

def test_get_outputs(self) -> None:
"""Test getting ClientApp outputs."""
# Prepare
message = Message(
metadata=self.maker.metadata(),
content=self.maker.recordset(2, 2, 1),
)
context = Context(
node_id=1,
node_config={"nodeconfig1": 4.2},
state=self.maker.recordset(2, 2, 1),
run_config={"runconfig1": 6.1},
)
client_output = ClientAppIoOutputs(message, context)

# Execute and assert - when `ClientAppIoOutputs` is None
self.servicer.clientapp_output = None
with self.assertRaises(ValueError):
# `ClientAppIoOutputs` should not be None
_ = self.servicer.get_outputs()

# Execute and assert - when `ClientAppIoOutputs` is not None
self.servicer.clientapp_output = client_output
output = self.servicer.get_outputs()
assert isinstance(output, ClientAppIoOutputs)
assert output == client_output
assert self.servicer.clientapp_input is None
assert self.servicer.clientapp_output is None

0 comments on commit 9c85be5

Please sign in to comment.