Skip to content

Commit

Permalink
update send and receive for grpc-bidi client
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Jan 26, 2024
1 parent 6aa796d commit 69a7a17
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 12 deletions.
81 changes: 77 additions & 4 deletions src/py/flwr/client/grpc_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
from typing import Callable, Iterator, Optional, Tuple, Union

from flwr.common import GRPC_MAX_MESSAGE_LENGTH
from flwr.common import recordset_compat as compat
from flwr.common import serde, typing
from flwr.common.constant import (
TASK_TYPE_EVALUATE,
TASK_TYPE_FIT,
TASK_TYPE_GET_PARAMETERS,
TASK_TYPE_GET_PROPERTIES,
)
from flwr.common.grpc import create_channel
from flwr.common.logger import log
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
Expand Down Expand Up @@ -118,7 +126,42 @@ def grpc_connection(
server_message_iterator: Iterator[ServerMessage] = stub.Join(iter(queue.get, None))

def receive() -> TaskIns:
server_message = next(server_message_iterator)
# Receive ServerMessage proto
proto = next(server_message_iterator)

# ServerMessage proto --> *Ins --> RecordSet
field = proto.WhichOneof("msg")
task_type = ""
if field == "get_properties_ins":
recordset = compat.getpropertiesins_to_recordset(
serde.get_properties_ins_from_proto(proto.get_properties_ins)
)
task_type = TASK_TYPE_GET_PROPERTIES
elif field == "get_parameters_ins":
recordset = compat.getparametersins_to_recordset(
serde.get_parameters_ins_from_proto(proto.get_parameters_ins)
)
task_type = TASK_TYPE_GET_PARAMETERS
elif field == "fit_ins":
recordset = compat.fitins_to_recordset(
serde.fit_ins_from_proto(proto.fit_ins), False
)
task_type = TASK_TYPE_FIT
elif field == "evaluate_ins":
recordset = compat.evaluateins_to_recordset(
serde.evaluate_ins_from_proto(proto.evaluate_ins), False
)
task_type = TASK_TYPE_EVALUATE
else:
raise ValueError(
"Unsupported instruction in ServerMessage, "
"cannot deserialize from ProtoBuf"
)

# RecordSet --> RecordSet proto
recordset_proto = serde.recordset_to_proto(recordset)

# Construct TaskIns
return TaskIns(
task_id=str(uuid.uuid4()),
group_id="",
Expand All @@ -127,13 +170,43 @@ def receive() -> TaskIns:
producer=Node(node_id=0, anonymous=True),
consumer=Node(node_id=0, anonymous=True),
ancestry=[],
legacy_server_message=server_message,
task_type=task_type,
recordset=recordset_proto,
),
)

def send(task_res: TaskRes) -> None:
msg = task_res.task.legacy_client_message
return queue.put(msg, block=False)
# Retrieve RecordSet and task_type
recordset = serde.recordset_from_proto(task_res.task.recordset)
task_type = task_res.task.task_type

# RecordSet --> *Res --> ClientMessage
if task_type == TASK_TYPE_GET_PROPERTIES:
client_message = typing.ClientMessage(
get_properties_res=compat.recordset_to_getpropertiesres(recordset)
)
elif task_type == TASK_TYPE_GET_PARAMETERS:
client_message = typing.ClientMessage(
get_parameters_res=compat.recordset_to_getparametersres(
recordset, False
)
)
elif task_type == TASK_TYPE_FIT:
client_message = typing.ClientMessage(
fit_res=compat.recordset_to_fitres(recordset, False)
)
elif task_type == TASK_TYPE_EVALUATE:
client_message = typing.ClientMessage(
evaluate_res=compat.recordset_to_evaluateres(recordset)
)
else:
raise ValueError(f"Invalid task type: {task_type}")

# ClientMessage --> ClientMessage proto
client_message_proto = serde.client_message_to_proto(client_message)

# Send ClientMessage proto
return queue.put(client_message_proto, block=False)

try:
# Yield methods
Expand Down
26 changes: 18 additions & 8 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
from flwr.client.secure_aggregation import SecureAggregationHandler
from flwr.client.typing import ClientFn
from flwr.common import serde
from flwr.common.constant import (
TASK_TYPE_EVALUATE,
TASK_TYPE_FIT,
TASK_TYPE_GET_PARAMETERS,
TASK_TYPE_GET_PROPERTIES,
)
from flwr.common.context import Context
from flwr.common.message import Message
from flwr.common.recordset import RecordSet
Expand Down Expand Up @@ -201,34 +207,38 @@ def handle_legacy_message_from_tasktype(
task_type = message.metadata.task_type

out_message = Message(metadata=message.metadata, message=RecordSet())
if task_type == "get_properties_ins":
# Handle GetPropertiesIns
if task_type == TASK_TYPE_GET_PROPERTIES:
get_properties_res = maybe_call_get_properties(
client=client,
get_properties_ins=recordset_to_getpropertiesins(message.message),
)
out_message.message = getpropertiesres_to_recordset(get_properties_res)
elif task_type == "get_parameteres_ins":
# Handle GetParametersIns
elif task_type == TASK_TYPE_GET_PARAMETERS:
get_parameters_res = maybe_call_get_parameters(
client=client,
get_parameters_ins=recordset_to_getparametersins(message.message),
)
out_message.message = getparametersres_to_recordset(get_parameters_res)
elif task_type == "fit_ins":
out_message.message = getparametersres_to_recordset(
get_parameters_res, keep_input=False
)
# Handle FitIns
elif task_type == TASK_TYPE_FIT:
fit_res = maybe_call_fit(
client=client,
fit_ins=recordset_to_fitins(message.message, keep_input=False),
)
out_message.message = fitres_to_recordset(fit_res, keep_input=False)
elif task_type == "evaluate_ins":
# Handle EvaluateIns
elif task_type == TASK_TYPE_EVALUATE:
evaluate_res = maybe_call_evaluate(
client=client,
evaluate_ins=recordset_to_evaluateins(message.message, keep_input=False),
)
out_message.message = evaluateres_to_recordset(evaluate_res)
else:
# TODO: what to do with reconnect?
print("do something")

raise ValueError(f"Invalid task type: {task_type}")
return out_message


Expand Down
5 changes: 5 additions & 0 deletions src/py/flwr/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,8 @@
TRANSPORT_TYPE_GRPC_RERE,
TRANSPORT_TYPE_REST,
]

TASK_TYPE_GET_PROPERTIES = "get_properties"
TASK_TYPE_GET_PARAMETERS = "get_parameters"
TASK_TYPE_FIT = "fit"
TASK_TYPE_EVALUATE = "evaluate"

0 comments on commit 69a7a17

Please sign in to comment.