Skip to content

Commit

Permalink
feat(framework) Enable PullMessages and PushMessages RPCs in `Grp…
Browse files Browse the repository at this point in the history
…cAdapterServicer`
  • Loading branch information
panh99 authored Jan 22, 2025
1 parent 1c03c29 commit 1f304f6
Showing 1 changed file with 20 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Fleet API gRPC adapter servicer."""


from logging import DEBUG, INFO
from logging import DEBUG
from typing import Callable, TypeVar

import grpc
Expand All @@ -31,35 +31,30 @@
from flwr.common.logger import log
from flwr.common.version import package_name, package_version
from flwr.proto import grpcadapter_pb2_grpc # pylint: disable=E0611
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
from flwr.proto.fab_pb2 import GetFabRequest # pylint: disable=E0611
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
CreateNodeResponse,
DeleteNodeRequest,
DeleteNodeResponse,
PingRequest,
PingResponse,
PullTaskInsRequest,
PullTaskInsResponse,
PushTaskResRequest,
PushTaskResResponse,
PullMessagesRequest,
PushMessagesRequest,
)
from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
from flwr.server.superlink.fleet.message_handler import message_handler
from flwr.server.superlink.linkstate import LinkStateFactory
from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611

from ..grpc_rere.fleet_servicer import FleetServicer

T = TypeVar("T", bound=GrpcMessage)


def _handle(
msg_container: MessageContainer,
context: grpc.ServicerContext,
request_type: type[T],
handler: Callable[[T], GrpcMessage],
handler: Callable[[T, grpc.ServicerContext], GrpcMessage],
) -> MessageContainer:
req = request_type.FromString(msg_container.grpc_message_content)
res = handler(req)
res = handler(req, context)
res_cls = res.__class__
return MessageContainer(
metadata={
Expand All @@ -74,89 +69,26 @@ def _handle(
)


class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer, FleetServicer):
"""Fleet API via GrpcAdapter servicer."""

def __init__(
self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
) -> None:
self.state_factory = state_factory
self.ffs_factory = ffs_factory

def SendReceive( # pylint: disable=too-many-return-statements
self, request: MessageContainer, context: grpc.ServicerContext
) -> MessageContainer:
"""."""
log(DEBUG, "GrpcAdapterServicer.SendReceive")
if request.grpc_message_name == CreateNodeRequest.__qualname__:
return _handle(request, CreateNodeRequest, self._create_node)
return _handle(request, context, CreateNodeRequest, self.CreateNode)
if request.grpc_message_name == DeleteNodeRequest.__qualname__:
return _handle(request, DeleteNodeRequest, self._delete_node)
return _handle(request, context, DeleteNodeRequest, self.DeleteNode)
if request.grpc_message_name == PingRequest.__qualname__:
return _handle(request, PingRequest, self._ping)
if request.grpc_message_name == PullTaskInsRequest.__qualname__:
return _handle(request, PullTaskInsRequest, self._pull_task_ins)
if request.grpc_message_name == PushTaskResRequest.__qualname__:
return _handle(request, PushTaskResRequest, self._push_task_res)
return _handle(request, context, PingRequest, self.Ping)
if request.grpc_message_name == GetRunRequest.__qualname__:
return _handle(request, GetRunRequest, self._get_run)
return _handle(request, context, GetRunRequest, self.GetRun)
if request.grpc_message_name == GetFabRequest.__qualname__:
return _handle(request, GetFabRequest, self._get_fab)
return _handle(request, context, GetFabRequest, self.GetFab)
if request.grpc_message_name == PullMessagesRequest.__qualname__:
return _handle(request, context, PullMessagesRequest, self.PullMessages)
if request.grpc_message_name == PushMessagesRequest.__qualname__:
return _handle(request, context, PushMessagesRequest, self.PushMessages)
raise ValueError(f"Invalid grpc_message_name: {request.grpc_message_name}")

def _create_node(self, request: CreateNodeRequest) -> CreateNodeResponse:
"""."""
log(INFO, "GrpcAdapter.CreateNode")
return message_handler.create_node(
request=request,
state=self.state_factory.state(),
)

def _delete_node(self, request: DeleteNodeRequest) -> DeleteNodeResponse:
"""."""
log(INFO, "GrpcAdapter.DeleteNode")
return message_handler.delete_node(
request=request,
state=self.state_factory.state(),
)

def _ping(self, request: PingRequest) -> PingResponse:
"""."""
log(DEBUG, "GrpcAdapter.Ping")
return message_handler.ping(
request=request,
state=self.state_factory.state(),
)

def _pull_task_ins(self, request: PullTaskInsRequest) -> PullTaskInsResponse:
"""Pull TaskIns."""
log(INFO, "GrpcAdapter.PullTaskIns")
return message_handler.pull_task_ins(
request=request,
state=self.state_factory.state(),
)

def _push_task_res(self, request: PushTaskResRequest) -> PushTaskResResponse:
"""Push TaskRes."""
log(INFO, "GrpcAdapter.PushTaskRes")
return message_handler.push_task_res(
request=request,
state=self.state_factory.state(),
)

def _get_run(self, request: GetRunRequest) -> GetRunResponse:
"""Get run information."""
log(INFO, "GrpcAdapter.GetRun")
return message_handler.get_run(
request=request,
state=self.state_factory.state(),
)

def _get_fab(self, request: GetFabRequest) -> GetFabResponse:
"""Get FAB."""
log(INFO, "GrpcAdapter.GetFab")
return message_handler.get_fab(
request=request,
ffs=self.ffs_factory.ffs(),
state=self.state_factory.state(),
)

0 comments on commit 1f304f6

Please sign in to comment.