From 2fe95f2bbbd6689f7462e474bb29fa105d101a1b Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Fri, 5 Jul 2024 19:07:17 +0200 Subject: [PATCH] feat(framework) Add run configs --- src/proto/flwr/proto/common.proto | 42 ++++++ src/proto/flwr/proto/driver.proto | 2 + src/proto/flwr/proto/exec.proto | 7 +- src/proto/flwr/proto/recordset.proto | 24 +--- src/proto/flwr/proto/run.proto | 3 + src/py/flwr/cli/run/run.py | 20 ++- src/py/flwr/client/app.py | 14 +- .../client/grpc_adapter_client/connection.py | 3 +- src/py/flwr/client/grpc_client/connection.py | 3 +- .../client/grpc_rere_client/connection.py | 18 ++- src/py/flwr/client/rest_client/connection.py | 20 ++- src/py/flwr/common/config.py | 8 ++ src/py/flwr/common/context.py | 15 ++- src/py/flwr/common/serde.py | 12 +- src/py/flwr/common/typing.py | 1 + src/py/flwr/proto/common_pb2.py | 36 +++++ src/py/flwr/proto/common_pb2.pyi | 121 +++++++++++++++++ src/py/flwr/proto/common_pb2_grpc.py | 4 + src/py/flwr/proto/common_pb2_grpc.pyi | 4 + src/py/flwr/proto/driver_pb2.py | 43 +++--- src/py/flwr/proto/driver_pb2.pyi | 22 ++- src/py/flwr/proto/exec_pb2.py | 27 ++-- src/py/flwr/proto/exec_pb2.pyi | 23 +++- src/py/flwr/proto/recordset_pb2.py | 59 ++++---- src/py/flwr/proto/recordset_pb2.pyi | 126 ++---------------- src/py/flwr/proto/run_pb2.py | 19 ++- src/py/flwr/proto/run_pb2.pyi | 23 +++- src/py/flwr/server/driver/grpc_driver.py | 7 +- src/py/flwr/server/run_serverapp.py | 21 ++- .../superlink/driver/driver_servicer.py | 7 +- src/py/flwr/server/superlink/state/state.py | 11 +- src/py/flwr/superexec/deployment.py | 17 ++- src/py/flwr/superexec/executor.py | 7 +- 33 files changed, 509 insertions(+), 260 deletions(-) create mode 100644 src/proto/flwr/proto/common.proto create mode 100644 src/py/flwr/proto/common_pb2.py create mode 100644 src/py/flwr/proto/common_pb2.pyi create mode 100644 src/py/flwr/proto/common_pb2_grpc.py create mode 100644 src/py/flwr/proto/common_pb2_grpc.pyi diff --git a/src/proto/flwr/proto/common.proto b/src/proto/flwr/proto/common.proto new file mode 100644 index 000000000000..4f018f6f0735 --- /dev/null +++ b/src/proto/flwr/proto/common.proto @@ -0,0 +1,42 @@ +// Copyright 2024 Flower Labs GmbH. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================== + +syntax = "proto3"; + +package flwr.proto; + +message DoubleList { repeated double vals = 1; } +message Sint64List { repeated sint64 vals = 1; } +message BoolList { repeated bool vals = 1; } +message StringList { repeated string vals = 1; } +message BytesList { repeated bytes vals = 1; } + +message ConfigsRecordValue { + oneof value { + // Single element + double double = 1; + sint64 sint64 = 2; + bool bool = 3; + string string = 4; + bytes bytes = 5; + + // List types + DoubleList double_list = 21; + Sint64List sint64_list = 22; + BoolList bool_list = 23; + StringList string_list = 24; + BytesList bytes_list = 25; + } +} diff --git a/src/proto/flwr/proto/driver.proto b/src/proto/flwr/proto/driver.proto index edbd5d91bb5b..4f8ccbbeaf32 100644 --- a/src/proto/flwr/proto/driver.proto +++ b/src/proto/flwr/proto/driver.proto @@ -20,6 +20,7 @@ package flwr.proto; import "flwr/proto/node.proto"; import "flwr/proto/task.proto"; import "flwr/proto/run.proto"; +import "flwr/proto/common.proto"; service Driver { // Request run_id @@ -42,6 +43,7 @@ service Driver { message CreateRunRequest { string fab_id = 1; string fab_version = 2; + map override_config = 3; } message CreateRunResponse { sint64 run_id = 1; } diff --git a/src/proto/flwr/proto/exec.proto b/src/proto/flwr/proto/exec.proto index 8e5f53b02ca8..097b32d0296f 100644 --- a/src/proto/flwr/proto/exec.proto +++ b/src/proto/flwr/proto/exec.proto @@ -17,6 +17,8 @@ syntax = "proto3"; package flwr.proto; +import "flwr/proto/common.proto"; + service Exec { // Start run upon request rpc StartRun(StartRunRequest) returns (StartRunResponse) {} @@ -25,7 +27,10 @@ service Exec { rpc StreamLogs(StreamLogsRequest) returns (stream StreamLogsResponse) {} } -message StartRunRequest { bytes fab_file = 1; } +message StartRunRequest { + bytes fab_file = 1; + map override_config = 2; +} message StartRunResponse { sint64 run_id = 1; } message StreamLogsRequest { sint64 run_id = 1; } message StreamLogsResponse { string log_output = 1; } diff --git a/src/proto/flwr/proto/recordset.proto b/src/proto/flwr/proto/recordset.proto index d51d0f9ce416..53311f26dffe 100644 --- a/src/proto/flwr/proto/recordset.proto +++ b/src/proto/flwr/proto/recordset.proto @@ -17,11 +17,7 @@ syntax = "proto3"; package flwr.proto; -message DoubleList { repeated double vals = 1; } -message Sint64List { repeated sint64 vals = 1; } -message BoolList { repeated bool vals = 1; } -message StringList { repeated string vals = 1; } -message BytesList { repeated bytes vals = 1; } +import "flwr/proto/common.proto"; message Array { string dtype = 1; @@ -42,24 +38,6 @@ message MetricsRecordValue { } } -message ConfigsRecordValue { - oneof value { - // Single element - double double = 1; - sint64 sint64 = 2; - bool bool = 3; - string string = 4; - bytes bytes = 5; - - // List types - DoubleList double_list = 21; - Sint64List sint64_list = 22; - BoolList bool_list = 23; - StringList string_list = 24; - BytesList bytes_list = 25; - } -} - message ParametersRecord { repeated string data_keys = 1; repeated Array data_values = 2; diff --git a/src/proto/flwr/proto/run.proto b/src/proto/flwr/proto/run.proto index 76a7fd91532f..a78dd5b6fe92 100644 --- a/src/proto/flwr/proto/run.proto +++ b/src/proto/flwr/proto/run.proto @@ -17,10 +17,13 @@ syntax = "proto3"; package flwr.proto; +import "flwr/proto/common.proto"; + message Run { sint64 run_id = 1; string fab_id = 2; string fab_version = 3; + map override_config = 4; } message GetRunRequest { sint64 run_id = 1; } message GetRunResponse { Run run = 1; } diff --git a/src/py/flwr/cli/run/run.py b/src/py/flwr/cli/run/run.py index f5882bd14ab8..3f0e4fe7ad51 100644 --- a/src/py/flwr/cli/run/run.py +++ b/src/py/flwr/cli/run/run.py @@ -18,7 +18,7 @@ from enum import Enum from logging import DEBUG from pathlib import Path -from typing import Optional +from typing import Dict, Optional import typer from typing_extensions import Annotated @@ -28,8 +28,11 @@ from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel from flwr.common.logger import log +from flwr.common.serde import record_value_dict_to_proto +from flwr.common.typing import ConfigsRecordValues from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611 from flwr.proto.exec_pb2_grpc import ExecStub +from flwr.proto.recordset_pb2 import ConfigsRecordValue as ProtoConfigsRecordValue from flwr.simulation.run_simulation import _run_simulation @@ -61,7 +64,7 @@ def run( ) -> None: """Run Flower project.""" if use_superexec: - _start_superexec_run(directory) + _start_superexec_run({}, directory) return typer.secho("Loading project configuration... ", fg=typer.colors.BLUE) @@ -115,7 +118,9 @@ def run( ) -def _start_superexec_run(directory: Optional[Path]) -> None: +def _start_superexec_run( + override_config: Dict[str, ConfigsRecordValues], directory: Optional[Path] +) -> None: def on_channel_state_change(channel_connectivity: str) -> None: """Log channel connectivity.""" log(DEBUG, channel_connectivity) @@ -132,6 +137,13 @@ def on_channel_state_change(channel_connectivity: str) -> None: fab_path = build(directory) - req = StartRunRequest(fab_file=Path(fab_path).read_bytes()) + req = StartRunRequest( + fab_file=Path(fab_path).read_bytes(), + override_config=record_value_dict_to_proto( + override_config, + [bool, int, float, str, bytes], + ProtoConfigsRecordValue, + ), + ) res = stub.StartRun(req) typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index d2d5a79f32f3..0a5e54422897 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -22,6 +22,7 @@ from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union from cryptography.hazmat.primitives.asymmetric import ec +from flwr.common.config import get_fused_config from grpc import RpcError from flwr.client.client import Client @@ -41,6 +42,7 @@ from flwr.common.logger import log, warn_deprecated_feature from flwr.common.message import Error from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential +from flwr.common.typing import Run from .grpc_adapter_client.connection import grpc_adapter from .grpc_client.connection import grpc_connection @@ -315,8 +317,7 @@ def _on_backoff(retry_state: RetryState) -> None: ) node_state = NodeState(partition_id=partition_id) - # run_id -> (fab_id, fab_version) - run_info: Dict[int, Tuple[str, str]] = {} + run_info: Dict[int, Run] = {} while not app_state_tracker.interrupt: sleep_duration: int = 0 @@ -371,13 +372,14 @@ def _on_backoff(retry_state: RetryState) -> None: run_info[run_id] = get_run(run_id) # If get_run is None, i.e., in grpc-bidi mode else: - run_info[run_id] = ("", "") + run_info[run_id] = Run(run_id, "", "", {}) # Register context for this run node_state.register_context(run_id=run_id) # Retrieve context for this run context = node_state.retrieve_context(run_id=run_id) + context.config = get_fused_config(run_info[run_id]) # Create an error reply message that will never be used to prevent # the used-before-assignment linting error @@ -388,7 +390,9 @@ def _on_backoff(retry_state: RetryState) -> None: # Handle app loading and task message try: # Load ClientApp instance - client_app: ClientApp = load_client_app_fn(*run_info[run_id]) + client_app: ClientApp = load_client_app_fn( + run_info[run_id].fab_id, run_info[run_id].fab_version + ) # Execute ClientApp reply_message = client_app(message=message, context=context) @@ -573,7 +577,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ], ], diff --git a/src/py/flwr/client/grpc_adapter_client/connection.py b/src/py/flwr/client/grpc_adapter_client/connection.py index e4e32b3accd0..971b630e470b 100644 --- a/src/py/flwr/client/grpc_adapter_client/connection.py +++ b/src/py/flwr/client/grpc_adapter_client/connection.py @@ -27,6 +27,7 @@ from flwr.common.logger import log from flwr.common.message import Message from flwr.common.retry_invoker import RetryInvoker +from flwr.common.typing import Run @contextmanager @@ -45,7 +46,7 @@ def grpc_adapter( # pylint: disable=R0913 Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ]: """Primitives for request/response-based interaction with a server via GrpcAdapter. diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 8c049861c672..3e9f261c1ecf 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -38,6 +38,7 @@ from flwr.common.grpc import create_channel from flwr.common.logger import log from flwr.common.retry_invoker import RetryInvoker +from flwr.common.typing import Run from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, Reason, @@ -73,7 +74,7 @@ def grpc_connection( # pylint: disable=R0913, R0915 Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ]: """Establish a gRPC connection to a gRPC server. diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 34dc0e417383..31ac47a5bf84 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -40,7 +40,12 @@ from flwr.common.logger import log from flwr.common.message import Message, Metadata from flwr.common.retry_invoker import RetryInvoker -from flwr.common.serde import message_from_taskins, message_to_taskres +from flwr.common.serde import ( + message_from_taskins, + message_to_taskres, + record_value_dict_from_proto, +) +from flwr.common.typing import Run from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, @@ -80,7 +85,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915 Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ]: """Primitives for request/response-based interaction with a server. @@ -266,7 +271,7 @@ def send(message: Message) -> None: # Cleanup metadata = None - def get_run(run_id: int) -> Tuple[str, str]: + def get_run(run_id: int) -> Run: # Call FleetAPI get_run_request = GetRunRequest(run_id=run_id) get_run_response: GetRunResponse = retry_invoker.invoke( @@ -275,7 +280,12 @@ def get_run(run_id: int) -> Tuple[str, str]: ) # Return fab_id and fab_version - return get_run_response.run.fab_id, get_run_response.run.fab_version + return Run( + run_id, + get_run_response.run.fab_id, + get_run_response.run.fab_version, + record_value_dict_from_proto(get_run_response.run.override_config), + ) try: # Yield methods diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index db5bd7eb6770..01417cfa91a2 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -24,6 +24,7 @@ from typing import Callable, Iterator, Optional, Tuple, Type, TypeVar, Union from cryptography.hazmat.primitives.asymmetric import ec +from flwr.common.typing import Run from google.protobuf.message import Message as GrpcMessage from flwr.client.heartbeat import start_ping_loop @@ -40,7 +41,11 @@ from flwr.common.logger import log from flwr.common.message import Message, Metadata from flwr.common.retry_invoker import RetryInvoker -from flwr.common.serde import message_from_taskins, message_to_taskres +from flwr.common.serde import ( + message_from_taskins, + message_to_taskres, + record_value_dict_from_proto, +) from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, @@ -91,7 +96,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915 Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ]: """Primitives for request/response-based interaction with a server. @@ -344,16 +349,21 @@ def send(message: Message) -> None: res.results, # pylint: disable=no-member ) - def get_run(run_id: int) -> Tuple[str, str]: + def get_run(run_id: int) -> Run: # Construct the request req = GetRunRequest(run_id=run_id) # Send the request res = _request(req, GetRunResponse, PATH_GET_RUN) if res is None: - return "", "" + return Run(run_id, "", "", {}) - return res.run.fab_id, res.run.fab_version + return Run( + run_id, + res.run.fab_id, + res.run.fab_version, + record_value_dict_from_proto(res.run.override_config), + ) try: # Yield methods diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index 20de00a6fba9..40a9a7bce5e3 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -18,6 +18,7 @@ from pathlib import Path from typing import Any, Dict, Optional, Union +from flwr.common.typing import ConfigsRecordValues, Run import tomli from flwr.cli.config_utils import validate_fields @@ -71,3 +72,10 @@ def get_project_config(project_dir: Union[str, Path]) -> Dict[str, Any]: ) return config + + +def get_fused_config(run: Run) -> Dict[str, ConfigsRecordValues]: + """Get the config using the fab_id and the fab_version, remove the nesting by adding + the nested keys as prefixes separated by dots, and fuse it with the override + dict.""" + return {} diff --git a/src/py/flwr/common/context.py b/src/py/flwr/common/context.py index 8fe0f1781817..1f89a7656107 100644 --- a/src/py/flwr/common/context.py +++ b/src/py/flwr/common/context.py @@ -16,7 +16,9 @@ from dataclasses import dataclass -from typing import Optional +from typing import Dict, Optional + +from flwr.common.typing import ConfigsRecordValues from .record import RecordSet @@ -38,11 +40,20 @@ class Context: An index that specifies the data partition that the ClientApp using this Context object should make use of. Setting this attribute is better suited for simulation or proto typing setups. + config : Dict[str, ConfigsRecordValues] + A config (key/value mapping) held by the entity in a given run and that will + stay local. It can be used at any point during the lifecycle of this entity + (e.g. across multiple rounds) """ state: RecordSet partition_id: Optional[int] + config: Dict[str, ConfigsRecordValues] - def __init__(self, state: RecordSet, partition_id: Optional[int] = None) -> None: + def __init__( + self, + state: RecordSet, + partition_id: Optional[int] = None, + ) -> None: self.state = state self.partition_id = partition_id diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index 84932b806aff..dfdf84bbe3f2 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -411,7 +411,7 @@ def _record_value_from_proto(value_proto: GrpcMessage) -> Any: return value -def _record_value_dict_to_proto( +def record_value_dict_to_proto( value_dict: TypedDict[str, Any], allowed_types: List[type], value_proto_class: Type[T], @@ -431,7 +431,7 @@ def proto(_v: Any) -> T: return {k: proto(v) for k, v in value_dict.items()} -def _record_value_dict_from_proto( +def record_value_dict_from_proto( value_dict_proto: MutableMapping[str, Any] ) -> Dict[str, Any]: """Deserialize the record value dict from ProtoBuf.""" @@ -476,7 +476,7 @@ def parameters_record_from_proto( def metrics_record_to_proto(record: MetricsRecord) -> ProtoMetricsRecord: """Serialize MetricsRecord to ProtoBuf.""" return ProtoMetricsRecord( - data=_record_value_dict_to_proto(record, [float, int], ProtoMetricsRecordValue) + data=record_value_dict_to_proto(record, [float, int], ProtoMetricsRecordValue) ) @@ -485,7 +485,7 @@ def metrics_record_from_proto(record_proto: ProtoMetricsRecord) -> MetricsRecord return MetricsRecord( metrics_dict=cast( Dict[str, typing.MetricsRecordValues], - _record_value_dict_from_proto(record_proto.data), + record_value_dict_from_proto(record_proto.data), ), keep_input=False, ) @@ -494,7 +494,7 @@ def metrics_record_from_proto(record_proto: ProtoMetricsRecord) -> MetricsRecord def configs_record_to_proto(record: ConfigsRecord) -> ProtoConfigsRecord: """Serialize ConfigsRecord to ProtoBuf.""" return ProtoConfigsRecord( - data=_record_value_dict_to_proto( + data=record_value_dict_to_proto( record, [bool, int, float, str, bytes], ProtoConfigsRecordValue, @@ -507,7 +507,7 @@ def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord return ConfigsRecord( configs_dict=cast( Dict[str, typing.ConfigsRecordValues], - _record_value_dict_from_proto(record_proto.data), + record_value_dict_from_proto(record_proto.data), ), keep_input=False, ) diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index f51830955679..d124a93a5163 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -194,3 +194,4 @@ class Run: run_id: int fab_id: str fab_version: str + overrides: Dict[str, ConfigsRecordValues] diff --git a/src/py/flwr/proto/common_pb2.py b/src/py/flwr/proto/common_pb2.py new file mode 100644 index 000000000000..1025aa862933 --- /dev/null +++ b/src/py/flwr/proto/common_pb2.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: flwr/proto/common.proto +# Protobuf Python Version: 4.25.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/common.proto\x12\nflwr.proto\"\x1a\n\nDoubleList\x12\x0c\n\x04vals\x18\x01 \x03(\x01\"\x1a\n\nSint64List\x12\x0c\n\x04vals\x18\x01 \x03(\x12\"\x18\n\x08\x42oolList\x12\x0c\n\x04vals\x18\x01 \x03(\x08\"\x1a\n\nStringList\x12\x0c\n\x04vals\x18\x01 \x03(\t\"\x19\n\tBytesList\x12\x0c\n\x04vals\x18\x01 \x03(\x0c\"\xd9\x02\n\x12\x43onfigsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x12)\n\tbool_list\x18\x17 \x01(\x0b\x32\x14.flwr.proto.BoolListH\x00\x12-\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x16.flwr.proto.StringListH\x00\x12+\n\nbytes_list\x18\x19 \x01(\x0b\x32\x15.flwr.proto.BytesListH\x00\x42\x07\n\x05valueb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.common_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_DOUBLELIST']._serialized_start=39 + _globals['_DOUBLELIST']._serialized_end=65 + _globals['_SINT64LIST']._serialized_start=67 + _globals['_SINT64LIST']._serialized_end=93 + _globals['_BOOLLIST']._serialized_start=95 + _globals['_BOOLLIST']._serialized_end=119 + _globals['_STRINGLIST']._serialized_start=121 + _globals['_STRINGLIST']._serialized_end=147 + _globals['_BYTESLIST']._serialized_start=149 + _globals['_BYTESLIST']._serialized_end=174 + _globals['_CONFIGSRECORDVALUE']._serialized_start=177 + _globals['_CONFIGSRECORDVALUE']._serialized_end=522 +# @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/common_pb2.pyi b/src/py/flwr/proto/common_pb2.pyi new file mode 100644 index 000000000000..e2539a7300a9 --- /dev/null +++ b/src/py/flwr/proto/common_pb2.pyi @@ -0,0 +1,121 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import builtins +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.message +import typing +import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class DoubleList(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[builtins.float]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___DoubleList = DoubleList + +class Sint64List(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[builtins.int]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___Sint64List = Sint64List + +class BoolList(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bool]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[builtins.bool]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___BoolList = BoolList + +class StringList(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[typing.Text]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___StringList = StringList + +class BytesList(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[builtins.bytes]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___BytesList = BytesList + +class ConfigsRecordValue(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + DOUBLE_FIELD_NUMBER: builtins.int + SINT64_FIELD_NUMBER: builtins.int + BOOL_FIELD_NUMBER: builtins.int + STRING_FIELD_NUMBER: builtins.int + BYTES_FIELD_NUMBER: builtins.int + DOUBLE_LIST_FIELD_NUMBER: builtins.int + SINT64_LIST_FIELD_NUMBER: builtins.int + BOOL_LIST_FIELD_NUMBER: builtins.int + STRING_LIST_FIELD_NUMBER: builtins.int + BYTES_LIST_FIELD_NUMBER: builtins.int + double: builtins.float + """Single element""" + + sint64: builtins.int + bool: builtins.bool + string: typing.Text + bytes: builtins.bytes + @property + def double_list(self) -> global___DoubleList: + """List types""" + pass + @property + def sint64_list(self) -> global___Sint64List: ... + @property + def bool_list(self) -> global___BoolList: ... + @property + def string_list(self) -> global___StringList: ... + @property + def bytes_list(self) -> global___BytesList: ... + def __init__(self, + *, + double: builtins.float = ..., + sint64: builtins.int = ..., + bool: builtins.bool = ..., + string: typing.Text = ..., + bytes: builtins.bytes = ..., + double_list: typing.Optional[global___DoubleList] = ..., + sint64_list: typing.Optional[global___Sint64List] = ..., + bool_list: typing.Optional[global___BoolList] = ..., + string_list: typing.Optional[global___StringList] = ..., + bytes_list: typing.Optional[global___BytesList] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","string",b"string","string_list",b"string_list","value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","string",b"string","string_list",b"string_list","value",b"value"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["value",b"value"]) -> typing.Optional[typing_extensions.Literal["double","sint64","bool","string","bytes","double_list","sint64_list","bool_list","string_list","bytes_list"]]: ... +global___ConfigsRecordValue = ConfigsRecordValue diff --git a/src/py/flwr/proto/common_pb2_grpc.py b/src/py/flwr/proto/common_pb2_grpc.py new file mode 100644 index 000000000000..2daafffebfc8 --- /dev/null +++ b/src/py/flwr/proto/common_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/src/py/flwr/proto/common_pb2_grpc.pyi b/src/py/flwr/proto/common_pb2_grpc.pyi new file mode 100644 index 000000000000..f3a5a087ef5d --- /dev/null +++ b/src/py/flwr/proto/common_pb2_grpc.pyi @@ -0,0 +1,4 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" diff --git a/src/py/flwr/proto/driver_pb2.py b/src/py/flwr/proto/driver_pb2.py index a2458b445563..347d33ca64b6 100644 --- a/src/py/flwr/proto/driver_pb2.py +++ b/src/py/flwr/proto/driver_pb2.py @@ -15,31 +15,36 @@ from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2 from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2 from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 +from flwr.proto import common_pb2 as flwr_dot_proto_dot_common__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\"7\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\x84\x03\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x17\x66lwr/proto/common.proto\"\xd9\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x1aU\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.ConfigsRecordValue:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\x84\x03\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.driver_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_CREATERUNREQUEST']._serialized_start=107 - _globals['_CREATERUNREQUEST']._serialized_end=162 - _globals['_CREATERUNRESPONSE']._serialized_start=164 - _globals['_CREATERUNRESPONSE']._serialized_end=199 - _globals['_GETNODESREQUEST']._serialized_start=201 - _globals['_GETNODESREQUEST']._serialized_end=234 - _globals['_GETNODESRESPONSE']._serialized_start=236 - _globals['_GETNODESRESPONSE']._serialized_end=287 - _globals['_PUSHTASKINSREQUEST']._serialized_start=289 - _globals['_PUSHTASKINSREQUEST']._serialized_end=353 - _globals['_PUSHTASKINSRESPONSE']._serialized_start=355 - _globals['_PUSHTASKINSRESPONSE']._serialized_end=394 - _globals['_PULLTASKRESREQUEST']._serialized_start=396 - _globals['_PULLTASKRESREQUEST']._serialized_end=466 - _globals['_PULLTASKRESRESPONSE']._serialized_start=468 - _globals['_PULLTASKRESRESPONSE']._serialized_end=533 - _globals['_DRIVER']._serialized_start=536 - _globals['_DRIVER']._serialized_end=924 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._options = None + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' + _globals['_CREATERUNREQUEST']._serialized_start=133 + _globals['_CREATERUNREQUEST']._serialized_end=350 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=265 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=350 + _globals['_CREATERUNRESPONSE']._serialized_start=352 + _globals['_CREATERUNRESPONSE']._serialized_end=387 + _globals['_GETNODESREQUEST']._serialized_start=389 + _globals['_GETNODESREQUEST']._serialized_end=422 + _globals['_GETNODESRESPONSE']._serialized_start=424 + _globals['_GETNODESRESPONSE']._serialized_end=475 + _globals['_PUSHTASKINSREQUEST']._serialized_start=477 + _globals['_PUSHTASKINSREQUEST']._serialized_end=541 + _globals['_PUSHTASKINSRESPONSE']._serialized_start=543 + _globals['_PUSHTASKINSRESPONSE']._serialized_end=582 + _globals['_PULLTASKRESREQUEST']._serialized_start=584 + _globals['_PULLTASKRESREQUEST']._serialized_end=654 + _globals['_PULLTASKRESRESPONSE']._serialized_start=656 + _globals['_PULLTASKRESRESPONSE']._serialized_end=721 + _globals['_DRIVER']._serialized_start=724 + _globals['_DRIVER']._serialized_end=1112 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/driver_pb2.pyi b/src/py/flwr/proto/driver_pb2.pyi index 2d8d11fb59a3..8fb77021addc 100644 --- a/src/py/flwr/proto/driver_pb2.pyi +++ b/src/py/flwr/proto/driver_pb2.pyi @@ -3,6 +3,7 @@ isort:skip_file """ import builtins +import flwr.proto.common_pb2 import flwr.proto.node_pb2 import flwr.proto.task_pb2 import google.protobuf.descriptor @@ -16,16 +17,35 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor class CreateRunRequest(google.protobuf.message.Message): """CreateRun""" DESCRIPTOR: google.protobuf.descriptor.Descriptor + class OverrideConfigEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> flwr.proto.common_pb2.ConfigsRecordValue: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[flwr.proto.common_pb2.ConfigsRecordValue] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + FAB_ID_FIELD_NUMBER: builtins.int FAB_VERSION_FIELD_NUMBER: builtins.int + OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int fab_id: typing.Text fab_version: typing.Text + @property + def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.common_pb2.ConfigsRecordValue]: ... def __init__(self, *, fab_id: typing.Text = ..., fab_version: typing.Text = ..., + override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.common_pb2.ConfigsRecordValue]] = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config"]) -> None: ... global___CreateRunRequest = CreateRunRequest class CreateRunResponse(google.protobuf.message.Message): diff --git a/src/py/flwr/proto/exec_pb2.py b/src/py/flwr/proto/exec_pb2.py index 7b037a9454c0..a76aea6a80f6 100644 --- a/src/py/flwr/proto/exec_pb2.py +++ b/src/py/flwr/proto/exec_pb2.py @@ -12,23 +12,28 @@ _sym_db = _symbol_database.Default() +from flwr.proto import common_pb2 as flwr_dot_proto_dot_common__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\"#\n\x0fStartRunRequest\x12\x10\n\x08\x66\x61\x62_file\x18\x01 \x01(\x0c\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"#\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"(\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t2\xa0\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x17\x66lwr/proto/common.proto\"\xc4\x01\n\x0fStartRunRequest\x12\x10\n\x08\x66\x61\x62_file\x18\x01 \x01(\x0c\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x1aU\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.ConfigsRecordValue:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"#\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"(\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t2\xa0\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.exec_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_STARTRUNREQUEST']._serialized_start=37 - _globals['_STARTRUNREQUEST']._serialized_end=72 - _globals['_STARTRUNRESPONSE']._serialized_start=74 - _globals['_STARTRUNRESPONSE']._serialized_end=108 - _globals['_STREAMLOGSREQUEST']._serialized_start=110 - _globals['_STREAMLOGSREQUEST']._serialized_end=145 - _globals['_STREAMLOGSRESPONSE']._serialized_start=147 - _globals['_STREAMLOGSRESPONSE']._serialized_end=187 - _globals['_EXEC']._serialized_start=190 - _globals['_EXEC']._serialized_end=350 + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._options = None + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' + _globals['_STARTRUNREQUEST']._serialized_start=63 + _globals['_STARTRUNREQUEST']._serialized_end=259 + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=174 + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=259 + _globals['_STARTRUNRESPONSE']._serialized_start=261 + _globals['_STARTRUNRESPONSE']._serialized_end=295 + _globals['_STREAMLOGSREQUEST']._serialized_start=297 + _globals['_STREAMLOGSREQUEST']._serialized_end=332 + _globals['_STREAMLOGSRESPONSE']._serialized_start=334 + _globals['_STREAMLOGSRESPONSE']._serialized_end=374 + _globals['_EXEC']._serialized_start=377 + _globals['_EXEC']._serialized_end=537 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/exec_pb2.pyi b/src/py/flwr/proto/exec_pb2.pyi index 466812808da8..01d2df6dd169 100644 --- a/src/py/flwr/proto/exec_pb2.pyi +++ b/src/py/flwr/proto/exec_pb2.pyi @@ -3,7 +3,9 @@ isort:skip_file """ import builtins +import flwr.proto.common_pb2 import google.protobuf.descriptor +import google.protobuf.internal.containers import google.protobuf.message import typing import typing_extensions @@ -12,13 +14,32 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor class StartRunRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + class OverrideConfigEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> flwr.proto.common_pb2.ConfigsRecordValue: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[flwr.proto.common_pb2.ConfigsRecordValue] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + FAB_FILE_FIELD_NUMBER: builtins.int + OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int fab_file: builtins.bytes + @property + def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.common_pb2.ConfigsRecordValue]: ... def __init__(self, *, fab_file: builtins.bytes = ..., + override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.common_pb2.ConfigsRecordValue]] = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["fab_file",b"fab_file"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["fab_file",b"fab_file","override_config",b"override_config"]) -> None: ... global___StartRunRequest = StartRunRequest class StartRunResponse(google.protobuf.message.Message): diff --git a/src/py/flwr/proto/recordset_pb2.py b/src/py/flwr/proto/recordset_pb2.py index f7f74d72182b..5dd5deb12876 100644 --- a/src/py/flwr/proto/recordset_pb2.py +++ b/src/py/flwr/proto/recordset_pb2.py @@ -12,9 +12,10 @@ _sym_db = _symbol_database.Default() +from flwr.proto import common_pb2 as flwr_dot_proto_dot_common__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1a\x66lwr/proto/recordset.proto\x12\nflwr.proto\"\x1a\n\nDoubleList\x12\x0c\n\x04vals\x18\x01 \x03(\x01\"\x1a\n\nSint64List\x12\x0c\n\x04vals\x18\x01 \x03(\x12\"\x18\n\x08\x42oolList\x12\x0c\n\x04vals\x18\x01 \x03(\x08\"\x1a\n\nStringList\x12\x0c\n\x04vals\x18\x01 \x03(\t\"\x19\n\tBytesList\x12\x0c\n\x04vals\x18\x01 \x03(\x0c\"B\n\x05\x41rray\x12\r\n\x05\x64type\x18\x01 \x01(\t\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05stype\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\"\x9f\x01\n\x12MetricsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x42\x07\n\x05value\"\xd9\x02\n\x12\x43onfigsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x12)\n\tbool_list\x18\x17 \x01(\x0b\x32\x14.flwr.proto.BoolListH\x00\x12-\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x16.flwr.proto.StringListH\x00\x12+\n\nbytes_list\x18\x19 \x01(\x0b\x32\x15.flwr.proto.BytesListH\x00\x42\x07\n\x05value\"M\n\x10ParametersRecord\x12\x11\n\tdata_keys\x18\x01 \x03(\t\x12&\n\x0b\x64\x61ta_values\x18\x02 \x03(\x0b\x32\x11.flwr.proto.Array\"\x8f\x01\n\rMetricsRecord\x12\x31\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32#.flwr.proto.MetricsRecord.DataEntry\x1aK\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.MetricsRecordValue:\x02\x38\x01\"\x8f\x01\n\rConfigsRecord\x12\x31\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32#.flwr.proto.ConfigsRecord.DataEntry\x1aK\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.ConfigsRecordValue:\x02\x38\x01\"\x97\x03\n\tRecordSet\x12\x39\n\nparameters\x18\x01 \x03(\x0b\x32%.flwr.proto.RecordSet.ParametersEntry\x12\x33\n\x07metrics\x18\x02 \x03(\x0b\x32\".flwr.proto.RecordSet.MetricsEntry\x12\x33\n\x07\x63onfigs\x18\x03 \x03(\x0b\x32\".flwr.proto.RecordSet.ConfigsEntry\x1aO\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12+\n\x05value\x18\x02 \x01(\x0b\x32\x1c.flwr.proto.ParametersRecord:\x02\x38\x01\x1aI\n\x0cMetricsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.flwr.proto.MetricsRecord:\x02\x38\x01\x1aI\n\x0c\x43onfigsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord:\x02\x38\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1a\x66lwr/proto/recordset.proto\x12\nflwr.proto\x1a\x17\x66lwr/proto/common.proto\"B\n\x05\x41rray\x12\r\n\x05\x64type\x18\x01 \x01(\t\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05stype\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\"\x9f\x01\n\x12MetricsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x42\x07\n\x05value\"M\n\x10ParametersRecord\x12\x11\n\tdata_keys\x18\x01 \x03(\t\x12&\n\x0b\x64\x61ta_values\x18\x02 \x03(\x0b\x32\x11.flwr.proto.Array\"\x8f\x01\n\rMetricsRecord\x12\x31\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32#.flwr.proto.MetricsRecord.DataEntry\x1aK\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.MetricsRecordValue:\x02\x38\x01\"\x8f\x01\n\rConfigsRecord\x12\x31\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32#.flwr.proto.ConfigsRecord.DataEntry\x1aK\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.ConfigsRecordValue:\x02\x38\x01\"\x97\x03\n\tRecordSet\x12\x39\n\nparameters\x18\x01 \x03(\x0b\x32%.flwr.proto.RecordSet.ParametersEntry\x12\x33\n\x07metrics\x18\x02 \x03(\x0b\x32\".flwr.proto.RecordSet.MetricsEntry\x12\x33\n\x07\x63onfigs\x18\x03 \x03(\x0b\x32\".flwr.proto.RecordSet.ConfigsEntry\x1aO\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12+\n\x05value\x18\x02 \x01(\x0b\x32\x1c.flwr.proto.ParametersRecord:\x02\x38\x01\x1aI\n\x0cMetricsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.flwr.proto.MetricsRecord:\x02\x38\x01\x1aI\n\x0c\x43onfigsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord:\x02\x38\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -31,38 +32,26 @@ _globals['_RECORDSET_METRICSENTRY']._serialized_options = b'8\001' _globals['_RECORDSET_CONFIGSENTRY']._options = None _globals['_RECORDSET_CONFIGSENTRY']._serialized_options = b'8\001' - _globals['_DOUBLELIST']._serialized_start=42 - _globals['_DOUBLELIST']._serialized_end=68 - _globals['_SINT64LIST']._serialized_start=70 - _globals['_SINT64LIST']._serialized_end=96 - _globals['_BOOLLIST']._serialized_start=98 - _globals['_BOOLLIST']._serialized_end=122 - _globals['_STRINGLIST']._serialized_start=124 - _globals['_STRINGLIST']._serialized_end=150 - _globals['_BYTESLIST']._serialized_start=152 - _globals['_BYTESLIST']._serialized_end=177 - _globals['_ARRAY']._serialized_start=179 - _globals['_ARRAY']._serialized_end=245 - _globals['_METRICSRECORDVALUE']._serialized_start=248 - _globals['_METRICSRECORDVALUE']._serialized_end=407 - _globals['_CONFIGSRECORDVALUE']._serialized_start=410 - _globals['_CONFIGSRECORDVALUE']._serialized_end=755 - _globals['_PARAMETERSRECORD']._serialized_start=757 - _globals['_PARAMETERSRECORD']._serialized_end=834 - _globals['_METRICSRECORD']._serialized_start=837 - _globals['_METRICSRECORD']._serialized_end=980 - _globals['_METRICSRECORD_DATAENTRY']._serialized_start=905 - _globals['_METRICSRECORD_DATAENTRY']._serialized_end=980 - _globals['_CONFIGSRECORD']._serialized_start=983 - _globals['_CONFIGSRECORD']._serialized_end=1126 - _globals['_CONFIGSRECORD_DATAENTRY']._serialized_start=1051 - _globals['_CONFIGSRECORD_DATAENTRY']._serialized_end=1126 - _globals['_RECORDSET']._serialized_start=1129 - _globals['_RECORDSET']._serialized_end=1536 - _globals['_RECORDSET_PARAMETERSENTRY']._serialized_start=1307 - _globals['_RECORDSET_PARAMETERSENTRY']._serialized_end=1386 - _globals['_RECORDSET_METRICSENTRY']._serialized_start=1388 - _globals['_RECORDSET_METRICSENTRY']._serialized_end=1461 - _globals['_RECORDSET_CONFIGSENTRY']._serialized_start=1463 - _globals['_RECORDSET_CONFIGSENTRY']._serialized_end=1536 + _globals['_ARRAY']._serialized_start=67 + _globals['_ARRAY']._serialized_end=133 + _globals['_METRICSRECORDVALUE']._serialized_start=136 + _globals['_METRICSRECORDVALUE']._serialized_end=295 + _globals['_PARAMETERSRECORD']._serialized_start=297 + _globals['_PARAMETERSRECORD']._serialized_end=374 + _globals['_METRICSRECORD']._serialized_start=377 + _globals['_METRICSRECORD']._serialized_end=520 + _globals['_METRICSRECORD_DATAENTRY']._serialized_start=445 + _globals['_METRICSRECORD_DATAENTRY']._serialized_end=520 + _globals['_CONFIGSRECORD']._serialized_start=523 + _globals['_CONFIGSRECORD']._serialized_end=666 + _globals['_CONFIGSRECORD_DATAENTRY']._serialized_start=591 + _globals['_CONFIGSRECORD_DATAENTRY']._serialized_end=666 + _globals['_RECORDSET']._serialized_start=669 + _globals['_RECORDSET']._serialized_end=1076 + _globals['_RECORDSET_PARAMETERSENTRY']._serialized_start=847 + _globals['_RECORDSET_PARAMETERSENTRY']._serialized_end=926 + _globals['_RECORDSET_METRICSENTRY']._serialized_start=928 + _globals['_RECORDSET_METRICSENTRY']._serialized_end=1001 + _globals['_RECORDSET_CONFIGSENTRY']._serialized_start=1003 + _globals['_RECORDSET_CONFIGSENTRY']._serialized_end=1076 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/recordset_pb2.pyi b/src/py/flwr/proto/recordset_pb2.pyi index 86244697129c..4fa24ab68677 100644 --- a/src/py/flwr/proto/recordset_pb2.pyi +++ b/src/py/flwr/proto/recordset_pb2.pyi @@ -3,6 +3,7 @@ isort:skip_file """ import builtins +import flwr.proto.common_pb2 import google.protobuf.descriptor import google.protobuf.internal.containers import google.protobuf.message @@ -11,66 +12,6 @@ import typing_extensions DESCRIPTOR: google.protobuf.descriptor.FileDescriptor -class DoubleList(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - VALS_FIELD_NUMBER: builtins.int - @property - def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: ... - def __init__(self, - *, - vals: typing.Optional[typing.Iterable[builtins.float]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... -global___DoubleList = DoubleList - -class Sint64List(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - VALS_FIELD_NUMBER: builtins.int - @property - def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... - def __init__(self, - *, - vals: typing.Optional[typing.Iterable[builtins.int]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... -global___Sint64List = Sint64List - -class BoolList(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - VALS_FIELD_NUMBER: builtins.int - @property - def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bool]: ... - def __init__(self, - *, - vals: typing.Optional[typing.Iterable[builtins.bool]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... -global___BoolList = BoolList - -class StringList(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - VALS_FIELD_NUMBER: builtins.int - @property - def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... - def __init__(self, - *, - vals: typing.Optional[typing.Iterable[typing.Text]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... -global___StringList = StringList - -class BytesList(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - VALS_FIELD_NUMBER: builtins.int - @property - def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... - def __init__(self, - *, - vals: typing.Optional[typing.Iterable[builtins.bytes]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... -global___BytesList = BytesList - class Array(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor DTYPE_FIELD_NUMBER: builtins.int @@ -103,72 +44,23 @@ class MetricsRecordValue(google.protobuf.message.Message): sint64: builtins.int @property - def double_list(self) -> global___DoubleList: + def double_list(self) -> flwr.proto.common_pb2.DoubleList: """List types""" pass @property - def sint64_list(self) -> global___Sint64List: ... + def sint64_list(self) -> flwr.proto.common_pb2.Sint64List: ... def __init__(self, *, double: builtins.float = ..., sint64: builtins.int = ..., - double_list: typing.Optional[global___DoubleList] = ..., - sint64_list: typing.Optional[global___Sint64List] = ..., + double_list: typing.Optional[flwr.proto.common_pb2.DoubleList] = ..., + sint64_list: typing.Optional[flwr.proto.common_pb2.Sint64List] = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","value",b"value"]) -> builtins.bool: ... def ClearField(self, field_name: typing_extensions.Literal["double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","value",b"value"]) -> None: ... def WhichOneof(self, oneof_group: typing_extensions.Literal["value",b"value"]) -> typing.Optional[typing_extensions.Literal["double","sint64","double_list","sint64_list"]]: ... global___MetricsRecordValue = MetricsRecordValue -class ConfigsRecordValue(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - DOUBLE_FIELD_NUMBER: builtins.int - SINT64_FIELD_NUMBER: builtins.int - BOOL_FIELD_NUMBER: builtins.int - STRING_FIELD_NUMBER: builtins.int - BYTES_FIELD_NUMBER: builtins.int - DOUBLE_LIST_FIELD_NUMBER: builtins.int - SINT64_LIST_FIELD_NUMBER: builtins.int - BOOL_LIST_FIELD_NUMBER: builtins.int - STRING_LIST_FIELD_NUMBER: builtins.int - BYTES_LIST_FIELD_NUMBER: builtins.int - double: builtins.float - """Single element""" - - sint64: builtins.int - bool: builtins.bool - string: typing.Text - bytes: builtins.bytes - @property - def double_list(self) -> global___DoubleList: - """List types""" - pass - @property - def sint64_list(self) -> global___Sint64List: ... - @property - def bool_list(self) -> global___BoolList: ... - @property - def string_list(self) -> global___StringList: ... - @property - def bytes_list(self) -> global___BytesList: ... - def __init__(self, - *, - double: builtins.float = ..., - sint64: builtins.int = ..., - bool: builtins.bool = ..., - string: typing.Text = ..., - bytes: builtins.bytes = ..., - double_list: typing.Optional[global___DoubleList] = ..., - sint64_list: typing.Optional[global___Sint64List] = ..., - bool_list: typing.Optional[global___BoolList] = ..., - string_list: typing.Optional[global___StringList] = ..., - bytes_list: typing.Optional[global___BytesList] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","string",b"string","string_list",b"string_list","value",b"value"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","string",b"string","string_list",b"string_list","value",b"value"]) -> None: ... - def WhichOneof(self, oneof_group: typing_extensions.Literal["value",b"value"]) -> typing.Optional[typing_extensions.Literal["double","sint64","bool","string","bytes","double_list","sint64_list","bool_list","string_list","bytes_list"]]: ... -global___ConfigsRecordValue = ConfigsRecordValue - class ParametersRecord(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor DATA_KEYS_FIELD_NUMBER: builtins.int @@ -220,21 +112,21 @@ class ConfigsRecord(google.protobuf.message.Message): VALUE_FIELD_NUMBER: builtins.int key: typing.Text @property - def value(self) -> global___ConfigsRecordValue: ... + def value(self) -> flwr.proto.common_pb2.ConfigsRecordValue: ... def __init__(self, *, key: typing.Text = ..., - value: typing.Optional[global___ConfigsRecordValue] = ..., + value: typing.Optional[flwr.proto.common_pb2.ConfigsRecordValue] = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... DATA_FIELD_NUMBER: builtins.int @property - def data(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___ConfigsRecordValue]: ... + def data(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.common_pb2.ConfigsRecordValue]: ... def __init__(self, *, - data: typing.Optional[typing.Mapping[typing.Text, global___ConfigsRecordValue]] = ..., + data: typing.Optional[typing.Mapping[typing.Text, flwr.proto.common_pb2.ConfigsRecordValue]] = ..., ) -> None: ... def ClearField(self, field_name: typing_extensions.Literal["data",b"data"]) -> None: ... global___ConfigsRecord = ConfigsRecord diff --git a/src/py/flwr/proto/run_pb2.py b/src/py/flwr/proto/run_pb2.py index 13f06e7169aa..ecf45903dc02 100644 --- a/src/py/flwr/proto/run_pb2.py +++ b/src/py/flwr/proto/run_pb2.py @@ -12,19 +12,24 @@ _sym_db = _symbol_database.Default() +from flwr.proto import common_pb2 as flwr_dot_proto_dot_common__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\":\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x17\x66lwr/proto/common.proto\"\xcf\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x1aU\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.ConfigsRecordValue:\x02\x38\x01\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.run_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_RUN']._serialized_start=36 - _globals['_RUN']._serialized_end=94 - _globals['_GETRUNREQUEST']._serialized_start=96 - _globals['_GETRUNREQUEST']._serialized_end=127 - _globals['_GETRUNRESPONSE']._serialized_start=129 - _globals['_GETRUNRESPONSE']._serialized_end=175 + _globals['_RUN_OVERRIDECONFIGENTRY']._options = None + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' + _globals['_RUN']._serialized_start=62 + _globals['_RUN']._serialized_end=269 + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=184 + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=269 + _globals['_GETRUNREQUEST']._serialized_start=271 + _globals['_GETRUNREQUEST']._serialized_end=302 + _globals['_GETRUNRESPONSE']._serialized_start=304 + _globals['_GETRUNRESPONSE']._serialized_end=350 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/run_pb2.pyi b/src/py/flwr/proto/run_pb2.pyi index 401d27855a41..463054927fb5 100644 --- a/src/py/flwr/proto/run_pb2.pyi +++ b/src/py/flwr/proto/run_pb2.pyi @@ -3,7 +3,9 @@ isort:skip_file """ import builtins +import flwr.proto.common_pb2 import google.protobuf.descriptor +import google.protobuf.internal.containers import google.protobuf.message import typing import typing_extensions @@ -12,19 +14,38 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor class Run(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + class OverrideConfigEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> flwr.proto.common_pb2.ConfigsRecordValue: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[flwr.proto.common_pb2.ConfigsRecordValue] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + RUN_ID_FIELD_NUMBER: builtins.int FAB_ID_FIELD_NUMBER: builtins.int FAB_VERSION_FIELD_NUMBER: builtins.int + OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int run_id: builtins.int fab_id: typing.Text fab_version: typing.Text + @property + def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.common_pb2.ConfigsRecordValue]: ... def __init__(self, *, run_id: builtins.int = ..., fab_id: typing.Text = ..., fab_version: typing.Text = ..., + override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.common_pb2.ConfigsRecordValue]] = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","run_id",b"run_id"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config","run_id",b"run_id"]) -> None: ... global___Run = Run class GetRunRequest(google.protobuf.message.Message): diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index e614df659e3f..573619613616 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -24,7 +24,11 @@ from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event from flwr.common.grpc import create_channel from flwr.common.logger import log -from flwr.common.serde import message_from_taskres, message_to_taskins +from flwr.common.serde import ( + message_from_taskres, + message_to_taskins, + record_value_dict_from_proto, +) from flwr.common.typing import Run from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 CreateRunRequest, @@ -206,6 +210,7 @@ def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]: run_id=res.run.run_id, fab_id=res.run.fab_id, fab_version=res.run.fab_version, + overrides=record_value_dict_from_proto(res.run.override_config), ) return self.stub, self._run.run_id diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index 3505ebfdb0a9..275210afee10 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -19,12 +19,18 @@ import sys from logging import DEBUG, INFO, WARN from pathlib import Path -from typing import Optional +from typing import Dict, Optional from flwr.common import Context, EventType, RecordSet, event -from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir +from flwr.common.config import ( + get_flwr_dir, + get_fused_config, + get_project_config, + get_project_dir, +) from flwr.common.logger import log, update_console_handler, warn_deprecated_feature from flwr.common.object_ref import load_app +from flwr.common.typing import ConfigsRecordValues from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611 from .driver import Driver @@ -37,6 +43,7 @@ def run( driver: Driver, server_app_dir: str, + server_config: Dict[str, ConfigsRecordValues], server_app_attr: Optional[str] = None, loaded_server_app: Optional[ServerApp] = None, ) -> None: @@ -70,6 +77,7 @@ def _load() -> ServerApp: # Initialize Context context = Context(state=RecordSet()) + context.config = server_config # Call ServerApp server_app(driver=driver, context=context) @@ -160,6 +168,7 @@ def run_server_app() -> None: # pylint: disable=too-many-branches # Initialize GrpcDriver driver = GrpcDriver(run_id=run_id, stub=stub) + server_config = {} # Dynamically obtain ServerApp path based on run_id if args.run_id is not None: @@ -169,6 +178,7 @@ def run_server_app() -> None: # pylint: disable=too-many-branches server_app_dir = str(get_project_dir(run_.fab_id, run_.fab_version, flwr_dir)) config = get_project_config(server_app_dir) server_app_attr = config["flower"]["components"]["serverapp"] + server_config = get_fused_config(run_) else: # User provided `server-app`, but not `--run-id` server_app_dir = str(Path(args.dir).absolute()) @@ -182,7 +192,12 @@ def run_server_app() -> None: # pylint: disable=too-many-branches ) # Run the ServerApp with the Driver - run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr) + run( + driver=driver, + server_app_dir=server_app_dir, + server_config=server_config, + server_app_attr=server_app_attr, + ) # Clean up driver.close() diff --git a/src/py/flwr/server/superlink/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py index 03128f02158e..0982bd4afd9f 100644 --- a/src/py/flwr/server/superlink/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -23,6 +23,7 @@ import grpc from flwr.common.logger import log +from flwr.common.serde import record_value_dict_from_proto from flwr.proto import driver_pb2_grpc # pylint: disable=E0611 from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 CreateRunRequest, @@ -69,7 +70,11 @@ def CreateRun( """Create run ID.""" log(DEBUG, "DriverServicer.CreateRun") state: State = self.state_factory.state() - run_id = state.create_run(request.fab_id, request.fab_version) + run_id = state.create_run( + request.fab_id, + request.fab_version, + record_value_dict_from_proto(request.override_config), + ) return CreateRunResponse(run_id=run_id) def PushTaskIns( diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index 65e2c63cab69..bb3a00731414 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -16,10 +16,10 @@ import abc -from typing import List, Optional, Set +from typing import Dict, List, Optional, Set from uuid import UUID -from flwr.common.typing import Run +from flwr.common.typing import ConfigsRecordValues, Run from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 @@ -157,7 +157,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `client_public_keys`.""" @abc.abstractmethod - def create_run(self, fab_id: str, fab_version: str) -> int: + def create_run( + self, + fab_id: str, + fab_version: str, + override_config: Dict[str, ConfigsRecordValues], + ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" @abc.abstractmethod diff --git a/src/py/flwr/superexec/deployment.py b/src/py/flwr/superexec/deployment.py index 6f931e81eefa..c7f44689574f 100644 --- a/src/py/flwr/superexec/deployment.py +++ b/src/py/flwr/superexec/deployment.py @@ -17,7 +17,7 @@ import subprocess import sys from logging import ERROR, INFO -from typing import Optional +from typing import Dict, Optional from typing_extensions import override @@ -25,6 +25,7 @@ from flwr.cli.install import install_from_fab from flwr.common.grpc import create_channel from flwr.common.logger import log +from flwr.common.typing import ConfigsRecordValues from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611 from flwr.proto.driver_pb2_grpc import DriverStub from flwr.server.driver.grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER @@ -53,18 +54,22 @@ def _connect(self) -> None: ) self.stub = DriverStub(channel) - def _create_run(self, fab_id: str, fab_version: str) -> int: + def _create_run(self, fab_id: str, fab_version: str, override_config) -> int: if self.stub is None: self._connect() assert self.stub is not None - req = CreateRunRequest(fab_id=fab_id, fab_version=fab_version) + req = CreateRunRequest( + fab_id=fab_id, fab_version=fab_version, override_config=override_config + ) res = self.stub.CreateRun(request=req) return int(res.run_id) @override - def start_run(self, fab_file: bytes) -> Optional[RunTracker]: + def start_run( + self, fab_file: bytes, override_config: Dict[str, ConfigsRecordValues] + ) -> Optional[RunTracker]: """Start run using the Flower Deployment Engine.""" try: # Install FAB to flwr dir @@ -79,9 +84,11 @@ def start_run(self, fab_file: bytes) -> Optional[RunTracker]: ) # Call SuperLink to create run - run_id: int = self._create_run(fab_id, fab_version) + run_id: int = self._create_run(fab_id, fab_version, override_config) log(INFO, "Created run %s", str(run_id)) + # log_to_db(run_id, "overrides", override_config) + # Start ServerApp proc = subprocess.Popen( # pylint: disable=consider-using-with [ diff --git a/src/py/flwr/superexec/executor.py b/src/py/flwr/superexec/executor.py index f85ac4c157fc..c0ebe044a73d 100644 --- a/src/py/flwr/superexec/executor.py +++ b/src/py/flwr/superexec/executor.py @@ -17,7 +17,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from subprocess import Popen -from typing import Optional +from typing import Dict, Optional + +from flwr.common.typing import ConfigsRecordValues @dataclass @@ -33,8 +35,7 @@ class Executor(ABC): @abstractmethod def start_run( - self, - fab_file: bytes, + self, fab_file: bytes, override_config: Dict[str, ConfigsRecordValues] ) -> Optional[RunTracker]: """Start a run using the given Flower FAB ID and version.