Skip to content

Commit

Permalink
Add GetRun rpc to FleetAPI. (#3253)
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Apr 14, 2024
1 parent 1b4b3e7 commit f86f849
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 5 deletions.
11 changes: 11 additions & 0 deletions src/proto/flwr/proto/fleet.proto
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ service Fleet {
//
// HTTP API path: /api/v1/fleet/push-task-res
rpc PushTaskRes(PushTaskResRequest) returns (PushTaskResResponse) {}

rpc GetRun(GetRunRequest) returns (GetRunResponse) {}
}

// CreateNode messages
Expand Down Expand Up @@ -68,4 +70,13 @@ message PushTaskResResponse {
map<string, uint32> results = 2;
}

// GetRun messages
message Run {
sint64 run_id = 1;
string fab_id = 2;
string fab_version = 3;
}
message GetRunRequest { sint64 run_id = 1; }
message GetRunResponse { Run run = 1; }

message Reconnect { uint64 reconnect = 1; }
16 changes: 11 additions & 5 deletions src/py/flwr/proto/fleet_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 42 additions & 0 deletions src/py/flwr/proto/fleet_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,48 @@ class PushTaskResResponse(google.protobuf.message.Message):
def ClearField(self, field_name: typing_extensions.Literal["reconnect",b"reconnect","results",b"results"]) -> None: ...
global___PushTaskResResponse = PushTaskResResponse

class Run(google.protobuf.message.Message):
"""GetRun messages"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
RUN_ID_FIELD_NUMBER: builtins.int
FAB_ID_FIELD_NUMBER: builtins.int
FAB_VERSION_FIELD_NUMBER: builtins.int
run_id: builtins.int
fab_id: typing.Text
fab_version: typing.Text
def __init__(self,
*,
run_id: builtins.int = ...,
fab_id: typing.Text = ...,
fab_version: typing.Text = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","run_id",b"run_id"]) -> None: ...
global___Run = Run

class GetRunRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
RUN_ID_FIELD_NUMBER: builtins.int
run_id: builtins.int
def __init__(self,
*,
run_id: builtins.int = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ...
global___GetRunRequest = GetRunRequest

class GetRunResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
RUN_FIELD_NUMBER: builtins.int
@property
def run(self) -> global___Run: ...
def __init__(self,
*,
run: typing.Optional[global___Run] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["run",b"run"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["run",b"run"]) -> None: ...
global___GetRunResponse = GetRunResponse

class Reconnect(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
RECONNECT_FIELD_NUMBER: builtins.int
Expand Down
33 changes: 33 additions & 0 deletions src/py/flwr/proto/fleet_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ def __init__(self, channel):
request_serializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResResponse.FromString,
)
self.GetRun = channel.unary_unary(
'/flwr.proto.Fleet/GetRun',
request_serializer=flwr_dot_proto_dot_fleet__pb2.GetRunRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_fleet__pb2.GetRunResponse.FromString,
)


class FleetServicer(object):
Expand Down Expand Up @@ -80,6 +85,12 @@ def PushTaskRes(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def GetRun(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_FleetServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand Down Expand Up @@ -108,6 +119,11 @@ def add_FleetServicer_to_server(servicer, server):
request_deserializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResRequest.FromString,
response_serializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResResponse.SerializeToString,
),
'GetRun': grpc.unary_unary_rpc_method_handler(
servicer.GetRun,
request_deserializer=flwr_dot_proto_dot_fleet__pb2.GetRunRequest.FromString,
response_serializer=flwr_dot_proto_dot_fleet__pb2.GetRunResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'flwr.proto.Fleet', rpc_method_handlers)
Expand Down Expand Up @@ -202,3 +218,20 @@ def PushTaskRes(request,
flwr_dot_proto_dot_fleet__pb2.PushTaskResResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def GetRun(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Fleet/GetRun',
flwr_dot_proto_dot_fleet__pb2.GetRunRequest.SerializeToString,
flwr_dot_proto_dot_fleet__pb2.GetRunResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
10 changes: 10 additions & 0 deletions src/py/flwr/proto/fleet_pb2_grpc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class FleetStub:
HTTP API path: /api/v1/fleet/push-task-res
"""

GetRun: grpc.UnaryUnaryMultiCallable[
flwr.proto.fleet_pb2.GetRunRequest,
flwr.proto.fleet_pb2.GetRunResponse]


class FleetServicer(metaclass=abc.ABCMeta):
@abc.abstractmethod
Expand Down Expand Up @@ -78,5 +82,11 @@ class FleetServicer(metaclass=abc.ABCMeta):
"""
pass

@abc.abstractmethod
def GetRun(self,
request: flwr.proto.fleet_pb2.GetRunRequest,
context: grpc.ServicerContext,
) -> flwr.proto.fleet_pb2.GetRunResponse: ...


def add_FleetServicer_to_server(servicer: FleetServicer, server: grpc.Server) -> None: ...
12 changes: 12 additions & 0 deletions src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
CreateNodeResponse,
DeleteNodeRequest,
DeleteNodeResponse,
GetRunRequest,
GetRunResponse,
PingRequest,
PingResponse,
PullTaskInsRequest,
Expand Down Expand Up @@ -90,3 +92,13 @@ def PushTaskRes(
request=request,
state=self.state_factory.state(),
)

def GetRun(
self, request: GetRunRequest, context: grpc.ServicerContext
) -> GetRunResponse:
"""Get run information."""
log(INFO, "FleetServicer.GetRun")
return message_handler.get_run(
request=request,
state=self.state_factory.state(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
CreateNodeResponse,
DeleteNodeRequest,
DeleteNodeResponse,
GetRunRequest,
GetRunResponse,
PingRequest,
PingResponse,
PullTaskInsRequest,
Expand Down Expand Up @@ -101,3 +103,10 @@ def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResRespo
results={str(task_id): 0},
)
return response


def get_run(
request: GetRunRequest, state: State # pylint: disable=W0613
) -> GetRunResponse:
"""Get run information."""
return GetRunResponse()
30 changes: 30 additions & 0 deletions src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
DeleteNodeRequest,
GetRunRequest,
PingRequest,
PullTaskInsRequest,
PushTaskResRequest,
Expand Down Expand Up @@ -179,12 +180,41 @@ async def ping(request: Request) -> Response:
)


async def get_run(request: Request) -> Response:
"""GetRun."""
_check_headers(request.headers)

# Get the request body as raw bytes
get_run_request_bytes: bytes = await request.body()

# Deserialize ProtoBuf
get_run_request_proto = GetRunRequest()
get_run_request_proto.ParseFromString(get_run_request_bytes)

# Get state from app
state: State = app.state.STATE_FACTORY.state()

# Handle message
get_run_response_proto = message_handler.get_run(
request=get_run_request_proto, state=state
)

# Return serialized ProtoBuf
get_run_response_bytes = get_run_response_proto.SerializeToString()
return Response(
status_code=200,
content=get_run_response_bytes,
headers={"Content-Type": "application/protobuf"},
)


routes = [
Route("/api/v0/fleet/create-node", create_node, methods=["POST"]),
Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]),
Route("/api/v0/fleet/pull-task-ins", pull_task_ins, methods=["POST"]),
Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]),
Route("/api/v0/fleet/ping", ping, methods=["POST"]),
Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
]

app: Starlette = Starlette(
Expand Down

0 comments on commit f86f849

Please sign in to comment.