diff --git a/src/proto/flwr/proto/exec.proto b/src/proto/flwr/proto/exec.proto index d366b47edfa4..e5b4136d18e8 100644 --- a/src/proto/flwr/proto/exec.proto +++ b/src/proto/flwr/proto/exec.proto @@ -19,6 +19,7 @@ package flwr.proto; import "flwr/proto/fab.proto"; import "flwr/proto/transport.proto"; +import "flwr/proto/recordset.proto"; service Exec { // Start run upon request @@ -31,7 +32,7 @@ service Exec { message StartRunRequest { Fab fab = 1; map override_config = 2; - map federation_config = 3; + ConfigsRecord federation_options = 3; } message StartRunResponse { uint64 run_id = 1; } message StreamLogsRequest { diff --git a/src/py/flwr/cli/run/run.py b/src/py/flwr/cli/run/run.py index fdb565c81377..ce23c3669df0 100644 --- a/src/py/flwr/cli/run/run.py +++ b/src/py/flwr/cli/run/run.py @@ -29,10 +29,18 @@ validate_federation_in_project_config, validate_project_config, ) -from flwr.common.config import flatten_dict, parse_config_args +from flwr.common.config import ( + flatten_dict, + parse_config_args, + user_config_to_configsrecord, +) from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel from flwr.common.logger import log -from flwr.common.serde import fab_to_proto, user_config_to_proto +from flwr.common.serde import ( + configs_record_to_proto, + fab_to_proto, + user_config_to_proto, +) from flwr.common.typing import Fab from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611 from flwr.proto.exec_pb2_grpc import ExecStub @@ -94,6 +102,7 @@ def run( _run_without_exec_api(app, federation_config, config_overrides, federation) +# pylint: disable-next=too-many-locals def _run_with_exec_api( app: Path, federation_config: dict[str, Any], @@ -118,12 +127,14 @@ def _run_with_exec_api( content = Path(fab_path).read_bytes() fab = Fab(fab_hash, content) + # Construct a `ConfigsRecord` out of a flattened `UserConfig` + fed_conf = flatten_dict(federation_config.get("options", {})) + c_record = user_config_to_configsrecord(fed_conf) + req = StartRunRequest( fab=fab_to_proto(fab), override_config=user_config_to_proto(parse_config_args(config_overrides)), - federation_config=user_config_to_proto( - flatten_dict(federation_config.get("options")) - ), + federation_options=configs_record_to_proto(c_record), ) res = stub.StartRun(req) diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index 24ccada7509a..e7f71a40951c 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -22,6 +22,7 @@ import tomli from flwr.cli.config_utils import get_fab_config, validate_fields +from flwr.common import ConfigsRecord from flwr.common.constant import ( APP_DIR, FAB_CONFIG_FILE, @@ -229,3 +230,12 @@ def get_metadata_from_config(config: dict[str, Any]) -> tuple[str, str]: config["project"]["version"], f"{config['tool']['flwr']['app']['publisher']}/{config['project']['name']}", ) + + +def user_config_to_configsrecord(config: UserConfig) -> ConfigsRecord: + """Construct a `ConfigsRecord` out of a `UserConfig`.""" + c_record = ConfigsRecord() + for k, v in config.items(): + c_record[k] = v + + return c_record diff --git a/src/py/flwr/proto/exec_pb2.py b/src/py/flwr/proto/exec_pb2.py index 2126addddeac..e8fda2cfb4f8 100644 --- a/src/py/flwr/proto/exec_pb2.py +++ b/src/py/flwr/proto/exec_pb2.py @@ -14,9 +14,10 @@ from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2 from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 +from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xdf\x02\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12L\n\x11\x66\x65\x64\x65ration_config\x18\x03 \x03(\x0b\x32\x31.flwr.proto.StartRunRequest.FederationConfigEntry\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1aK\n\x15\x46\x65\x64\x65rationConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\x32\xa0\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1a\x66lwr/proto/recordset.proto\"\xfb\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\x32\xa0\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -25,20 +26,16 @@ DESCRIPTOR._options = None _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._options = None _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' - _globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._options = None - _globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._serialized_options = b'8\001' - _globals['_STARTRUNREQUEST']._serialized_start=88 - _globals['_STARTRUNREQUEST']._serialized_end=439 - _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=289 - _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=362 - _globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._serialized_start=364 - _globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._serialized_end=439 - _globals['_STARTRUNRESPONSE']._serialized_start=441 - _globals['_STARTRUNRESPONSE']._serialized_end=475 - _globals['_STREAMLOGSREQUEST']._serialized_start=477 - _globals['_STREAMLOGSREQUEST']._serialized_end=537 - _globals['_STREAMLOGSRESPONSE']._serialized_start=539 - _globals['_STREAMLOGSRESPONSE']._serialized_end=605 - _globals['_EXEC']._serialized_start=608 - _globals['_EXEC']._serialized_end=768 + _globals['_STARTRUNREQUEST']._serialized_start=116 + _globals['_STARTRUNREQUEST']._serialized_end=367 + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=294 + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=367 + _globals['_STARTRUNRESPONSE']._serialized_start=369 + _globals['_STARTRUNRESPONSE']._serialized_end=403 + _globals['_STREAMLOGSREQUEST']._serialized_start=405 + _globals['_STREAMLOGSREQUEST']._serialized_end=465 + _globals['_STREAMLOGSRESPONSE']._serialized_start=467 + _globals['_STREAMLOGSRESPONSE']._serialized_end=533 + _globals['_EXEC']._serialized_start=536 + _globals['_EXEC']._serialized_end=696 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/exec_pb2.pyi b/src/py/flwr/proto/exec_pb2.pyi index 82486383a3c1..380c57ab0780 100644 --- a/src/py/flwr/proto/exec_pb2.pyi +++ b/src/py/flwr/proto/exec_pb2.pyi @@ -4,6 +4,7 @@ isort:skip_file """ import builtins import flwr.proto.fab_pb2 +import flwr.proto.recordset_pb2 import flwr.proto.transport_pb2 import google.protobuf.descriptor import google.protobuf.internal.containers @@ -30,38 +31,23 @@ class StartRunRequest(google.protobuf.message.Message): def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... - class FederationConfigEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: typing.Text - @property - def value(self) -> flwr.proto.transport_pb2.Scalar: ... - def __init__(self, - *, - key: typing.Text = ..., - value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... - FAB_FIELD_NUMBER: builtins.int OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int - FEDERATION_CONFIG_FIELD_NUMBER: builtins.int + FEDERATION_OPTIONS_FIELD_NUMBER: builtins.int @property def fab(self) -> flwr.proto.fab_pb2.Fab: ... @property def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ... @property - def federation_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ... + def federation_options(self) -> flwr.proto.recordset_pb2.ConfigsRecord: ... def __init__(self, *, fab: typing.Optional[flwr.proto.fab_pb2.Fab] = ..., override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ..., - federation_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ..., + federation_options: typing.Optional[flwr.proto.recordset_pb2.ConfigsRecord] = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["fab",b"fab"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["fab",b"fab","federation_config",b"federation_config","override_config",b"override_config"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["fab",b"fab","federation_options",b"federation_options"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["fab",b"fab","federation_options",b"federation_options","override_config",b"override_config"]) -> None: ... global___StartRunRequest = StartRunRequest class StartRunResponse(google.protobuf.message.Message): diff --git a/src/py/flwr/superexec/deployment.py b/src/py/flwr/superexec/deployment.py index 247c594f9766..ad059e9fbf26 100644 --- a/src/py/flwr/superexec/deployment.py +++ b/src/py/flwr/superexec/deployment.py @@ -153,7 +153,7 @@ def start_run( self, fab_file: bytes, override_config: UserConfig, - federation_config: UserConfig, + federation_options: ConfigsRecord, ) -> Optional[int]: """Start run using the Flower Deployment Engine.""" run_id = None diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index 98565dfd31b7..359f87894021 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -24,7 +24,7 @@ from flwr.common.constant import LOG_STREAM_INTERVAL, Status from flwr.common.logger import log -from flwr.common.serde import user_config_from_proto +from flwr.common.serde import configs_record_from_proto, user_config_from_proto from flwr.proto import exec_pb2_grpc # pylint: disable=E0611 from flwr.proto.exec_pb2 import ( # pylint: disable=E0611 StartRunRequest, @@ -61,7 +61,7 @@ def StartRun( run_id = self.executor.start_run( request.fab.content, user_config_from_proto(request.override_config), - user_config_from_proto(request.federation_config), + configs_record_from_proto(request.federation_options), ) if run_id is None: diff --git a/src/py/flwr/superexec/executor.py b/src/py/flwr/superexec/executor.py index fd87b0d742be..a4c73a7b19fe 100644 --- a/src/py/flwr/superexec/executor.py +++ b/src/py/flwr/superexec/executor.py @@ -19,6 +19,7 @@ from subprocess import Popen from typing import Optional +from flwr.common import ConfigsRecord from flwr.common.typing import UserConfig from flwr.server.superlink.ffs.ffs_factory import FfsFactory from flwr.server.superlink.linkstate import LinkStateFactory @@ -71,7 +72,7 @@ def start_run( self, fab_file: bytes, override_config: UserConfig, - federation_config: UserConfig, + federation_options: ConfigsRecord, ) -> Optional[int]: """Start a run using the given Flower FAB ID and version. @@ -84,8 +85,8 @@ def start_run( The Flower App Bundle file bytes. override_config: UserConfig The config overrides dict sent by the user (using `flwr run`). - federation_config: UserConfig - The federation options dict sent by the user (using `flwr run`). + federation_options: ConfigsRecord + The federation options sent by the user (using `flwr run`). Returns ------- diff --git a/src/py/flwr/superexec/simulation.py b/src/py/flwr/superexec/simulation.py index 3941b0c98bc6..ee41a9c7fc45 100644 --- a/src/py/flwr/superexec/simulation.py +++ b/src/py/flwr/superexec/simulation.py @@ -25,6 +25,7 @@ from flwr.cli.config_utils import load_and_validate from flwr.cli.install import install_from_fab +from flwr.common import ConfigsRecord from flwr.common.config import unflatten_dict from flwr.common.constant import RUN_ID_NUM_BYTES from flwr.common.logger import log @@ -124,7 +125,7 @@ def start_run( self, fab_file: bytes, override_config: UserConfig, - federation_config: UserConfig, + federation_options: ConfigsRecord, ) -> Optional[int]: """Start run using the Flower Simulation Engine.""" if self.num_supernodes is None: @@ -163,14 +164,13 @@ def start_run( "Config extracted from FAB's pyproject.toml is not valid" ) - # Flatten federated config - federation_config_flat = unflatten_dict(federation_config) + # Unflatten underlaying dict + fed_opt = unflatten_dict({**federation_options}) - num_supernodes = federation_config_flat.get( - "num-supernodes", self.num_supernodes - ) - backend_cfg = federation_config_flat.get("backend", {}) - verbose: Optional[bool] = federation_config_flat.get("verbose") + # Read data + num_supernodes = fed_opt.get("num-supernodes", self.num_supernodes) + backend_cfg = fed_opt.get("backend", {}) + verbose: Optional[bool] = fed_opt.get("verbose") # In Simulation there is no SuperLink, still we create a run_id run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)