Skip to content

Commit

Permalink
centralize run-related messages and add GetRunStatus rpc
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Sep 13, 2024
1 parent 310c234 commit 87937bb
Show file tree
Hide file tree
Showing 16 changed files with 330 additions and 186 deletions.
24 changes: 3 additions & 21 deletions src/proto/flwr/proto/control.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,12 @@ syntax = "proto3";

package flwr.proto;

import "flwr/proto/fab.proto";
import "flwr/proto/transport.proto";
import "flwr/proto/run.proto";

service Control {
// Request to create a new run
rpc CreateRun(CreateRunRequest) returns (CreateRunResponse);
rpc CreateRun(CreateRunRequest) returns (CreateRunResponse) {}

// Get the status of a given run
rpc GetRunStatus(GetRunStatusRequest) returns (GetRunStatusResponse);
rpc GetRunStatus(GetRunStatusRequest) returns (GetRunStatusResponse) {}
}

// CreateRun
message CreateRunRequest {
string fab_id = 1;
string fab_version = 2;
map<string, Scalar> override_config = 3;
Fab fab = 4;
double ttl = 5;
}
message CreateRunResponse {
bool success = 1;
sint64 run_id = 2;
}

// GetRunStatus
message GetRunStatusRequest { sint64 run_id = 1; }
message GetRunStatusResponse { string status = 1; }
4 changes: 3 additions & 1 deletion src/proto/flwr/proto/driver.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ syntax = "proto3";

package flwr.proto;

import "flwr/proto/control.proto";
import "flwr/proto/node.proto";
import "flwr/proto/task.proto";
import "flwr/proto/run.proto";
Expand All @@ -41,6 +40,9 @@ service Driver {

// Get FAB
rpc GetFab(GetFabRequest) returns (GetFabResponse) {}

// Get run status
rpc GetRunStatus(GetRunStatusRequest) returns (GetRunStatusResponse) {}
}

// GetNodes messages
Expand Down
3 changes: 3 additions & 0 deletions src/proto/flwr/proto/fleet.proto
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ service Fleet {

// Get FAB
rpc GetFab(GetFabRequest) returns (GetFabResponse) {}

// Get run status
rpc GetRunStatus(GetRunStatusRequest) returns (GetRunStatusResponse) {}
}

// CreateNode messages
Expand Down
27 changes: 27 additions & 0 deletions src/proto/flwr/proto/run.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ syntax = "proto3";

package flwr.proto;

import "flwr/proto/fab.proto";
import "flwr/proto/transport.proto";

