Skip to content

Commit

Permalink
v0
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq committed Jan 23, 2024
1 parent 214d1c8 commit 81db36f
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 38 deletions.
40 changes: 29 additions & 11 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from flwr.client.client import Client
from flwr.client.flower import Flower
from flwr.client.typing import Bwd, ClientFn, Fwd
from flwr.client.typing import ClientFn
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
from flwr.common.address import parse_address
from flwr.common.constant import (
Expand All @@ -34,8 +34,11 @@
TRANSPORT_TYPE_REST,
TRANSPORT_TYPES,
)
from flwr.common.flowercontext import FlowerContext, Metadata
from flwr.common.logger import log, warn_experimental_feature
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
from flwr.common.recordset import RecordSet
from flwr.common.serde import recordset_to_proto
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611

from .flower import load_flower_callable
from .grpc_client.connection import grpc_connection
Expand Down Expand Up @@ -323,6 +326,15 @@ def _load_app() -> Flower:
connection, address = _init_connection(transport, server_address)

node_state = NodeState()
# TODO: remove NodeState/RunState logic ?

# TODO: initialize context here?
context = FlowerContext(
in_message=RecordSet(),
out_message=RecordSet(),
local=RecordSet(),
metadata=Metadata(run_id=-1, task_id="", group_id="", ttl="", task_type=""),
)

while True:
sleep_duration: int = 0
Expand Down Expand Up @@ -354,24 +366,30 @@ def _load_app() -> Flower:
# Register state
node_state.register_runstate(run_id=task_ins.run_id)

# TODO: pulate context.metadata and context.in_message from TaskIns

# Load app
app: Flower = load_flower_callable_fn()

# Handle task message
fwd_msg: Fwd = Fwd(
task_ins=task_ins,
state=node_state.retrieve_runstate(run_id=task_ins.run_id),
)
bwd_msg: Bwd = app(fwd=fwd_msg)
context_ = app(context=context)

