Skip to content

Commit

Permalink
feat(framework) Add run configs
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll committed Jul 5, 2024
1 parent 131edbb commit 2fe95f2
Show file tree
Hide file tree
Showing 33 changed files with 509 additions and 260 deletions.
42 changes: 42 additions & 0 deletions src/proto/flwr/proto/common.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright 2024 Flower Labs GmbH. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ==============================================================================

syntax = "proto3";

package flwr.proto;

message DoubleList { repeated double vals = 1; }
message Sint64List { repeated sint64 vals = 1; }
message BoolList { repeated bool vals = 1; }
message StringList { repeated string vals = 1; }
message BytesList { repeated bytes vals = 1; }

message ConfigsRecordValue {
oneof value {
// Single element
double double = 1;
sint64 sint64 = 2;
bool bool = 3;
string string = 4;
bytes bytes = 5;

// List types
DoubleList double_list = 21;
Sint64List sint64_list = 22;
BoolList bool_list = 23;
StringList string_list = 24;
BytesList bytes_list = 25;
}
}
2 changes: 2 additions & 0 deletions src/proto/flwr/proto/driver.proto
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package flwr.proto;
import "flwr/proto/node.proto";
import "flwr/proto/task.proto";
import "flwr/proto/run.proto";
import "flwr/proto/common.proto";

