Skip to content

Commit

Permalink
adopt pull/push messages rpcs
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Jan 15, 2025
1 parent c28fdc5 commit eb36760
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 92 deletions.
24 changes: 12 additions & 12 deletions src/py/flwr/client/connection/fleet_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
DeleteNodeResponse,
PingRequest,
PingResponse,
PullTaskInsRequest,
PullTaskInsResponse,
PushTaskResRequest,
PushTaskResResponse,
PullMessagesRequest,
PullMessagesResponse,
PushMessagesRequest,
PushMessagesResponse,
)
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611

Expand All @@ -56,16 +56,16 @@ def DeleteNode( # pylint: disable=C0103
"""Fleet.DeleteNode."""

@abstractmethod
def PullTaskIns( # pylint: disable=C0103
self, request: PullTaskInsRequest, **kwargs: Any
) -> PullTaskInsResponse:
"""Fleet.PullTaskIns."""
def PullMessages( # pylint: disable=C0103
self, request: PullMessagesRequest, **kwargs: Any
) -> PullMessagesResponse:
"""Fleet.PullMessages."""

@abstractmethod
def PushTaskRes( # pylint: disable=C0103
self, request: PushTaskResRequest, **kwargs: Any
) -> PushTaskResResponse:
"""Fleet.PushTaskRes."""
def PushMessages( # pylint: disable=C0103
self, request: PushMessagesRequest, **kwargs: Any
) -> PushMessagesResponse:
"""Fleet.PushMessages."""

@abstractmethod
def GetRun( # pylint: disable=C0103
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,8 @@
PingResponse,
PullMessagesRequest,
PullMessagesResponse,
PullTaskInsRequest,
PullTaskInsResponse,
PushMessagesRequest,
PushMessagesResponse,
PushTaskResRequest,
PushTaskResResponse,
)
from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub
Expand Down Expand Up @@ -165,24 +161,12 @@ def Ping( # pylint: disable=C0103
"""."""
return self._send_and_receive(request, PingResponse, **kwargs)

def PullTaskIns( # pylint: disable=C0103
self, request: PullTaskInsRequest, **kwargs: Any
) -> PullTaskInsResponse:
"""."""
return self._send_and_receive(request, PullTaskInsResponse, **kwargs)

def PullMessages( # pylint: disable=C0103
self, request: PullMessagesRequest, **kwargs: Any
) -> PullMessagesResponse:
"""."""
return self._send_and_receive(request, PullMessagesResponse, **kwargs)

def PushTaskRes( # pylint: disable=C0103
self, request: PushTaskResRequest, **kwargs: Any
) -> PushTaskResResponse:
"""."""
return self._send_and_receive(request, PushTaskResResponse, **kwargs)

def PushMessages( # pylint: disable=C0103
self, request: PushMessagesRequest, **kwargs: Any
) -> PushMessagesResponse:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,9 @@ def test_grpc_adapter_methods() -> None:
if inspect.isfunction(ref)
}

# Backward compatibility
expected_methods.remove("PullTaskIns")
expected_methods.remove("PushTaskRes")

# Assert
assert expected_methods.issubset(methods)
62 changes: 33 additions & 29 deletions src/py/flwr/client/connection/grpc_rere/client_interceptor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import grpc

from flwr.common import GRPC_MAX_MESSAGE_LENGTH, serde
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
from flwr.common.message import Message, Metadata
from flwr.common.record import RecordSet
from flwr.common.retry_invoker import RetryInvoker, exponential
Expand All @@ -35,21 +35,26 @@
generate_shared_key,
public_key_to_bytes,
)
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611

# pylint: disable=E0611
from flwr.proto.fleet_pb2 import (
CreateNodeRequest,
CreateNodeResponse,
DeleteNodeRequest,
DeleteNodeResponse,
PullTaskInsRequest,
PullTaskInsResponse,
PushTaskResRequest,
PushTaskResResponse,
PullMessagesRequest,
PullMessagesResponse,
PushMessagesRequest,
PushMessagesResponse,
)
from flwr.proto.fleet_pb2_grpc import FleetServicer
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
from flwr.proto.task_pb2 import Task, TaskIns # pylint: disable=E0611
from flwr.proto.message_pb2 import Message as MessageProto
from flwr.proto.message_pb2 import Metadata as MetadataProto
from flwr.proto.node_pb2 import Node
from flwr.proto.recordset_pb2 import RecordSet as RecordSetProto
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse

