Skip to content

Commit

Permalink
feat(framework) Send federation config to SuperExec (#3838)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
charlesbvll and danieljanes authored Jul 23, 2024
1 parent d24ebe5 commit 488bc91
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 20 deletions.
1 change: 1 addition & 0 deletions src/proto/flwr/proto/exec.proto
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ service Exec {
message StartRunRequest {
bytes fab_file = 1;
map<string, Scalar> override_config = 2;
map<string, Scalar> federation_config = 3;
}
message StartRunResponse { sint64 run_id = 1; }
message StreamLogsRequest { sint64 run_id = 1; }
Expand Down
5 changes: 3 additions & 2 deletions src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from flwr.cli.build import build
from flwr.cli.config_utils import load_and_validate
from flwr.common.config import parse_config_args
from flwr.common.config import flatten_dict, parse_config_args
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
from flwr.common.logger import log
from flwr.common.serde import user_config_to_proto
Expand Down Expand Up @@ -114,7 +114,7 @@ def run(


def _run_with_superexec(
federation: Dict[str, str],
federation: Dict[str, Any],
directory: Optional[Path],
config_overrides: Optional[List[str]],
) -> None:
Expand Down Expand Up @@ -168,6 +168,7 @@ def on_channel_state_change(channel_connectivity: str) -> None:
override_config=user_config_to_proto(
parse_config_args(config_overrides, separator=",")
),
federation_config=user_config_to_proto(flatten_dict(federation.get("options"))),
)
res = stub.StartRun(req)
typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
Expand Down
7 changes: 6 additions & 1 deletion src/py/flwr/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,13 @@ def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> UserConfig:
return get_fused_config_from_dir(project_dir, run.override_config)


def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> UserConfig:
def flatten_dict(
raw_dict: Optional[Dict[str, Any]], parent_key: str = ""
) -> UserConfig:
"""Flatten dict by joining nested keys with a given separator."""
if raw_dict is None:
return {}

items: List[Tuple[str, UserConfigValue]] = []
separator: str = "."
for k, v in raw_dict.items():
Expand Down
28 changes: 16 additions & 12 deletions src/py/flwr/proto/exec_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 20 additions & 1 deletion src/py/flwr/proto/exec_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,36 @@ class StartRunRequest(google.protobuf.message.Message):
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...

class FederationConfigEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: typing.Text
@property
def value(self) -> flwr.proto.transport_pb2.Scalar: ...
def __init__(self,
*,
key: typing.Text = ...,
value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...

FAB_FILE_FIELD_NUMBER: builtins.int
OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int
FEDERATION_CONFIG_FIELD_NUMBER: builtins.int
fab_file: builtins.bytes
@property
def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
@property
def federation_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
def __init__(self,
*,
fab_file: builtins.bytes = ...,
override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
federation_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["fab_file",b"fab_file","override_config",b"override_config"]) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["fab_file",b"fab_file","federation_config",b"federation_config","override_config",b"override_config"]) -> None: ...
global___StartRunRequest = StartRunRequest

class StartRunResponse(google.protobuf.message.Message):
Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/superexec/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def start_run(
self,
fab_file: bytes,
override_config: UserConfig,
federation_config: UserConfig,
) -> Optional[RunTracker]:
"""Start run using the Flower Deployment Engine."""
try:
Expand Down
4 changes: 3 additions & 1 deletion src/py/flwr/superexec/exec_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def StartRun(
log(INFO, "ExecServicer.StartRun")

run = self.executor.start_run(
request.fab_file, user_config_from_proto(request.override_config)
request.fab_file,
user_config_from_proto(request.override_config),
user_config_from_proto(request.federation_config),
)

if run is None:
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/superexec/exec_servicer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_start_run() -> None:
run_res.proc = proc

executor = MagicMock()
executor.start_run = lambda _, __: run_res
executor.start_run = lambda _, __, ___: run_res

context_mock = MagicMock()

Expand Down
3 changes: 3 additions & 0 deletions src/py/flwr/superexec/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def start_run(
self,
fab_file: bytes,
override_config: UserConfig,
federation_config: UserConfig,
) -> Optional[RunTracker]:
"""Start a run using the given Flower FAB ID and version.
Expand All @@ -63,6 +64,8 @@ def start_run(
The Flower App Bundle file bytes.
override_config: UserConfig
The config overrides dict sent by the user (using `flwr run`).
federation_config: UserConfig
The federation options dict sent by the user (using `flwr run`).
Returns
-------
Expand Down
7 changes: 5 additions & 2 deletions src/py/flwr/superexec/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def set_config(

@override
def start_run(
self, fab_file: bytes, override_config: UserConfig
self,
fab_file: bytes,
override_config: UserConfig,
federation_config: UserConfig,
) -> Optional[RunTracker]:
"""Start run using the Flower Simulation Engine."""
try:
Expand Down Expand Up @@ -120,7 +123,7 @@ def start_run(
"--app",
f"{str(fab_path)}",
"--num-supernodes",
f"{self.num_supernodes}",
f"{federation_config.get('num-supernodes', self.num_supernodes)}",
"--run-id",
str(run_id),
]
Expand Down

0 comments on commit 488bc91

Please sign in to comment.