Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Add node to all Fleet requests #4250

Merged
merged 7 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/proto/flwr/proto/fab.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ syntax = "proto3";

package flwr.proto;

import "flwr/proto/node.proto";

message Fab {
// This field is the hash of the data field. It is used to identify the data.
// The hash is calculated using the SHA-256 algorithm and is represented as a
Expand All @@ -26,5 +28,8 @@ message Fab {
bytes content = 2;
}

message GetFabRequest { string hash_str = 1; }
message GetFabRequest {
Node node = 1;
string hash_str = 2;
}
message GetFabResponse { Fab fab = 1; }
5 changes: 4 additions & 1 deletion src/proto/flwr/proto/fleet.proto
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ message PullTaskInsResponse {
}

// PushTaskRes messages
message PushTaskResRequest { repeated TaskRes task_res_list = 1; }
message PushTaskResRequest {
Node node = 1;
repeated TaskRes task_res_list = 2;
}
message PushTaskResResponse {
Reconnect reconnect = 1;
map<string, uint32> results = 2;
Expand Down
11 changes: 9 additions & 2 deletions src/proto/flwr/proto/run.proto
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ syntax = "proto3";
package flwr.proto;

import "flwr/proto/fab.proto";
import "flwr/proto/node.proto";
import "flwr/proto/transport.proto";

message Run {
Expand Down Expand Up @@ -47,7 +48,10 @@ message CreateRunRequest {
message CreateRunResponse { uint64 run_id = 1; }

// GetRun
message GetRunRequest { uint64 run_id = 1; }
message GetRunRequest {
Node node = 1;
uint64 run_id = 2;
}
message GetRunResponse { Run run = 1; }

// UpdateRunStatus
Expand All @@ -58,5 +62,8 @@ message UpdateRunStatusRequest {
message UpdateRunStatusResponse {}

// GetRunStatus
message GetRunStatusRequest { repeated uint64 run_ids = 1; }
message GetRunStatusRequest {
Node node = 1;
repeated uint64 run_ids = 2;
}
message GetRunStatusResponse { map<uint64, RunStatus> run_status_dict = 1; }
6 changes: 3 additions & 3 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,15 +269,15 @@ def send(message: Message) -> None:
task_res = message_to_taskres(message)

# Serialize ProtoBuf to bytes
request = PushTaskResRequest(task_res_list=[task_res])
request = PushTaskResRequest(node=node, task_res_list=[task_res])
_ = retry_invoker.invoke(stub.PushTaskRes, request)

# Cleanup
metadata = None

def get_run(run_id: int) -> Run:
# Call FleetAPI
get_run_request = GetRunRequest(run_id=run_id)
get_run_request = GetRunRequest(node=node, run_id=run_id)
get_run_response: GetRunResponse = retry_invoker.invoke(
stub.GetRun,
request=get_run_request,
Expand All @@ -294,7 +294,7 @@ def get_run(run_id: int) -> Run:

def get_fab(fab_hash: str) -> Fab:
# Call FleetAPI
get_fab_request = GetFabRequest(hash_str=fab_hash)
get_fab_request = GetFabRequest(node=node, hash_str=fab_hash)
get_fab_response: GetFabResponse = retry_invoker.invoke(
stub.GetFab,
request=get_fab_request,
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/client/rest_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def send(message: Message) -> None:
task_res = message_to_taskres(message)

# Serialize ProtoBuf to bytes
req = PushTaskResRequest(task_res_list=[task_res])
req = PushTaskResRequest(node=node, task_res_list=[task_res])

# Send the request
res = _request(req, PushTaskResResponse, PATH_PUSH_TASK_RES)
Expand All @@ -356,7 +356,7 @@ def send(message: Message) -> None:

def get_run(run_id: int) -> Run:
# Construct the request
req = GetRunRequest(run_id=run_id)
req = GetRunRequest(node=node, run_id=run_id)

# Send the request
res = _request(req, GetRunResponse, PATH_GET_RUN)
Expand All @@ -373,7 +373,7 @@ def get_run(run_id: int) -> Run:

def get_fab(fab_hash: str) -> Fab:
# Construct the request
req = GetFabRequest(hash_str=fab_hash)
req = GetFabRequest(node=node, hash_str=fab_hash)

# Send the request
res = _request(req, GetFabResponse, PATH_GET_FAB)
Expand Down
15 changes: 8 additions & 7 deletions src/py/flwr/proto/fab_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion src/py/flwr/proto/fab_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
isort:skip_file
"""
import builtins
import flwr.proto.node_pb2
import google.protobuf.descriptor
import google.protobuf.message
import typing
Expand Down Expand Up @@ -33,13 +34,18 @@ global___Fab = Fab

class GetFabRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
NODE_FIELD_NUMBER: builtins.int
HASH_STR_FIELD_NUMBER: builtins.int
@property
def node(self) -> flwr.proto.node_pb2.Node: ...
hash_str: typing.Text
def __init__(self,
*,
node: typing.Optional[flwr.proto.node_pb2.Node] = ...,
hash_str: typing.Text = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["hash_str",b"hash_str"]) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["hash_str",b"hash_str","node",b"node"]) -> None: ...
global___GetFabRequest = GetFabRequest

class GetFabResponse(google.protobuf.message.Message):
Expand Down
20 changes: 10 additions & 10 deletions src/py/flwr/proto/fleet_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion src/py/flwr/proto/fleet_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,19 @@ global___PullTaskInsResponse = PullTaskInsResponse
class PushTaskResRequest(google.protobuf.message.Message):
"""PushTaskRes messages"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
NODE_FIELD_NUMBER: builtins.int
TASK_RES_LIST_FIELD_NUMBER: builtins.int
@property
def node(self) -> flwr.proto.node_pb2.Node: ...
@property
def task_res_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.task_pb2.TaskRes]: ...
def __init__(self,
*,
node: typing.Optional[flwr.proto.node_pb2.Node] = ...,
task_res_list: typing.Optional[typing.Iterable[flwr.proto.task_pb2.TaskRes]] = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["task_res_list",b"task_res_list"]) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["node",b"node","task_res_list",b"task_res_list"]) -> None: ...
global___PushTaskResRequest = PushTaskResRequest

class PushTaskResResponse(google.protobuf.message.Message):
Expand Down
Loading