Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace Fwd/Bwd with Message/Context #2842

Merged
merged 21 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from flwr.client.client import Client
from flwr.client.flower import Flower
from flwr.client.typing import Bwd, ClientFn, Fwd
from flwr.client.typing import ClientFn
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
from flwr.common.address import parse_address
from flwr.common.constant import (
Expand All @@ -35,6 +35,7 @@
TRANSPORT_TYPES,
)
from flwr.common.logger import log, warn_experimental_feature
from flwr.common.serde import message_from_taskins, message_to_taskres
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611

from .flower import load_flower_callable
Expand Down Expand Up @@ -351,27 +352,32 @@ def _load_app() -> Flower:
send(task_res)
break

# Register state
# Register context for this run
node_state.register_context(run_id=task_ins.run_id)

# Retrieve context for this run
context = node_state.retrieve_context(run_id=task_ins.run_id)

# Get Message from TaskIns
message = message_from_taskins(task_ins)

# Load app
app: Flower = load_flower_callable_fn()

# Handle task message
fwd_msg: Fwd = Fwd(
task_ins=task_ins,
context=node_state.retrieve_context(run_id=task_ins.run_id),
)
bwd_msg: Bwd = app(fwd=fwd_msg)
out_message = app(message=message, context=context)

# Update node state
node_state.update_context(
run_id=fwd_msg.task_ins.run_id,
context=bwd_msg.context,
run_id=message.metadata.run_id,
context=context,
)

# Construct TaskRes from out_message
task_res = message_to_taskres(out_message)

# Send
send(bwd_msg.task_res)
send(task_res)

# Unregister node
if delete_node is not None:
Expand Down
25 changes: 15 additions & 10 deletions src/py/flwr/client/flower.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@
import importlib
from typing import List, Optional, cast

from flwr.client.message_handler.message_handler import handle
from flwr.client.message_handler.message_handler import (
handle_legacy_message_from_tasktype,
)
from flwr.client.middleware.utils import make_ffn
from flwr.client.typing import Bwd, ClientFn, Fwd, Layer
from flwr.client.typing import ClientFn, Layer
from flwr.common.context import Context
from flwr.common.message import Message


class Flower:
Expand Down Expand Up @@ -55,20 +59,21 @@ def __init__(
layers: Optional[List[Layer]] = None,
) -> None:
# Create wrapper function for `handle`
def ffn(fwd: Fwd) -> Bwd: # pylint: disable=invalid-name
task_res, context_updated = handle(
client_fn=client_fn,
context=fwd.context,
task_ins=fwd.task_ins,
def ffn(
message: Message,
context: Context,
) -> Message: # pylint: disable=invalid-name
out_message = handle_legacy_message_from_tasktype(
client_fn=client_fn, message=message, context=context
)
return Bwd(task_res=task_res, context=context_updated)
return out_message

# Wrap middleware layers around the wrapped handle function
self._call = make_ffn(ffn, layers if layers is not None else [])

def __call__(self, fwd: Fwd) -> Bwd:
def __call__(self, message: Message, context: Context) -> Message:
"""."""
return self._call(fwd)
return self._call(message, context)


class LoadCallableError(Exception):
Expand Down
94 changes: 88 additions & 6 deletions src/py/flwr/client/grpc_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,26 @@
from logging import DEBUG
from pathlib import Path
from queue import Queue
from typing import Callable, Iterator, Optional, Tuple, Union
from typing import Callable, Iterator, Optional, Tuple, Union, cast

from flwr.common import GRPC_MAX_MESSAGE_LENGTH
from flwr.common import recordset_compat as compat
from flwr.common import serde
from flwr.common.configsrecord import ConfigsRecord
from flwr.common.constant import (
TASK_TYPE_EVALUATE,
TASK_TYPE_FIT,
TASK_TYPE_GET_PARAMETERS,
TASK_TYPE_GET_PROPERTIES,
)
from flwr.common.grpc import create_channel
from flwr.common.logger import log
from flwr.common.recordset import RecordSet
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
ClientMessage,
Reason,
ServerMessage,
)
from flwr.proto.transport_pb2_grpc import FlowerServiceStub # pylint: disable=E0611
Expand All @@ -46,7 +57,7 @@ def on_channel_state_change(channel_connectivity: str) -> None:


