Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GetRun rpc to FleetAPI. #3253

Merged
merged 6 commits into from
Apr 14, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/proto/flwr/proto/fleet.proto
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ service Fleet {
//
// HTTP API path: /api/v1/fleet/push-task-res
rpc PushTaskRes(PushTaskResRequest) returns (PushTaskResResponse) {}

rpc GetRun(GetRunRequest) returns (GetRunResponse) {}
}

// CreateNode messages
Expand Down Expand Up @@ -68,4 +70,11 @@ message PushTaskResResponse {
map<string, uint32> results = 2;
}

// GetRun messages
message GetRunRequest { sint64 run_id = 1; }
message GetRunResponse {
string fab_id = 1;
string fab_version = 2;
}

message Reconnect { uint64 reconnect = 1; }
14 changes: 9 additions & 5 deletions src/py/flwr/proto/fleet_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 26 additions & 0 deletions src/py/flwr/proto/fleet_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,32 @@ class PushTaskResResponse(google.protobuf.message.Message):
def ClearField(self, field_name: typing_extensions.Literal["reconnect",b"reconnect","results",b"results"]) -> None: ...
global___PushTaskResResponse = PushTaskResResponse

class GetRunRequest(google.protobuf.message.Message):
"""GetRun messages"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
RUN_ID_FIELD_NUMBER: builtins.int
run_id: builtins.int
def __init__(self,
*,
run_id: builtins.int = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ...
global___GetRunRequest = GetRunRequest

class GetRunResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
FAB_ID_FIELD_NUMBER: builtins.int
FAB_VERSION_FIELD_NUMBER: builtins.int
fab_id: typing.Text
fab_version: typing.Text
def __init__(self,
*,
fab_id: typing.Text = ...,
fab_version: typing.Text = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version"]) -> None: ...
global___GetRunResponse = GetRunResponse

class Reconnect(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
RECONNECT_FIELD_NUMBER: builtins.int
Expand Down
33 changes: 33 additions & 0 deletions src/py/flwr/proto/fleet_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ def __init__(self, channel):
request_serializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResResponse.FromString,
)
self.GetRun = channel.unary_unary(
'/flwr.proto.Fleet/GetRun',
request_serializer=flwr_dot_proto_dot_fleet__pb2.GetRunRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_fleet__pb2.GetRunResponse.FromString,
)


class FleetServicer(object):
Expand Down Expand Up @@ -80,6 +85,12 @@ def PushTaskRes(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def GetRun(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_FleetServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand Down Expand Up @@ -108,6 +119,11 @@ def add_FleetServicer_to_server(servicer, server):
request_deserializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResRequest.FromString,
response_serializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResResponse.SerializeToString,
),
'GetRun': grpc.unary_unary_rpc_method_handler(
servicer.GetRun,
request_deserializer=flwr_dot_proto_dot_fleet__pb2.GetRunRequest.FromString,
response_serializer=flwr_dot_proto_dot_fleet__pb2.GetRunResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'flwr.proto.Fleet', rpc_method_handlers)
Expand Down Expand Up @@ -202,3 +218,20 @@ def PushTaskRes(request,
flwr_dot_proto_dot_fleet__pb2.PushTaskResResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def GetRun(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Fleet/GetRun',
flwr_dot_proto_dot_fleet__pb2.GetRunRequest.SerializeToString,
flwr_dot_proto_dot_fleet__pb2.GetRunResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
10 changes: 10 additions & 0 deletions src/py/flwr/proto/fleet_pb2_grpc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class FleetStub:
HTTP API path: /api/v1/fleet/push-task-res
"""

GetRun: grpc.UnaryUnaryMultiCallable[
flwr.proto.fleet_pb2.GetRunRequest,
flwr.proto.fleet_pb2.GetRunResponse]


class FleetServicer(metaclass=abc.ABCMeta):
@abc.abstractmethod
Expand Down Expand Up @@ -78,5 +82,11 @@ class FleetServicer(metaclass=abc.ABCMeta):
"""
pass

@abc.abstractmethod
def GetRun(self,
request: flwr.proto.fleet_pb2.GetRunRequest,
context: grpc.ServicerContext,
) -> flwr.proto.fleet_pb2.GetRunResponse: ...


def add_FleetServicer_to_server(servicer: FleetServicer, server: grpc.Server) -> None: ...
12 changes: 12 additions & 0 deletions src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
CreateNodeResponse,
DeleteNodeRequest,
DeleteNodeResponse,
GetRunRequest,
GetRunResponse,
PingRequest,
PingResponse,
PullTaskInsRequest,
Expand Down Expand Up @@ -90,3 +92,13 @@ def PushTaskRes(
request=request,
state=self.state_factory.state(),
)

def GetRun(
self, request: GetRunRequest, context: grpc.ServicerContext
) -> GetRunResponse:
"""Get run information."""
log(INFO, "FleetServicer.GetRun")
return message_handler.get_run(
request=request,
state=self.state_factory.state(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
CreateNodeResponse,
DeleteNodeRequest,
DeleteNodeResponse,
GetRunRequest,
GetRunResponse,
PingRequest,
PingResponse,
PullTaskInsRequest,
Expand Down Expand Up @@ -101,3 +103,10 @@ def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResRespo
results={str(task_id): 0},
)
return response


def get_run(
request: GetRunRequest, state: State # pylint: disable=W0613
) -> GetRunResponse:
"""Get run information."""
return GetRunResponse()
30 changes: 30 additions & 0 deletions src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
DeleteNodeRequest,
GetRunRequest,
PingRequest,
PullTaskInsRequest,
PushTaskResRequest,
Expand Down Expand Up @@ -179,12 +180,41 @@ async def ping(request: Request) -> Response:
)


async def get_run(request: Request) -> Response:
"""GetRun."""
_check_headers(request.headers)

# Get the request body as raw bytes
get_run_request_bytes: bytes = await request.body()

# Deserialize ProtoBuf
get_run_request_proto = GetRunRequest()
get_run_request_proto.ParseFromString(get_run_request_bytes)

# Get state from app
state: State = app.state.STATE_FACTORY.state()

# Handle message
get_run_response_proto = message_handler.get_run(
request=get_run_request_proto, state=state
)

# Return serialized ProtoBuf
get_run_response_bytes = get_run_response_proto.SerializeToString()
return Response(
status_code=200,
content=get_run_response_bytes,
headers={"Content-Type": "application/protobuf"},
)


routes = [
Route("/api/v0/fleet/create-node", create_node, methods=["POST"]),
Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]),
Route("/api/v0/fleet/pull-task-ins", pull_task_ins, methods=["POST"]),
Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]),
Route("/api/v0/fleet/ping", ping, methods=["POST"]),
Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
]

app: Starlette = Starlette(
Expand Down