message Run {
Expand All @@ -26,5 +27,31 @@ message Run {
map<string, Scalar> override_config = 4;
string fab_hash = 5;
}

// CreateRun
message CreateRunRequest {
string fab_id = 1;
string fab_version = 2;
map<string, Scalar> override_config = 3;
Fab fab = 4;
double ttl = 5;
}
message CreateRunResponse {
bool success = 1;
sint64 run_id = 2;
}

// GetRun
message GetRunRequest { sint64 run_id = 1; }
message GetRunResponse { Run run = 1; }

// GetRunStatus
message GetRunStatusRequest { repeated uint64 run_ids = 1; }
message GetRunStatusResponse { map<uint64, string> run_status_dict = 1; }

// EndRun
message EndRunRequest {
uint64 run_id = 1;
string status = 2;
}
message EndRunResponse {}
21 changes: 4 additions & 17 deletions src/py/flwr/proto/control_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

86 changes: 0 additions & 86 deletions src/py/flwr/proto/control_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,92 +2,6 @@
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import builtins
import flwr.proto.fab_pb2
import flwr.proto.transport_pb2
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.message
import typing
import typing_extensions

DESCRIPTOR: google.protobuf.descriptor.FileDescriptor

class CreateRunRequest(google.protobuf.message.Message):
"""CreateRun"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class OverrideConfigEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: typing.Text
@property
def value(self) -> flwr.proto.transport_pb2.Scalar: ...
def __init__(self,
*,
key: typing.Text = ...,
value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...

FAB_ID_FIELD_NUMBER: builtins.int
FAB_VERSION_FIELD_NUMBER: builtins.int
OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int
FAB_FIELD_NUMBER: builtins.int
TTL_FIELD_NUMBER: builtins.int
fab_id: typing.Text
fab_version: typing.Text
@property
def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
@property
def fab(self) -> flwr.proto.fab_pb2.Fab: ...
ttl: builtins.float
def __init__(self,
*,
fab_id: typing.Text = ...,
fab_version: typing.Text = ...,
override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
fab: typing.Optional[flwr.proto.fab_pb2.Fab] = ...,
ttl: builtins.float = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["fab",b"fab"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["fab",b"fab","fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config","ttl",b"ttl"]) -> None: ...
global___CreateRunRequest = CreateRunRequest

class CreateRunResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
SUCCESS_FIELD_NUMBER: builtins.int
RUN_ID_FIELD_NUMBER: builtins.int
success: builtins.bool
run_id: builtins.int
def __init__(self,
*,
success: builtins.bool = ...,
run_id: builtins.int = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id","success",b"success"]) -> None: ...
global___CreateRunResponse = CreateRunResponse

class GetRunStatusRequest(google.protobuf.message.Message):
"""GetRunStatus"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
RUN_ID_FIELD_NUMBER: builtins.int
run_id: builtins.int
def __init__(self,
*,
run_id: builtins.int = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ...
global___GetRunStatusRequest = GetRunStatusRequest

class GetRunStatusResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
STATUS_FIELD_NUMBER: builtins.int
status: typing.Text
def __init__(self,
*,
status: typing.Text = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["status",b"status"]) -> None: ...
global___GetRunStatusResponse = GetRunStatusResponse
26 changes: 13 additions & 13 deletions src/py/flwr/proto/control_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""Client and server classes corresponding to protobuf-defined services."""
import grpc

from flwr.proto import control_pb2 as flwr_dot_proto_dot_control__pb2
from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2


class ControlStub(object):
Expand All @@ -16,13 +16,13 @@ def __init__(self, channel):
"""
self.CreateRun = channel.unary_unary(
'/flwr.proto.Control/CreateRun',
request_serializer=flwr_dot_proto_dot_control__pb2.CreateRunRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_control__pb2.CreateRunResponse.FromString,
request_serializer=flwr_dot_proto_dot_run__pb2.CreateRunRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_run__pb2.CreateRunResponse.FromString,
)
self.GetRunStatus = channel.unary_unary(
'/flwr.proto.Control/GetRunStatus',
request_serializer=flwr_dot_proto_dot_control__pb2.GetRunStatusRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_control__pb2.GetRunStatusResponse.FromString,
request_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString,
)


Expand All @@ -48,13 +48,13 @@ def add_ControlServicer_to_server(servicer, server):
rpc_method_handlers = {
'CreateRun': grpc.unary_unary_rpc_method_handler(
servicer.CreateRun,
request_deserializer=flwr_dot_proto_dot_control__pb2.CreateRunRequest.FromString,
response_serializer=flwr_dot_proto_dot_control__pb2.CreateRunResponse.SerializeToString,
request_deserializer=flwr_dot_proto_dot_run__pb2.CreateRunRequest.FromString,
response_serializer=flwr_dot_proto_dot_run__pb2.CreateRunResponse.SerializeToString,
),
'GetRunStatus': grpc.unary_unary_rpc_method_handler(
servicer.GetRunStatus,
request_deserializer=flwr_dot_proto_dot_control__pb2.GetRunStatusRequest.FromString,
response_serializer=flwr_dot_proto_dot_control__pb2.GetRunStatusResponse.SerializeToString,
request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.FromString,
response_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
Expand All @@ -78,8 +78,8 @@ def CreateRun(request,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Control/CreateRun',
flwr_dot_proto_dot_control__pb2.CreateRunRequest.SerializeToString,
flwr_dot_proto_dot_control__pb2.CreateRunResponse.FromString,
flwr_dot_proto_dot_run__pb2.CreateRunRequest.SerializeToString,
flwr_dot_proto_dot_run__pb2.CreateRunResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

Expand All @@ -95,7 +95,7 @@ def GetRunStatus(request,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Control/GetRunStatus',
flwr_dot_proto_dot_control__pb2.GetRunStatusRequest.SerializeToString,
flwr_dot_proto_dot_control__pb2.GetRunStatusResponse.FromString,
flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString,
flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
18 changes: 9 additions & 9 deletions src/py/flwr/proto/control_pb2_grpc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,36 @@
isort:skip_file
"""
import abc
import flwr.proto.control_pb2
import flwr.proto.run_pb2
import grpc

class ControlStub:
def __init__(self, channel: grpc.Channel) -> None: ...
CreateRun: grpc.UnaryUnaryMultiCallable[
flwr.proto.control_pb2.CreateRunRequest,
flwr.proto.control_pb2.CreateRunResponse]
flwr.proto.run_pb2.CreateRunRequest,
flwr.proto.run_pb2.CreateRunResponse]
"""Request to create a new run"""

GetRunStatus: grpc.UnaryUnaryMultiCallable[
flwr.proto.control_pb2.GetRunStatusRequest,
flwr.proto.control_pb2.GetRunStatusResponse]
flwr.proto.run_pb2.GetRunStatusRequest,
flwr.proto.run_pb2.GetRunStatusResponse]
"""Get the status of a given run"""


class ControlServicer(metaclass=abc.ABCMeta):
@abc.abstractmethod
def CreateRun(self,
request: flwr.proto.control_pb2.CreateRunRequest,
request: flwr.proto.run_pb2.CreateRunRequest,
context: grpc.ServicerContext,
) -> flwr.proto.control_pb2.CreateRunResponse:
) -> flwr.proto.run_pb2.CreateRunResponse:
"""Request to create a new run"""
pass

@abc.abstractmethod
def GetRunStatus(self,
request: flwr.proto.control_pb2.GetRunStatusRequest,
request: flwr.proto.run_pb2.GetRunStatusRequest,
context: grpc.ServicerContext,
) -> flwr.proto.control_pb2.GetRunStatusResponse:
) -> flwr.proto.run_pb2.GetRunStatusResponse:
"""Get the status of a given run"""
pass

Expand Down
31 changes: 15 additions & 16 deletions src/py/flwr/proto/driver_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 87937bb

Please sign in to comment.