@contextmanager
def grpc_connection(
def grpc_connection( # pylint: disable=R0915
server_address: str,
insecure: bool,
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
Expand Down Expand Up @@ -118,7 +129,48 @@ def grpc_connection(
server_message_iterator: Iterator[ServerMessage] = stub.Join(iter(queue.get, None))

def receive() -> TaskIns:
server_message = next(server_message_iterator)
# Receive ServerMessage proto
proto = next(server_message_iterator)

# ServerMessage proto --> *Ins --> RecordSet
field = proto.WhichOneof("msg")
task_type = ""
if field == "get_properties_ins":
recordset = compat.getpropertiesins_to_recordset(
serde.get_properties_ins_from_proto(proto.get_properties_ins)
)
task_type = TASK_TYPE_GET_PROPERTIES
elif field == "get_parameters_ins":
recordset = compat.getparametersins_to_recordset(
serde.get_parameters_ins_from_proto(proto.get_parameters_ins)
)
task_type = TASK_TYPE_GET_PARAMETERS
elif field == "fit_ins":
recordset = compat.fitins_to_recordset(
serde.fit_ins_from_proto(proto.fit_ins), False
)
task_type = TASK_TYPE_FIT
elif field == "evaluate_ins":
recordset = compat.evaluateins_to_recordset(
serde.evaluate_ins_from_proto(proto.evaluate_ins), False
)
task_type = TASK_TYPE_EVALUATE
elif field == "reconnect_ins":
recordset = RecordSet()
recordset.set_configs(
"config", ConfigsRecord({"seconds": proto.reconnect_ins.seconds})
)
task_type = "reconnect"
else:
raise ValueError(
"Unsupported instruction in ServerMessage, "
"cannot deserialize from ProtoBuf"
)

# RecordSet --> RecordSet proto
recordset_proto = serde.recordset_to_proto(recordset)

# Construct TaskIns
return TaskIns(
task_id=str(uuid.uuid4()),
group_id="",
Expand All @@ -127,13 +179,43 @@ def receive() -> TaskIns:
producer=Node(node_id=0, anonymous=True),
consumer=Node(node_id=0, anonymous=True),
ancestry=[],
legacy_server_message=server_message,
task_type=task_type,
recordset=recordset_proto,
),
)

def send(task_res: TaskRes) -> None:
msg = task_res.task.legacy_client_message
return queue.put(msg, block=False)
# Retrieve RecordSet and task_type
recordset = serde.recordset_from_proto(task_res.task.recordset)
task_type = task_res.task.task_type

# RecordSet --> *Res --> *Res proto -> ClientMessage proto
if task_type == TASK_TYPE_GET_PROPERTIES:
getpropres = compat.recordset_to_getpropertiesres(recordset)
msg_proto = ClientMessage(
get_properties_res=serde.get_properties_res_to_proto(getpropres)
)
elif task_type == TASK_TYPE_GET_PARAMETERS:
getparamres = compat.recordset_to_getparametersres(recordset, False)
msg_proto = ClientMessage(
get_parameters_res=serde.get_parameters_res_to_proto(getparamres)
)
elif task_type == TASK_TYPE_FIT:
fitres = compat.recordset_to_fitres(recordset, False)
msg_proto = ClientMessage(fit_res=serde.fit_res_to_proto(fitres))
elif task_type == TASK_TYPE_EVALUATE:
evalres = compat.recordset_to_evaluateres(recordset)
msg_proto = ClientMessage(evaluate_res=serde.evaluate_res_to_proto(evalres))
elif task_type == "reconnect":
reason = cast(Reason.ValueType, recordset.get_configs("config")["reason"])
msg_proto = ClientMessage(
disconnect_res=ClientMessage.DisconnectRes(reason=reason)
)
else:
raise ValueError(f"Invalid task type: {task_type}")

# Send ClientMessage proto
return queue.put(msg_proto, block=False)

try:
# Yield methods
Expand Down
45 changes: 22 additions & 23 deletions src/py/flwr/client/grpc_client/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@

import grpc

from flwr.common import recordset_compat as compat
from flwr.common import serde
from flwr.common.configsrecord import ConfigsRecord
from flwr.common.constant import TASK_TYPE_GET_PROPERTIES
from flwr.common.recordset import RecordSet
from flwr.common.typing import Code, GetPropertiesRes, Status
from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
ClientMessage,
Expand All @@ -35,11 +41,21 @@

EXPECTED_NUM_SERVER_MESSAGE = 10

SERVER_MESSAGE = ServerMessage()
SERVER_MESSAGE = ServerMessage(get_properties_ins=ServerMessage.GetPropertiesIns())
SERVER_MESSAGE_RECONNECT = ServerMessage(reconnect_ins=ServerMessage.ReconnectIns())

CLIENT_MESSAGE = ClientMessage()
CLIENT_MESSAGE_DISCONNECT = ClientMessage(disconnect_res=ClientMessage.DisconnectRes())
TASK_GET_PROPERTIES = Task(
task_type=TASK_TYPE_GET_PROPERTIES,
recordset=serde.recordset_to_proto(
compat.getpropertiesres_to_recordset(GetPropertiesRes(Status(Code.OK, ""), {}))
),
)
TASK_DISCONNECT = Task(
task_type="reconnect",
recordset=serde.recordset_to_proto(
RecordSet(configs={"config": ConfigsRecord({"reason": 0})})
),
)


def unused_tcp_port() -> int:
Expand Down Expand Up @@ -104,31 +120,14 @@ def run_client() -> int:
# Block until server responds with a message
task_ins = receive()

if task_ins is None:
raise ValueError("Unexpected None value")

# pylint: disable=no-member
if task_ins.HasField("task") and task_ins.task.HasField(
"legacy_server_message"
):
server_message = task_ins.task.legacy_server_message
else:
server_message = None
# pylint: enable=no-member

if server_message is None:
raise ValueError("Unexpected None value")

messages_received += 1
if server_message.HasField("reconnect_ins"):
task_res = TaskRes(
task=Task(legacy_client_message=CLIENT_MESSAGE_DISCONNECT)
)
if task_ins.task.task_type == "reconnect": # type: ignore
task_res = TaskRes(task=TASK_DISCONNECT)
send(task_res)
break

# Process server_message and send client_message...
task_res = TaskRes(task=Task(legacy_client_message=CLIENT_MESSAGE))
task_res = TaskRes(task=TASK_GET_PROPERTIES)
send(task_res)

return messages_received
Expand Down
4 changes: 1 addition & 3 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ def receive() -> Optional[TaskIns]:
task_ins: Optional[TaskIns] = get_task_ins(response)

# Discard the current TaskIns if not valid
if task_ins is not None and not validate_task_ins(
task_ins, discard_reconnect_ins=True
):
if task_ins is not None and not validate_task_ins(task_ins):
task_ins = None

# Remember `task_ins` until `task_res` is available
Expand Down
Loading
Loading