Skip to content

Commit

Permalink
add run proto
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Jun 11, 2024
1 parent c21c913 commit 82a399a
Show file tree
Hide file tree
Showing 25 changed files with 251 additions and 131 deletions.
4 changes: 4 additions & 0 deletions src/proto/flwr/proto/driver.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package flwr.proto;

import "flwr/proto/node.proto";
import "flwr/proto/task.proto";
import "flwr/proto/run.proto";

service Driver {
// Request run_id
Expand All @@ -32,6 +33,9 @@ service Driver {

// Get task results
rpc PullTaskRes(PullTaskResRequest) returns (PullTaskResResponse) {}

// Get run details
rpc GetRun(GetRunRequest) returns (GetRunResponse) {}
}

// CreateRun
Expand Down
10 changes: 1 addition & 9 deletions src/proto/flwr/proto/fleet.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package flwr.proto;

import "flwr/proto/node.proto";
import "flwr/proto/task.proto";
import "flwr/proto/run.proto";

service Fleet {
rpc CreateNode(CreateNodeRequest) returns (CreateNodeResponse) {}
Expand Down Expand Up @@ -70,13 +71,4 @@ message PushTaskResResponse {
map<string, uint32> results = 2;
}

// GetRun messages
message Run {
sint64 run_id = 1;
string fab_id = 2;
string fab_version = 3;
}
message GetRunRequest { sint64 run_id = 1; }
message GetRunResponse { Run run = 1; }

message Reconnect { uint64 reconnect = 1; }
26 changes: 26 additions & 0 deletions src/proto/flwr/proto/run.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright 2024 Flower Labs GmbH. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ==============================================================================

syntax = "proto3";

package flwr.proto;

message Run {
sint64 run_id = 1;
string fab_id = 2;
string fab_version = 3;
}
message GetRunRequest { sint64 run_id = 1; }
message GetRunResponse { Run run = 1; }
2 changes: 1 addition & 1 deletion src/py/flwr/client/grpc_rere_client/client_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
DeleteNodeRequest,
GetRunRequest,
PingRequest,
PullTaskInsRequest,
PushTaskResRequest,
)
from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611

_PUBLIC_KEY_HEADER = "public-key"
_AUTH_TOKEN_HEADER = "auth-token"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,12 @@
CreateNodeResponse,
DeleteNodeRequest,
DeleteNodeResponse,
GetRunRequest,
GetRunResponse,
PullTaskInsRequest,
PullTaskInsResponse,
PushTaskResRequest,
PushTaskResResponse,
)
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611

from .client_interceptor import _AUTH_TOKEN_HEADER, _PUBLIC_KEY_HEADER, Request

Expand Down
3 changes: 1 addition & 2 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,14 @@
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
DeleteNodeRequest,
GetRunRequest,
GetRunResponse,
PingRequest,
PingResponse,
PullTaskInsRequest,
PushTaskResRequest,
)
from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611

from .client_interceptor import AuthenticateClientInterceptor
Expand Down
3 changes: 1 addition & 2 deletions src/py/flwr/client/rest_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@
CreateNodeResponse,
DeleteNodeRequest,
DeleteNodeResponse,
GetRunRequest,
GetRunResponse,
PingRequest,
PingResponse,
PullTaskInsRequest,
Expand All @@ -56,6 +54,7 @@
PushTaskResResponse,
)
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611

try:
Expand Down
39 changes: 20 additions & 19 deletions src/py/flwr/proto/driver_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 35 additions & 0 deletions src/py/flwr/proto/driver_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import grpc

from flwr.proto import driver_pb2 as flwr_dot_proto_dot_driver__pb2
from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2


class DriverStub(object):
Expand Down Expand Up @@ -34,6 +35,11 @@ def __init__(self, channel):
request_serializer=flwr_dot_proto_dot_driver__pb2.PullTaskResRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_driver__pb2.PullTaskResResponse.FromString,
)
self.GetRun = channel.unary_unary(
'/flwr.proto.Driver/GetRun',
request_serializer=flwr_dot_proto_dot_run__pb2.GetRunRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_run__pb2.GetRunResponse.FromString,
)


class DriverServicer(object):
Expand Down Expand Up @@ -67,6 +73,13 @@ def PullTaskRes(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def GetRun(self, request, context):
"""Get run details
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_DriverServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand All @@ -90,6 +103,11 @@ def add_DriverServicer_to_server(servicer, server):
request_deserializer=flwr_dot_proto_dot_driver__pb2.PullTaskResRequest.FromString,
response_serializer=flwr_dot_proto_dot_driver__pb2.PullTaskResResponse.SerializeToString,
),
'GetRun': grpc.unary_unary_rpc_method_handler(
servicer.GetRun,
request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunRequest.FromString,
response_serializer=flwr_dot_proto_dot_run__pb2.GetRunResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'flwr.proto.Driver', rpc_method_handlers)
Expand Down Expand Up @@ -167,3 +185,20 @@ def PullTaskRes(request,
flwr_dot_proto_dot_driver__pb2.PullTaskResResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def GetRun(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/GetRun',
flwr_dot_proto_dot_run__pb2.GetRunRequest.SerializeToString,
flwr_dot_proto_dot_run__pb2.GetRunResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
14 changes: 14 additions & 0 deletions src/py/flwr/proto/driver_pb2_grpc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ isort:skip_file
"""
import abc
import flwr.proto.driver_pb2
import flwr.proto.run_pb2
import grpc

class DriverStub:
Expand All @@ -28,6 +29,11 @@ class DriverStub:
flwr.proto.driver_pb2.PullTaskResResponse]
"""Get task results"""

GetRun: grpc.UnaryUnaryMultiCallable[
flwr.proto.run_pb2.GetRunRequest,
flwr.proto.run_pb2.GetRunResponse]
"""Get run details"""


class DriverServicer(metaclass=abc.ABCMeta):
@abc.abstractmethod
Expand Down Expand Up @@ -62,5 +68,13 @@ class DriverServicer(metaclass=abc.ABCMeta):
"""Get task results"""
pass

@abc.abstractmethod
def GetRun(self,
request: flwr.proto.run_pb2.GetRunRequest,
context: grpc.ServicerContext,
) -> flwr.proto.run_pb2.GetRunResponse:
"""Get run details"""
pass


def add_DriverServicer_to_server(servicer: DriverServicer, server: grpc.Server) -> None: ...
Loading

0 comments on commit 82a399a

Please sign in to comment.