From ea01fd1f4557c38119fbbe74de356f2f384eeb49 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Thu, 11 Jul 2024 21:09:46 +0200 Subject: [PATCH] feat(framework) Add run configs (#3725) --- src/py/flwr/cli/config_utils.py | 10 ++++++++++ src/py/flwr/cli/run/run.py | 33 +++++++++++++++++++++++++-------- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/src/py/flwr/cli/config_utils.py b/src/py/flwr/cli/config_utils.py index d06a1d6dba96..33bf12e34b04 100644 --- a/src/py/flwr/cli/config_utils.py +++ b/src/py/flwr/cli/config_utils.py @@ -108,6 +108,14 @@ def load(path: Optional[Path] = None) -> Optional[Dict[str, Any]]: return load_from_string(toml_file.read()) +def _validate_run_config(config_dict: Dict[str, Any], errors: List[str]) -> None: + for key, value in config_dict.items(): + if isinstance(value, dict): + _validate_run_config(config_dict[key], errors) + elif not isinstance(value, str): + errors.append(f"Config value of key {key} is not of type `str`.") + + # pylint: disable=too-many-branches def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]: """Validate pyproject.toml fields.""" @@ -133,6 +141,8 @@ def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]] else: if "publisher" not in config["flower"]: errors.append('Property "publisher" missing in [flower]') + if "config" in config["flower"]: + _validate_run_config(config["flower"]["config"], errors) if "components" not in config["flower"]: errors.append("Missing [flower.components] section") else: diff --git a/src/py/flwr/cli/run/run.py b/src/py/flwr/cli/run/run.py index f5882bd14ab8..4ee2368f5794 100644 --- a/src/py/flwr/cli/run/run.py +++ b/src/py/flwr/cli/run/run.py @@ -18,13 +18,14 @@ from enum import Enum from logging import DEBUG from pathlib import Path -from typing import Optional +from typing import Dict, Optional import typer from typing_extensions import Annotated from flwr.cli import config_utils from flwr.cli.build import build +from flwr.common.config import parse_config_args from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel from flwr.common.logger import log @@ -58,15 +59,20 @@ def run( Optional[Path], typer.Option(help="Path of the Flower project to run"), ] = None, + config_overrides: Annotated[ + Optional[str], + typer.Option( + "--config", + "-c", + help="Override configuration key-value pairs", + ), + ] = None, ) -> None: """Run Flower project.""" - if use_superexec: - _start_superexec_run(directory) - return - typer.secho("Loading project configuration... ", fg=typer.colors.BLUE) - config, errors, warnings = config_utils.load_and_validate() + pyproject_path = directory / "pyproject.toml" if directory else None + config, errors, warnings = config_utils.load_and_validate(path=pyproject_path) if config is None: typer.secho( @@ -88,6 +94,12 @@ def run( typer.secho("Success", fg=typer.colors.GREEN) + if use_superexec: + _start_superexec_run( + parse_config_args(config_overrides, separator=","), directory + ) + return + server_app_ref = config["flower"]["components"]["serverapp"] client_app_ref = config["flower"]["components"]["clientapp"] @@ -115,7 +127,9 @@ def run( ) -def _start_superexec_run(directory: Optional[Path]) -> None: +def _start_superexec_run( + override_config: Dict[str, str], directory: Optional[Path] +) -> None: def on_channel_state_change(channel_connectivity: str) -> None: """Log channel connectivity.""" log(DEBUG, channel_connectivity) @@ -132,6 +146,9 @@ def on_channel_state_change(channel_connectivity: str) -> None: fab_path = build(directory) - req = StartRunRequest(fab_file=Path(fab_path).read_bytes()) + req = StartRunRequest( + fab_file=Path(fab_path).read_bytes(), + override_config=override_config, + ) res = stub.StartRun(req) typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)