# pylint: enable=E0611
from .client_interceptor import _AUTH_TOKEN_HEADER, _PUBLIC_KEY_HEADER, Request
from .grpc_rere_fleet_connection import GrpcRereFleetConnection

Expand All @@ -66,10 +71,11 @@ def __init__(self) -> None:
self.server_private_key, self.server_public_key = generate_key_pairs()
self._received_message_bytes: bytes = b""

def unary_unary(
self, request: Request, context: grpc.ServicerContext
) -> Union[
CreateNodeResponse, DeleteNodeResponse, PushTaskResResponse, PullTaskInsResponse
def unary_unary(self, request: Request, context: grpc.ServicerContext) -> Union[
CreateNodeResponse,
DeleteNodeResponse,
PushMessagesResponse,
PullMessagesResponse,
]:
"""Handle unary call."""
with self._lock:
Expand All @@ -90,16 +96,14 @@ def unary_unary(
return CreateNodeResponse(node=Node(node_id=123))
if isinstance(request, DeleteNodeRequest):
return DeleteNodeResponse()
if isinstance(request, PushTaskResRequest):
return PushTaskResResponse()

return PullTaskInsResponse(
task_ins_list=[
TaskIns(
task=Task(
consumer=Node(node_id=123),
recordset=serde.recordset_to_proto(RecordSet()),
)
if isinstance(request, PushMessagesRequest):
return PushMessagesResponse()

return PullMessagesResponse(
messages_list=[
MessageProto(
metadata=MetadataProto(dst_node_id=123),
content=RecordSetProto(),
)
]
)
Expand Down Expand Up @@ -129,15 +133,15 @@ def _add_generic_handler(servicer: _MockServicer, server: grpc.Server) -> None:
request_deserializer=DeleteNodeRequest.FromString,
response_serializer=DeleteNodeResponse.SerializeToString,
),
"PullTaskIns": grpc.unary_unary_rpc_method_handler(
"PullMessages": grpc.unary_unary_rpc_method_handler(
servicer.unary_unary,
request_deserializer=PullTaskInsRequest.FromString,
response_serializer=PullTaskInsResponse.SerializeToString,
request_deserializer=PullMessagesRequest.FromString,
response_serializer=PullMessagesResponse.SerializeToString,
),
"PushTaskRes": grpc.unary_unary_rpc_method_handler(
"PushMessages": grpc.unary_unary_rpc_method_handler(
servicer.unary_unary,
request_deserializer=PushTaskResRequest.FromString,
response_serializer=PushTaskResResponse.SerializeToString,
request_deserializer=PushMessagesRequest.FromString,
response_serializer=PushMessagesResponse.SerializeToString,
),
"GetRun": grpc.unary_unary_rpc_method_handler(
servicer.unary_unary,
Expand Down
40 changes: 19 additions & 21 deletions src/py/flwr/client/connection/rere_fleet_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

from flwr.client.heartbeat import start_ping_loop
from flwr.client.message_handler.message_handler import validate_out_message
from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
from flwr.common.constant import (
PING_BASE_MULTIPLIER,
Expand All @@ -39,7 +38,7 @@
from flwr.common.logger import log
from flwr.common.message import Message, Metadata
from flwr.common.retry_invoker import RetryInvoker
from flwr.common.serde import message_from_taskins, message_to_taskres, run_from_proto
from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
from flwr.common.typing import Fab, Run, RunNotRunningException
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
Expand All @@ -48,12 +47,12 @@
DeleteNodeRequest,
PingRequest,
PingResponse,
PullTaskInsRequest,
PushTaskResRequest,
PullMessagesRequest,
PullMessagesResponse,
PushMessagesRequest,
)
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611

from .fleet_api import FleetApi
from .fleet_connection import FleetConnection
Expand Down Expand Up @@ -170,22 +169,23 @@ def receive(self) -> Message | None:
log(ERROR, "Node instance missing")
return None

# Request instructions (task) from server
req = PullTaskInsRequest(node=self.node)
res = self.retry_invoker.invoke(self.api.PullTaskIns, request=req)
# Request instructions (message) from server
req = PullMessagesRequest(node=self.node)
res: PullMessagesResponse = self.retry_invoker.invoke(
self.api.PullMessages, request=req
)

# Get the current TaskIns
task_ins: TaskIns | None = get_task_ins(res)
message_proto = res.messages_list[0] if res.messages_list else None

# Discard the current TaskIns if not valid
if task_ins is not None and not (
task_ins.task.consumer.node_id == self.node.node_id
and validate_task_ins(task_ins)
# Discard the current message if not valid
if message_proto is not None and not (
message_proto.metadata.dst_node_id == self.node.node_id
):
task_ins = None
message_proto = None

# Construct the Message
in_message = message_from_taskins(task_ins) if task_ins else None
in_message = message_from_proto(message_proto) if message_proto else None

# Remember `metadata` of the in message
if in_message:
Expand Down Expand Up @@ -213,12 +213,10 @@ def send(self, message: Message) -> None:
log(ERROR, "Invalid out message")
return

# Construct TaskRes
task_res = message_to_taskres(message)

# Serialize ProtoBuf to bytes
req = PushTaskResRequest(node=self.node, task_res_list=[task_res])
self.retry_invoker.invoke(self.api.PushTaskRes, req)
# Serialize Message
message_proto = message_to_proto(message=message)
req = PushMessagesRequest(node=self.node, messages_list=[message_proto])
self.retry_invoker.invoke(self.api.PushMessages, req)

# Cleanup
metadata = None
Expand Down
32 changes: 18 additions & 14 deletions src/py/flwr/client/connection/rest/rest_fleet_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@
DeleteNodeResponse,
PingRequest,
PingResponse,
PullTaskInsRequest,
PullTaskInsResponse,
PushTaskResRequest,
PushTaskResResponse,
PullMessagesRequest,
PullMessagesResponse,
PushMessagesRequest,
PushMessagesResponse,
)
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611

Expand All @@ -51,8 +51,8 @@

PATH_CREATE_NODE: str = "api/v0/fleet/create-node"
PATH_DELETE_NODE: str = "api/v0/fleet/delete-node"
PATH_PULL_TASK_INS: str = "api/v0/fleet/pull-task-ins"
PATH_PUSH_TASK_RES: str = "api/v0/fleet/push-task-res"
PATH_PULL_MESSAGES: str = "/api/v0/fleet/pull-messages"
PATH_PUSH_MESSAGES: str = "/api/v0/fleet/push-messages"
PATH_PING: str = "api/v0/fleet/ping"
PATH_GET_RUN: str = "/api/v0/fleet/get-run"
PATH_GET_FAB: str = "/api/v0/fleet/get-fab"
Expand Down Expand Up @@ -181,17 +181,21 @@ def Ping( # pylint: disable=C0103
"""."""
return self._request(request, PingResponse, PATH_PING, **kwargs)

def PullTaskIns( # pylint: disable=C0103
self, request: PullTaskInsRequest, **kwargs: Any
) -> PullTaskInsResponse:
def PullMessages( # pylint: disable=C0103
self, request: PullMessagesRequest, **kwargs: Any
) -> PullMessagesResponse:
"""."""
return self._request(request, PullTaskInsResponse, PATH_PULL_TASK_INS, **kwargs)
return self._request(
request, PullMessagesResponse, PATH_PULL_MESSAGES, **kwargs
)

def PushTaskRes( # pylint: disable=C0103
self, request: PushTaskResRequest, **kwargs: Any
) -> PushTaskResResponse:
def PushMessages( # pylint: disable=C0103
self, request: PushMessagesRequest, **kwargs: Any
) -> PushMessagesResponse:
"""."""
return self._request(request, PushTaskResResponse, PATH_PUSH_TASK_RES, **kwargs)
return self._request(
request, PushMessagesResponse, PATH_PUSH_MESSAGES, **kwargs
)

def GetRun( # pylint: disable=C0103
self, request: GetRunRequest, **kwargs: Any
Expand Down

0 comments on commit eb36760

Please sign in to comment.