diff --git a/src/proto/flwr/proto/exec.proto b/src/proto/flwr/proto/exec.proto index 0968857bdd71..047b0d0910ff 100644 --- a/src/proto/flwr/proto/exec.proto +++ b/src/proto/flwr/proto/exec.proto @@ -30,6 +30,7 @@ service Exec { message StartRunRequest { bytes fab_file = 1; map override_config = 2; + map federation_config = 3; } message StartRunResponse { sint64 run_id = 1; } message StreamLogsRequest { sint64 run_id = 1; } diff --git a/src/py/flwr/cli/run/run.py b/src/py/flwr/cli/run/run.py index 00588fec4224..1c57b20e3026 100644 --- a/src/py/flwr/cli/run/run.py +++ b/src/py/flwr/cli/run/run.py @@ -25,7 +25,7 @@ from flwr.cli.build import build from flwr.cli.config_utils import load_and_validate -from flwr.common.config import parse_config_args +from flwr.common.config import flatten_dict, parse_config_args from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel from flwr.common.logger import log from flwr.common.serde import user_config_to_proto @@ -114,7 +114,7 @@ def run( def _run_with_superexec( - federation: Dict[str, str], + federation: Dict[str, Any], directory: Optional[Path], config_overrides: Optional[List[str]], ) -> None: @@ -168,6 +168,7 @@ def on_channel_state_change(channel_connectivity: str) -> None: override_config=user_config_to_proto( parse_config_args(config_overrides, separator=",") ), + federation_config=user_config_to_proto(flatten_dict(federation.get("options"))), ) res = stub.StartRun(req) typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN) diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index c915a3ef1621..c83fd59df184 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -113,8 +113,13 @@ def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> UserConfig: return get_fused_config_from_dir(project_dir, run.override_config) -def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> UserConfig: +def flatten_dict( + raw_dict: Optional[Dict[str, Any]], parent_key: str = "" +) -> UserConfig: """Flatten dict by joining nested keys with a given separator.""" + if raw_dict is None: + return {} + items: List[Tuple[str, UserConfigValue]] = [] separator: str = "." for k, v in raw_dict.items(): diff --git a/src/py/flwr/proto/exec_pb2.py b/src/py/flwr/proto/exec_pb2.py index 5f3a9f1e9f7d..6dfb061aff90 100644 --- a/src/py/flwr/proto/exec_pb2.py +++ b/src/py/flwr/proto/exec_pb2.py @@ -15,7 +15,7 @@ from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xb8\x01\n\x0fStartRunRequest\x12\x10\n\x08\x66\x61\x62_file\x18\x01 \x01(\x0c\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"#\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"(\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t2\xa0\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd3\x02\n\x0fStartRunRequest\x12\x10\n\x08\x66\x61\x62_file\x18\x01 \x01(\x0c\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12L\n\x11\x66\x65\x64\x65ration_config\x18\x03 \x03(\x0b\x32\x31.flwr.proto.StartRunRequest.FederationConfigEntry\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1aK\n\x15\x46\x65\x64\x65rationConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"#\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"(\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t2\xa0\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -24,16 +24,20 @@ DESCRIPTOR._options = None _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._options = None _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' + _globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._options = None + _globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._serialized_options = b'8\001' _globals['_STARTRUNREQUEST']._serialized_start=66 - _globals['_STARTRUNREQUEST']._serialized_end=250 - _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=177 - _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=250 - _globals['_STARTRUNRESPONSE']._serialized_start=252 - _globals['_STARTRUNRESPONSE']._serialized_end=286 - _globals['_STREAMLOGSREQUEST']._serialized_start=288 - _globals['_STREAMLOGSREQUEST']._serialized_end=323 - _globals['_STREAMLOGSRESPONSE']._serialized_start=325 - _globals['_STREAMLOGSRESPONSE']._serialized_end=365 - _globals['_EXEC']._serialized_start=368 - _globals['_EXEC']._serialized_end=528 + _globals['_STARTRUNREQUEST']._serialized_end=405 + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=255 + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=328 + _globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._serialized_start=330 + _globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._serialized_end=405 + _globals['_STARTRUNRESPONSE']._serialized_start=407 + _globals['_STARTRUNRESPONSE']._serialized_end=441 + _globals['_STREAMLOGSREQUEST']._serialized_start=443 + _globals['_STREAMLOGSREQUEST']._serialized_end=478 + _globals['_STREAMLOGSRESPONSE']._serialized_start=480 + _globals['_STREAMLOGSRESPONSE']._serialized_end=520 + _globals['_EXEC']._serialized_start=523 + _globals['_EXEC']._serialized_end=683 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/exec_pb2.pyi b/src/py/flwr/proto/exec_pb2.pyi index fc8a615a6b65..79d54a90856b 100644 --- a/src/py/flwr/proto/exec_pb2.pyi +++ b/src/py/flwr/proto/exec_pb2.pyi @@ -29,17 +29,36 @@ class StartRunRequest(google.protobuf.message.Message): def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + class FederationConfigEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> flwr.proto.transport_pb2.Scalar: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + FAB_FILE_FIELD_NUMBER: builtins.int OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int + FEDERATION_CONFIG_FIELD_NUMBER: builtins.int fab_file: builtins.bytes @property def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ... + @property + def federation_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ... def __init__(self, *, fab_file: builtins.bytes = ..., override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ..., + federation_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["fab_file",b"fab_file","override_config",b"override_config"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["fab_file",b"fab_file","federation_config",b"federation_config","override_config",b"override_config"]) -> None: ... global___StartRunRequest = StartRunRequest class StartRunResponse(google.protobuf.message.Message): diff --git a/src/py/flwr/superexec/deployment.py b/src/py/flwr/superexec/deployment.py index 2eb40a7464c9..bd27d6b21017 100644 --- a/src/py/flwr/superexec/deployment.py +++ b/src/py/flwr/superexec/deployment.py @@ -135,6 +135,7 @@ def start_run( self, fab_file: bytes, override_config: UserConfig, + federation_config: UserConfig, ) -> Optional[RunTracker]: """Start run using the Flower Deployment Engine.""" try: diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index fa54590d3b7b..83aac7bd5fd6 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -47,7 +47,9 @@ def StartRun( log(INFO, "ExecServicer.StartRun") run = self.executor.start_run( - request.fab_file, user_config_from_proto(request.override_config) + request.fab_file, + user_config_from_proto(request.override_config), + user_config_from_proto(request.federation_config), ) if run is None: diff --git a/src/py/flwr/superexec/exec_servicer_test.py b/src/py/flwr/superexec/exec_servicer_test.py index edc91df4530e..e55427572fd9 100644 --- a/src/py/flwr/superexec/exec_servicer_test.py +++ b/src/py/flwr/superexec/exec_servicer_test.py @@ -36,7 +36,7 @@ def test_start_run() -> None: run_res.proc = proc executor = MagicMock() - executor.start_run = lambda _, __: run_res + executor.start_run = lambda _, __, ___: run_res context_mock = MagicMock() diff --git a/src/py/flwr/superexec/executor.py b/src/py/flwr/superexec/executor.py index ed941d47e764..8d630d108b66 100644 --- a/src/py/flwr/superexec/executor.py +++ b/src/py/flwr/superexec/executor.py @@ -51,6 +51,7 @@ def start_run( self, fab_file: bytes, override_config: UserConfig, + federation_config: UserConfig, ) -> Optional[RunTracker]: """Start a run using the given Flower FAB ID and version. @@ -63,6 +64,8 @@ def start_run( The Flower App Bundle file bytes. override_config: UserConfig The config overrides dict sent by the user (using `flwr run`). + federation_config: UserConfig + The federation options dict sent by the user (using `flwr run`). Returns ------- diff --git a/src/py/flwr/superexec/simulation.py b/src/py/flwr/superexec/simulation.py index 737c037375d7..be49c83be716 100644 --- a/src/py/flwr/superexec/simulation.py +++ b/src/py/flwr/superexec/simulation.py @@ -82,7 +82,10 @@ def set_config( @override def start_run( - self, fab_file: bytes, override_config: UserConfig + self, + fab_file: bytes, + override_config: UserConfig, + federation_config: UserConfig, ) -> Optional[RunTracker]: """Start run using the Flower Simulation Engine.""" try: @@ -120,7 +123,7 @@ def start_run( "--app", f"{str(fab_path)}", "--num-supernodes", - f"{self.num_supernodes}", + f"{federation_config.get('num-supernodes', self.num_supernodes)}", "--run-id", str(run_id), ]