Skip to content

Commit

Permalink
update handle
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Jan 26, 2024
1 parent 5ac3faa commit d25ebb2
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 332 deletions.
47 changes: 28 additions & 19 deletions src/py/flwr/client/grpc_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
from logging import DEBUG
from pathlib import Path
from queue import Queue
from typing import Callable, Iterator, Optional, Tuple, Union
from typing import Callable, Iterator, Optional, Tuple, Union, cast

from flwr.common import GRPC_MAX_MESSAGE_LENGTH
from flwr.common import recordset_compat as compat
from flwr.common import serde, typing
from flwr.common import serde
from flwr.common.configsrecord import ConfigsRecord
from flwr.common.constant import (
TASK_TYPE_EVALUATE,
TASK_TYPE_FIT,
Expand All @@ -33,10 +34,12 @@
)
from flwr.common.grpc import create_channel
from flwr.common.logger import log
from flwr.common.recordset import RecordSet
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
ClientMessage,
Reason,
ServerMessage,
)
from flwr.proto.transport_pb2_grpc import FlowerServiceStub # pylint: disable=E0611
Expand All @@ -54,7 +57,7 @@ def on_channel_state_change(channel_connectivity: str) -> None:


@contextmanager
def grpc_connection(
def grpc_connection( # pylint: disable=R0915
server_address: str,
insecure: bool,
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
Expand Down Expand Up @@ -152,6 +155,12 @@ def receive() -> TaskIns:
serde.evaluate_ins_from_proto(proto.evaluate_ins), False
)
task_type = TASK_TYPE_EVALUATE
elif field == "reconnect_ins":
recordset = RecordSet()
recordset.set_configs(
"config", ConfigsRecord({"seconds": proto.reconnect_ins.seconds})
)
task_type = "reconnect"
else:
raise ValueError(
"Unsupported instruction in ServerMessage, "
Expand Down Expand Up @@ -180,33 +189,33 @@ def send(task_res: TaskRes) -> None:
recordset = serde.recordset_from_proto(task_res.task.recordset)
task_type = task_res.task.task_type

# RecordSet --> *Res --> ClientMessage
# RecordSet --> *Res --> *Res proto -> ClientMessage proto
if task_type == TASK_TYPE_GET_PROPERTIES:
client_message = typing.ClientMessage(
get_properties_res=compat.recordset_to_getpropertiesres(recordset)
getpropres = compat.recordset_to_getpropertiesres(recordset)
msg_proto = ClientMessage(
get_properties_res=serde.get_properties_res_to_proto(getpropres)
)
elif task_type == TASK_TYPE_GET_PARAMETERS:
client_message = typing.ClientMessage(
get_parameters_res=compat.recordset_to_getparametersres(
recordset, False
)
getparamres = compat.recordset_to_getparametersres(recordset, False)
msg_proto = ClientMessage(
get_parameters_res=serde.get_parameters_res_to_proto(getparamres)
)
elif task_type == TASK_TYPE_FIT:
client_message = typing.ClientMessage(
fit_res=compat.recordset_to_fitres(recordset, False)
)
fitres = compat.recordset_to_fitres(recordset, False)
msg_proto = ClientMessage(fit_res=serde.fit_res_to_proto(fitres))
elif task_type == TASK_TYPE_EVALUATE:
client_message = typing.ClientMessage(
evaluate_res=compat.recordset_to_evaluateres(recordset)
evalres = compat.recordset_to_evaluateres(recordset)
msg_proto = ClientMessage(evaluate_res=serde.evaluate_res_to_proto(evalres))
elif task_type == "reconnect":
reason = cast(Reason.ValueType, recordset.get_configs("config")["reason"])
msg_proto = ClientMessage(
disconnect_res=ClientMessage.DisconnectRes(reason=reason)
)
else:
raise ValueError(f"Invalid task type: {task_type}")

# ClientMessage --> ClientMessage proto
client_message_proto = serde.client_message_to_proto(client_message)

# Send ClientMessage proto
return queue.put(client_message_proto, block=False)
return queue.put(msg_proto, block=False)

try:
# Yield methods
Expand Down
45 changes: 22 additions & 23 deletions src/py/flwr/client/grpc_client/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@

import grpc

from flwr.common import recordset_compat as compat
from flwr.common import serde
from flwr.common.configsrecord import ConfigsRecord
from flwr.common.constant import TASK_TYPE_GET_PROPERTIES
from flwr.common.recordset import RecordSet
from flwr.common.typing import Code, GetPropertiesRes, Status
from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
ClientMessage,
Expand All @@ -35,11 +41,21 @@

EXPECTED_NUM_SERVER_MESSAGE = 10

SERVER_MESSAGE = ServerMessage()
SERVER_MESSAGE = ServerMessage(get_properties_ins=ServerMessage.GetPropertiesIns())
SERVER_MESSAGE_RECONNECT = ServerMessage(reconnect_ins=ServerMessage.ReconnectIns())

CLIENT_MESSAGE = ClientMessage()
CLIENT_MESSAGE_DISCONNECT = ClientMessage(disconnect_res=ClientMessage.DisconnectRes())
TASK_GET_PROPERTIES = Task(
task_type=TASK_TYPE_GET_PROPERTIES,
recordset=serde.recordset_to_proto(
compat.getpropertiesres_to_recordset(GetPropertiesRes(Status(Code.OK, ""), {}))
),
)
TASK_DISCONNECT = Task(
task_type="reconnect",
recordset=serde.recordset_to_proto(
RecordSet(configs={"config": ConfigsRecord({"reason": 0})})
),
)


def unused_tcp_port() -> int:
Expand Down Expand Up @@ -104,31 +120,14 @@ def run_client() -> int:
# Block until server responds with a message
task_ins = receive()

if task_ins is None:
raise ValueError("Unexpected None value")

# pylint: disable=no-member
if task_ins.HasField("task") and task_ins.task.HasField(
"legacy_server_message"
):
server_message = task_ins.task.legacy_server_message
else:
server_message = None
# pylint: enable=no-member

if server_message is None:
raise ValueError("Unexpected None value")

messages_received += 1
if server_message.HasField("reconnect_ins"):
task_res = TaskRes(
task=Task(legacy_client_message=CLIENT_MESSAGE_DISCONNECT)
)
if task_ins.task.task_type == "reconnect": # type: ignore
task_res = TaskRes(task=TASK_DISCONNECT)
send(task_res)
break

# Process server_message and send client_message...
task_res = TaskRes(task=Task(legacy_client_message=CLIENT_MESSAGE))
task_res = TaskRes(task=TASK_GET_PROPERTIES)
send(task_res)

return messages_received
Expand Down
Loading

0 comments on commit d25ebb2

Please sign in to comment.