service Driver {
// Request run_id
Expand All @@ -42,6 +43,7 @@ service Driver {
message CreateRunRequest {
string fab_id = 1;
string fab_version = 2;
map<string, ConfigsRecordValue> override_config = 3;
}
message CreateRunResponse { sint64 run_id = 1; }

Expand Down
7 changes: 6 additions & 1 deletion src/proto/flwr/proto/exec.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ syntax = "proto3";

package flwr.proto;

import "flwr/proto/common.proto";

service Exec {
// Start run upon request
rpc StartRun(StartRunRequest) returns (StartRunResponse) {}
Expand All @@ -25,7 +27,10 @@ service Exec {
rpc StreamLogs(StreamLogsRequest) returns (stream StreamLogsResponse) {}
}

message StartRunRequest { bytes fab_file = 1; }
message StartRunRequest {
bytes fab_file = 1;
map<string, ConfigsRecordValue> override_config = 2;
}
message StartRunResponse { sint64 run_id = 1; }
message StreamLogsRequest { sint64 run_id = 1; }
message StreamLogsResponse { string log_output = 1; }
24 changes: 1 addition & 23 deletions src/proto/flwr/proto/recordset.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@ syntax = "proto3";

package flwr.proto;

message DoubleList { repeated double vals = 1; }
message Sint64List { repeated sint64 vals = 1; }
message BoolList { repeated bool vals = 1; }
message StringList { repeated string vals = 1; }
message BytesList { repeated bytes vals = 1; }
import "flwr/proto/common.proto";

message Array {
string dtype = 1;
Expand All @@ -42,24 +38,6 @@ message MetricsRecordValue {
}
}

message ConfigsRecordValue {
oneof value {
// Single element
double double = 1;
sint64 sint64 = 2;
bool bool = 3;
string string = 4;
bytes bytes = 5;

// List types
DoubleList double_list = 21;
Sint64List sint64_list = 22;
BoolList bool_list = 23;
StringList string_list = 24;
BytesList bytes_list = 25;
}
}

message ParametersRecord {
repeated string data_keys = 1;
repeated Array data_values = 2;
Expand Down
3 changes: 3 additions & 0 deletions src/proto/flwr/proto/run.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@ syntax = "proto3";

package flwr.proto;

import "flwr/proto/common.proto";

message Run {
sint64 run_id = 1;
string fab_id = 2;
string fab_version = 3;
map<string, ConfigsRecordValue> override_config = 4;
}
message GetRunRequest { sint64 run_id = 1; }
message GetRunResponse { Run run = 1; }
20 changes: 16 additions & 4 deletions src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from enum import Enum
from logging import DEBUG
from pathlib import Path
from typing import Optional
from typing import Dict, Optional

import typer
from typing_extensions import Annotated
Expand All @@ -28,8 +28,11 @@
from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
from flwr.common.logger import log
from flwr.common.serde import record_value_dict_to_proto
from flwr.common.typing import ConfigsRecordValues
from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
from flwr.proto.exec_pb2_grpc import ExecStub
from flwr.proto.recordset_pb2 import ConfigsRecordValue as ProtoConfigsRecordValue
from flwr.simulation.run_simulation import _run_simulation


Expand Down Expand Up @@ -61,7 +64,7 @@ def run(
) -> None:
"""Run Flower project."""
if use_superexec:
_start_superexec_run(directory)
_start_superexec_run({}, directory)
return

typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
Expand Down Expand Up @@ -115,7 +118,9 @@ def run(
)


def _start_superexec_run(directory: Optional[Path]) -> None:
def _start_superexec_run(
override_config: Dict[str, ConfigsRecordValues], directory: Optional[Path]
) -> None:
def on_channel_state_change(channel_connectivity: str) -> None:
"""Log channel connectivity."""
log(DEBUG, channel_connectivity)
Expand All @@ -132,6 +137,13 @@ def on_channel_state_change(channel_connectivity: str) -> None:

fab_path = build(directory)

req = StartRunRequest(fab_file=Path(fab_path).read_bytes())
req = StartRunRequest(
fab_file=Path(fab_path).read_bytes(),
override_config=record_value_dict_to_proto(
override_config,
[bool, int, float, str, bytes],
ProtoConfigsRecordValue,
),
)
res = stub.StartRun(req)
typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
14 changes: 9 additions & 5 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union

from cryptography.hazmat.primitives.asymmetric import ec
from flwr.common.config import get_fused_config
from grpc import RpcError

from flwr.client.client import Client
Expand All @@ -41,6 +42,7 @@
from flwr.common.logger import log, warn_deprecated_feature
from flwr.common.message import Error
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
from flwr.common.typing import Run

from .grpc_adapter_client.connection import grpc_adapter
from .grpc_client.connection import grpc_connection
Expand Down Expand Up @@ -315,8 +317,7 @@ def _on_backoff(retry_state: RetryState) -> None:
)

node_state = NodeState(partition_id=partition_id)
# run_id -> (fab_id, fab_version)
run_info: Dict[int, Tuple[str, str]] = {}
run_info: Dict[int, Run] = {}

while not app_state_tracker.interrupt:
sleep_duration: int = 0
Expand Down Expand Up @@ -371,13 +372,14 @@ def _on_backoff(retry_state: RetryState) -> None:
run_info[run_id] = get_run(run_id)
# If get_run is None, i.e., in grpc-bidi mode
else:
run_info[run_id] = ("", "")
run_info[run_id] = Run(run_id, "", "", {})

# Register context for this run
node_state.register_context(run_id=run_id)

# Retrieve context for this run
context = node_state.retrieve_context(run_id=run_id)
context.config = get_fused_config(run_info[run_id])

# Create an error reply message that will never be used to prevent
# the used-before-assignment linting error
Expand All @@ -388,7 +390,9 @@ def _on_backoff(retry_state: RetryState) -> None:
# Handle app loading and task message
try:
# Load ClientApp instance
client_app: ClientApp = load_client_app_fn(*run_info[run_id])
client_app: ClientApp = load_client_app_fn(
run_info[run_id].fab_id, run_info[run_id].fab_version
)

# Execute ClientApp
reply_message = client_app(message=message, context=context)
Expand Down Expand Up @@ -573,7 +577,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
Callable[[Message], None],
Optional[Callable[[], None]],
Optional[Callable[[], None]],
Optional[Callable[[int], Tuple[str, str]]],
Optional[Callable[[int], Run]],
]
],
],
Expand Down
3 changes: 2 additions & 1 deletion src/py/flwr/client/grpc_adapter_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from flwr.common.logger import log
from flwr.common.message import Message
from flwr.common.retry_invoker import RetryInvoker
from flwr.common.typing import Run