# Update node state
node_state.update_runstate(
run_id=bwd_msg.task_res.run_id,
run_state=bwd_msg.state,
# node_state.update_runstate(
# run_id=bwd_msg.task_res.run_id,
# run_state=bwd_msg.state,
# )

# TODO: Construct TaskRes from context.out_message
task_res = TaskRes(
task_id=context_.metadata.task_id,
group_id=context_.metadata.group_id,
run_id=context_.metadata.run_id,
task=Task(recordset=recordset_to_proto(context_.out_message)),
)

# Send
send(bwd_msg.task_res)
send(task_res)

# Unregister node
if delete_node is not None:
Expand Down
23 changes: 13 additions & 10 deletions src/py/flwr/client/flower.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
import importlib
from typing import List, Optional, cast

from flwr.client.message_handler.message_handler import handle
from flwr.client.message_handler.message_handler import (
handle_legacy_message_from_tasktype,
)
from flwr.client.middleware.utils import make_ffn
from flwr.client.typing import Bwd, ClientFn, Fwd, Layer
from flwr.client.typing import ClientFn, Layer
from flwr.common.flowercontext import FlowerContext


class Flower:
Expand Down Expand Up @@ -55,20 +58,20 @@ def __init__(
layers: Optional[List[Layer]] = None,
) -> None:
# Create wrapper function for `handle`
def ffn(fwd: Fwd) -> Bwd: # pylint: disable=invalid-name
task_res, state_updated = handle(
client_fn=client_fn,
state=fwd.state,
task_ins=fwd.task_ins,
def ffn(
context: FlowerContext,
) -> FlowerContext: # pylint: disable=invalid-name
context = handle_legacy_message_from_tasktype(
client_fn=client_fn, context=context
)
return Bwd(task_res=task_res, state=state_updated)
return context

# Wrap middleware layers around the wrapped handle function
self._call = make_ffn(ffn, layers if layers is not None else [])

def __call__(self, fwd: Fwd) -> Bwd:
def __call__(self, context: FlowerContext) -> FlowerContext:
"""."""
return self._call(fwd)
return self._call(context)


class LoadCallableError(Exception):
Expand Down
49 changes: 49 additions & 0 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@
from flwr.client.secure_aggregation import SecureAggregationHandler
from flwr.client.typing import ClientFn
from flwr.common import serde
from flwr.common.flowercontext import FlowerContext
from flwr.common.recordset_compat import (
evaluateres_to_recordset,
fitres_to_recordset,
getparametersres_to_recordset,
getpropertiesres_to_recordset,
recordset_to_evaluateins,
recordset_to_fitins,
recordset_to_getparametersins,
recordset_to_getpropertiesins,
)
from flwr.proto.task_pb2 import ( # pylint: disable=E0611
SecureAggregation,
Task,
Expand Down Expand Up @@ -177,6 +188,44 @@ def handle_legacy_message(
raise UnknownServerMessage()


def handle_legacy_message_from_tasktype(
client_fn: ClientFn, context: FlowerContext
) -> FlowerContext:
"""Handle legacy message in the inner most middleware layer."""
client = client_fn("-1")
task_type = context.metadata.task_type

if task_type == "get_properties_ins":
get_properties_res = maybe_call_get_properties(
client=client,
get_properties_ins=recordset_to_getpropertiesins(context.in_message),
)
context.out_message = getpropertiesres_to_recordset(get_properties_res)
elif task_type == "get_parameteres_ins":
get_parameters_res = maybe_call_get_parameters(
client=client,
get_parameters_ins=recordset_to_getparametersins(context.in_message),
)
context.out_message = getparametersres_to_recordset(get_parameters_res)
elif task_type == "fit_ins":
fit_res = maybe_call_fit(
client=client,
fit_ins=recordset_to_fitins(context.in_message, keep_input=False),
)
context.out_message = fitres_to_recordset(fit_res, keep_input=False)
elif task_type == "evaluate_ins":
evaluate_res = maybe_call_evaluate(
client=client,
evaluate_ins=recordset_to_evaluateins(context.in_message, keep_input=False),
)
context.out_message = evaluateres_to_recordset(evaluate_res)
else:
# TODO: what to do with reconnect?
print("do something")

return context


def _reconnect(
reconnect_msg: ServerMessage.ReconnectIns,
) -> Tuple[ClientMessage, int]:
Expand Down
7 changes: 4 additions & 3 deletions src/py/flwr/client/middleware/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@

from typing import List

from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer
from flwr.client.typing import FlowerCallable, Layer
from flwr.common.flowercontext import FlowerContext


def make_ffn(ffn: FlowerCallable, layers: List[Layer]) -> FlowerCallable:
"""."""

def wrap_ffn(_ffn: FlowerCallable, _layer: Layer) -> FlowerCallable:
def new_ffn(fwd: Fwd) -> Bwd:
return _layer(fwd, _ffn)
def new_ffn(context: FlowerContext) -> FlowerContext:
return _layer(context, _ffn)

return new_ffn

Expand Down
38 changes: 26 additions & 12 deletions src/py/flwr/client/middleware/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

from flwr.client.run_state import RunState
from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer
from flwr.common.configsrecord import ConfigsRecord
from flwr.common.flowercontext import FlowerContext, Metadata
from flwr.common.recordset import RecordSet
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611

from .utils import make_ffn
Expand All @@ -28,24 +31,25 @@
def make_mock_middleware(name: str, footprint: List[str]) -> Layer:
"""Make a mock middleware layer."""

def middleware(fwd: Fwd, app: FlowerCallable) -> Bwd:
def middleware(context: FlowerContext, app: FlowerCallable) -> FlowerContext:
footprint.append(name)
fwd.task_ins.task_id += f"{name}"
bwd = app(fwd)
context.in_message.set_configs(name=name, record=ConfigsRecord())
ctx: FlowerContext = app(context)
footprint.append(name)
bwd.task_res.task_id += f"{name}"
return bwd
ctx.out_message.set_configs(name=name, record=ConfigsRecord())
return ctx

return middleware


def make_mock_app(name: str, footprint: List[str]) -> FlowerCallable:
"""Make a mock app."""

def app(fwd: Fwd) -> Bwd:
def app(context: FlowerContext) -> FlowerContext:
footprint.append(name)
fwd.task_ins.task_id += f"{name}"
return Bwd(task_res=TaskRes(task_id=name), state=RunState({}))
context.in_message.set_configs(name=name, record=ConfigsRecord())
context.out_message.set_configs(name=name, record=ConfigsRecord())
return context

return app

Expand All @@ -62,18 +66,28 @@ def test_multiple_middlewares(self) -> None:
mock_middleware_layers = [
make_mock_middleware(name, footprint) for name in mock_middleware_names
]
task_ins = TaskIns()

context = FlowerContext(
in_message=RecordSet(),
out_message=RecordSet(),
local=RecordSet(),
metadata=Metadata(
run_id=0, task_id="", group_id="", ttl="", task_type="mock"
),
)

# Execute
wrapped_app = make_ffn(mock_app, mock_middleware_layers)
task_res = wrapped_app(Fwd(task_ins=task_ins, state=RunState({}))).task_res
context_ = wrapped_app(context)

# Assert
trace = mock_middleware_names + ["app"]
self.assertEqual(footprint, trace + list(reversed(mock_middleware_names)))
# pylint: disable-next=no-member
self.assertEqual(task_ins.task_id, "".join(trace))
self.assertEqual(task_res.task_id, "".join(reversed(trace)))
self.assertEqual("".join(context_.in_message.configs.keys()), "".join(trace))
self.assertEqual(
"".join(context_.out_message.configs.keys()), "".join(reversed(trace))
)

def test_filter(self) -> None:
"""Test if a middleware can filter incoming TaskIns."""
Expand Down
5 changes: 3 additions & 2 deletions src/py/flwr/client/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Callable

from flwr.client.run_state import RunState
from flwr.common.flowercontext import FlowerContext
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611

from .client import Client as Client
Expand All @@ -39,6 +40,6 @@ class Bwd:
state: RunState


FlowerCallable = Callable[[Fwd], Bwd]
FlowerCallable = Callable[[FlowerContext], FlowerContext]
ClientFn = Callable[[str], Client]
Layer = Callable[[Fwd, FlowerCallable], Bwd]
Layer = Callable[[FlowerContext, FlowerCallable], FlowerContext]

0 comments on commit 81db36f

Please sign in to comment.