From 0d8b0d2eb912e23443c79db8a4c83797551a8c7f Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Fri, 12 Apr 2024 16:47:54 +0100 Subject: [PATCH] add endpoint --- .../fleet/grpc_rere/fleet_servicer.py | 12 ++++++++ .../fleet/message_handler/message_handler.py | 9 ++++++ .../superlink/fleet/rest_rere/rest_api.py | 30 +++++++++++++++++++ 3 files changed, 51 insertions(+) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py index eb8dd800ea37..03a2ec064213 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py @@ -26,6 +26,8 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, + GetRunRequest, + GetRunResponse, PingRequest, PingResponse, PullTaskInsRequest, @@ -90,3 +92,13 @@ def PushTaskRes( request=request, state=self.state_factory.state(), ) + + def GetRun( + self, request: GetRunRequest, context: grpc.ServicerContext + ) -> GetRunResponse: + """Get run information.""" + log(INFO, "FleetServicer.GetRun") + return message_handler.get_run( + request=request, + state=self.state_factory.state(), + ) diff --git a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py index 39edd606b464..5480d15e5936 100644 --- a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py @@ -24,6 +24,8 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, + GetRunRequest, + GetRunResponse, PingRequest, PingResponse, PullTaskInsRequest, @@ -101,3 +103,10 @@ def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResRespo results={str(task_id): 0}, ) return response + + +def get_run( + request: GetRunRequest, state: State # pylint: disable=W0613 +) -> GetRunResponse: + """Get run information.""" + return GetRunResponse() diff --git a/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py index 33d17ef1d579..8ac7c6cfc613 100644 --- a/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py +++ b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py @@ -21,6 +21,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, + GetRunRequest, PingRequest, PullTaskInsRequest, PushTaskResRequest, @@ -179,12 +180,41 @@ async def ping(request: Request) -> Response: ) +async def get_run(request: Request) -> Response: + """GetRun.""" + _check_headers(request.headers) + + # Get the request body as raw bytes + get_run_request_bytes: bytes = await request.body() + + # Deserialize ProtoBuf + get_run_request_proto = GetRunRequest() + get_run_request_proto.ParseFromString(get_run_request_bytes) + + # Get state from app + state: State = app.state.STATE_FACTORY.state() + + # Handle message + get_run_response_proto = message_handler.get_run( + request=get_run_request_proto, state=state + ) + + # Return serialized ProtoBuf + get_run_response_bytes = get_run_response_proto.SerializeToString() + return Response( + status_code=200, + content=get_run_response_bytes, + headers={"Content-Type": "application/protobuf"}, + ) + + routes = [ Route("/api/v0/fleet/create-node", create_node, methods=["POST"]), Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]), Route("/api/v0/fleet/pull-task-ins", pull_task_ins, methods=["POST"]), Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]), Route("/api/v0/fleet/ping", ping, methods=["POST"]), + Route("/api/v0/fleet/get-run", get_run, methods=["POST"]), ] app: Starlette = Starlette(