@contextmanager
Expand All @@ -45,7 +46,7 @@ def grpc_adapter( # pylint: disable=R0913
Callable[[Message], None],
Optional[Callable[[], None]],
Optional[Callable[[], None]],
Optional[Callable[[int], Tuple[str, str]]],
Optional[Callable[[int], Run]],
]
]:
"""Primitives for request/response-based interaction with a server via GrpcAdapter.
Expand Down
3 changes: 2 additions & 1 deletion src/py/flwr/client/grpc_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from flwr.common.grpc import create_channel
from flwr.common.logger import log
from flwr.common.retry_invoker import RetryInvoker
from flwr.common.typing import Run
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
ClientMessage,
Reason,
Expand Down Expand Up @@ -73,7 +74,7 @@ def grpc_connection( # pylint: disable=R0913, R0915
Callable[[Message], None],
Optional[Callable[[], None]],
Optional[Callable[[], None]],
Optional[Callable[[int], Tuple[str, str]]],
Optional[Callable[[int], Run]],
]
]:
"""Establish a gRPC connection to a gRPC server.
Expand Down
18 changes: 14 additions & 4 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@
from flwr.common.logger import log
from flwr.common.message import Message, Metadata
from flwr.common.retry_invoker import RetryInvoker
from flwr.common.serde import message_from_taskins, message_to_taskres
from flwr.common.serde import (
message_from_taskins,
message_to_taskres,
record_value_dict_from_proto,
)
from flwr.common.typing import Run
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
DeleteNodeRequest,
Expand Down Expand Up @@ -80,7 +85,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
Callable[[Message], None],
Optional[Callable[[], None]],
Optional[Callable[[], None]],
Optional[Callable[[int], Tuple[str, str]]],
Optional[Callable[[int], Run]],
]
]:
"""Primitives for request/response-based interaction with a server.
Expand Down Expand Up @@ -266,7 +271,7 @@ def send(message: Message) -> None:
# Cleanup
metadata = None

def get_run(run_id: int) -> Tuple[str, str]:
def get_run(run_id: int) -> Run:
# Call FleetAPI
get_run_request = GetRunRequest(run_id=run_id)
get_run_response: GetRunResponse = retry_invoker.invoke(
Expand All @@ -275,7 +280,12 @@ def get_run(run_id: int) -> Tuple[str, str]:
)

# Return fab_id and fab_version
return get_run_response.run.fab_id, get_run_response.run.fab_version
return Run(
run_id,
get_run_response.run.fab_id,
get_run_response.run.fab_version,
record_value_dict_from_proto(get_run_response.run.override_config),
)

try:
# Yield methods
Expand Down
20 changes: 15 additions & 5 deletions src/py/flwr/client/rest_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Callable, Iterator, Optional, Tuple, Type, TypeVar, Union

from cryptography.hazmat.primitives.asymmetric import ec
from flwr.common.typing import Run
from google.protobuf.message import Message as GrpcMessage

from flwr.client.heartbeat import start_ping_loop
Expand All @@ -40,7 +41,11 @@
from flwr.common.logger import log
from flwr.common.message import Message, Metadata
from flwr.common.retry_invoker import RetryInvoker
from flwr.common.serde import message_from_taskins, message_to_taskres
from flwr.common.serde import (
message_from_taskins,
message_to_taskres,
record_value_dict_from_proto,
)
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
CreateNodeResponse,
Expand Down Expand Up @@ -91,7 +96,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
Callable[[Message], None],
Optional[Callable[[], None]],
Optional[Callable[[], None]],
Optional[Callable[[int], Tuple[str, str]]],
Optional[Callable[[int], Run]],
]
]:
"""Primitives for request/response-based interaction with a server.
Expand Down Expand Up @@ -344,16 +349,21 @@ def send(message: Message) -> None:
res.results, # pylint: disable=no-member
)

def get_run(run_id: int) -> Tuple[str, str]:
def get_run(run_id: int) -> Run:
# Construct the request
req = GetRunRequest(run_id=run_id)

# Send the request
res = _request(req, GetRunResponse, PATH_GET_RUN)
if res is None:
return "", ""
return Run(run_id, "", "", {})

return res.run.fab_id, res.run.fab_version
return Run(
run_id,
res.run.fab_id,
res.run.fab_version,
record_value_dict_from_proto(res.run.override_config),
)

try:
# Yield methods
Expand Down
Loading

0 comments on commit 2fe95f2

Please sign in to comment.