diff --git a/src/py/flwr/cli/run/run.py b/src/py/flwr/cli/run/run.py index 8f0f23fb7a13..24225e6bf631 100644 --- a/src/py/flwr/cli/run/run.py +++ b/src/py/flwr/cli/run/run.py @@ -174,8 +174,8 @@ def on_channel_state_change(channel_connectivity: str) -> None: def _run_without_superexec( config: Dict[str, Any], federation: Dict[str, Any], federation_name: str ) -> None: - server_app_ref = config["tool"]["flwr"]["components"]["serverapp"] - client_app_ref = config["tool"]["flwr"]["components"]["clientapp"] + server_app_ref = config["tool"]["flwr"]["app"]["components"]["serverapp"] + client_app_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"] try: num_supernodes = federation["options"]["num-supernodes"] diff --git a/src/py/flwr/client/supernode/app.py b/src/py/flwr/client/supernode/app.py index a364318c766c..0ef0a145b1b6 100644 --- a/src/py/flwr/client/supernode/app.py +++ b/src/py/flwr/client/supernode/app.py @@ -248,7 +248,7 @@ def _load(fab_id: str, fab_version: str) -> ClientApp: dir_path = Path(project_dir).absolute() # Set app reference - client_app_ref = config["tool"]["flwr"]["components"]["clientapp"] + client_app_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"] # Set sys.path nonlocal inserted_path diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index 5b8f92b0e0f6..b6baca0dff54 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -187,7 +187,7 @@ def run_server_app() -> None: # pylint: disable=too-many-branches run_ = driver.run server_app_dir = str(get_project_dir(run_.fab_id, run_.fab_version, flwr_dir)) config = get_project_config(server_app_dir) - server_app_attr = config["tool"]["flwr"]["components"]["serverapp"] + server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"] server_app_run_config = get_fused_config(run_, flwr_dir) else: # User provided `server-app`, but not `--run-id` diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index b652207961a1..320f839e9e01 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -61,7 +61,9 @@ def _register_nodes( def _register_node_states( - nodes_mapping: NodeToPartitionMapping, run: Run + nodes_mapping: NodeToPartitionMapping, + run: Run, + app_dir: Optional[str] = None, ) -> Dict[int, NodeState]: """Create NodeState objects and pre-register the context for the run.""" node_states: Dict[int, NodeState] = {} @@ -76,7 +78,9 @@ def _register_node_states( ) # Pre-register Context objects - node_states[node_id].register_context(run_id=run.run_id, run=run) + node_states[node_id].register_context( + run_id=run.run_id, run=run, app_dir=app_dir + ) return node_states @@ -256,6 +260,7 @@ def start_vce( backend_name: str, backend_config_json_stream: str, app_dir: str, + is_app: bool, f_stop: threading.Event, run: Run, flwr_dir: Optional[str] = None, @@ -309,7 +314,9 @@ def start_vce( ) # Construct mapping of NodeStates - node_states = _register_node_states(nodes_mapping=nodes_mapping, run=run) + node_states = _register_node_states( + nodes_mapping=nodes_mapping, run=run, app_dir=app_dir if is_app else None + ) # Load backend config log(DEBUG, "Supported backends: %s", list(supported_backends.keys())) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py index 4dfc08560523..33c359af5cc8 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -174,6 +174,7 @@ def start_and_shutdown( backend_config_json_stream=backend_config, state_factory=state_factory, app_dir=app_dir, + is_app=False, f_stop=f_stop, run=run, existing_nodes_mapping=nodes_mapping, diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index aeba6793355c..3fb294b1458a 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -18,14 +18,19 @@ import asyncio import json import logging +import sys import threading import traceback +from argparse import Namespace from logging import DEBUG, ERROR, INFO, WARNING +from pathlib import Path from time import sleep -from typing import Optional +from typing import Dict, List, Optional +from flwr.cli.config_utils import load_and_validate from flwr.client import ClientApp from flwr.common import EventType, event, log +from flwr.common.config import get_fused_config_from_dir, parse_config_args from flwr.common.constant import RUN_ID_NUM_BYTES from flwr.common.logger import set_logger_propagation, update_console_handler from flwr.common.typing import Run, UserConfig @@ -41,28 +46,129 @@ ) +def _check_args_do_not_interfere(args: Namespace) -> bool: + """Ensure decoupling of flags for different ways to start the simulation.""" + mode_one_args = ["app", "run_config"] + mode_two_args = ["client_app", "server_app"] + + def _resolve_message(conflict_keys: List[str]) -> str: + return ",".join([f"`--{key}`".replace("_", "-") for key in conflict_keys]) + + # When passing `--app`, `--app-dir` is ignored + if args.app and args.app_dir: + log(ERROR, "Either `--app` or `--app-dir` can be set, but not both.") + return False + + if any(getattr(args, key) for key in mode_one_args): + if any(getattr(args, key) for key in mode_two_args): + log( + ERROR, + "Passing any of {%s} alongside with any of {%s}", + _resolve_message(mode_one_args), + _resolve_message(mode_two_args), + ) + return False + + if not args.app: + log(ERROR, "You need to pass --app") + return False + + return True + + # Ensure all args are set (required for the non-FAB mode of execution) + if not all(getattr(args, key) for key in mode_two_args): + log( + ERROR, + "Passing all of %s keys are required.", + _resolve_message(mode_two_args), + ) + return False + + return True + + # Entry point from CLI +# pylint: disable=too-many-locals def run_simulation_from_cli() -> None: """Run Simulation Engine from the CLI.""" args = _parse_args_run_simulation().parse_args() + # We are supporting two modes for the CLI entrypoint: + # 1) Running an app dir containing a `pyproject.toml` + # 2) Running any ClientApp and SeverApp w/o pyproject.toml being present + # For 2), some CLI args are compulsory, but they are not required for 1) + # We first do these checks + args_check_pass = _check_args_do_not_interfere(args) + if not args_check_pass: + sys.exit("Simulation Engine cannot start.") + + run_id = ( + generate_rand_int_from_bytes(RUN_ID_NUM_BYTES) + if args.run_id is None + else args.run_id + ) + if args.app: + # Mode 1 + app_path = Path(args.app) + if not app_path.is_dir(): + log(ERROR, "--app is not a directory") + sys.exit("Simulation Engine cannot start.") + + # Load pyproject.toml + config, errors, warnings = load_and_validate( + app_path / "pyproject.toml", check_module=False + ) + if errors: + raise ValueError(errors) + + if warnings: + log(WARNING, warnings) + + if config is None: + raise ValueError("Config extracted from FAB's pyproject.toml is not valid") + + # Get ClientApp and SeverApp components + app_components = config["tool"]["flwr"]["app"]["components"] + client_app_attr = app_components["clientapp"] + server_app_attr = app_components["serverapp"] + + override_config = parse_config_args(args.run_config) + fused_config = get_fused_config_from_dir(app_path, override_config) + app_dir = args.app + is_app = True + + else: + # Mode 2 + client_app_attr = args.client_app + server_app_attr = args.server_app + override_config = {} + fused_config = None + app_dir = args.app_dir + is_app = False + + # Create run + run = Run( + run_id=run_id, + fab_id="", + fab_version="", + override_config=override_config, + ) + # Load JSON config backend_config_dict = json.loads(args.backend_config) _run_simulation( - server_app_attr=args.server_app, - client_app_attr=args.client_app, + server_app_attr=server_app_attr, + client_app_attr=client_app_attr, num_supernodes=args.num_supernodes, backend_name=args.backend, backend_config=backend_config_dict, - app_dir=args.app_dir, - run=( - Run(run_id=args.run_id, fab_id="", fab_version="", override_config={}) - if args.run_id - else None - ), + app_dir=app_dir, + run=run, enable_tf_gpu_growth=args.enable_tf_gpu_growth, verbose_logging=args.verbose, + server_app_run_config=fused_config, + is_app=is_app, ) @@ -205,6 +311,7 @@ def _main_loop( backend_name: str, backend_config_stream: str, app_dir: str, + is_app: bool, enable_tf_gpu_growth: bool, run: Run, flwr_dir: Optional[str] = None, @@ -212,6 +319,7 @@ def _main_loop( client_app_attr: Optional[str] = None, server_app: Optional[ServerApp] = None, server_app_attr: Optional[str] = None, + server_app_run_config: Optional[UserConfig] = None, ) -> None: """Launch SuperLink with Simulation Engine, then ServerApp on a separate thread.""" # Initialize StateFactory @@ -225,7 +333,9 @@ def _main_loop( # Register run log(DEBUG, "Pre-registering run with id %s", run.run_id) state_factory.state().run_ids[run.run_id] = run # type: ignore - server_app_run_config: UserConfig = {} + + if server_app_run_config is None: + server_app_run_config = {} # Initialize Driver driver = InMemoryDriver(run_id=run.run_id, state_factory=state_factory) @@ -251,6 +361,7 @@ def _main_loop( backend_name=backend_name, backend_config_json_stream=backend_config_stream, app_dir=app_dir, + is_app=is_app, state_factory=state_factory, f_stop=f_stop, run=run, @@ -284,11 +395,13 @@ def _run_simulation( backend_config: Optional[BackendConfig] = None, client_app_attr: Optional[str] = None, server_app_attr: Optional[str] = None, + server_app_run_config: Optional[UserConfig] = None, app_dir: str = "", flwr_dir: Optional[str] = None, run: Optional[Run] = None, enable_tf_gpu_growth: bool = False, verbose_logging: bool = False, + is_app: bool = False, ) -> None: r"""Launch the Simulation Engine. @@ -317,14 +430,18 @@ def _run_simulation( parameters. Values supported in are those included by `flwr.common.typing.ConfigsRecordValues`. - client_app_attr : str + client_app_attr : Optional[str] A path to a `ClientApp` module to be loaded: For example: `client:app` or `project.package.module:wrapper.app`." - server_app_attr : str + server_app_attr : Optional[str] A path to a `ServerApp` module to be loaded: For example: `server:app` or `project.package.module:wrapper.app`." + server_app_run_config : Optional[UserConfig] + Config dictionary that parameterizes the run config. It will be made accesible + to the ServerApp. + app_dir : str Add specified directory to the PYTHONPATH and load `ClientApp` from there. (Default: current working directory.) @@ -346,6 +463,11 @@ def _run_simulation( verbose_logging : bool (default: False) When disabled, only INFO, WARNING and ERROR log messages will be shown. If enabled, DEBUG-level logs will be displayed. + + is_app : bool (default: False) + A flag that indicates whether the simulation is running an app or not. This is + needed in order to attempt loading an app's pyproject.toml when nodes register + a context object. """ if backend_config is None: backend_config = {} @@ -381,6 +503,7 @@ def _run_simulation( backend_name, backend_config_stream, app_dir, + is_app, enable_tf_gpu_growth, run, flwr_dir, @@ -388,6 +511,7 @@ def _run_simulation( client_app_attr, server_app, server_app_attr, + server_app_run_config, ) # Detect if there is an Asyncio event loop already running. # If yes, disable logger propagation. In environmnets @@ -419,12 +543,10 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser: ) parser.add_argument( "--server-app", - required=True, help="For example: `server:app` or `project.package.module:wrapper.app`", ) parser.add_argument( "--client-app", - required=True, help="For example: `client:app` or `project.package.module:wrapper.app`", ) parser.add_argument( @@ -433,6 +555,18 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser: required=True, help="Number of simulated SuperNodes.", ) + parser.add_argument( + "--app", + type=str, + default=None, + help="Path to a directory containing a FAB-like structure with a " + "pyproject.toml.", + ) + parser.add_argument( + "--run-config", + default=None, + help="Override configuration key-value pairs.", + ) parser.add_argument( "--backend", default="ray", diff --git a/src/py/flwr/superexec/simulation.py b/src/py/flwr/superexec/simulation.py index c44df7bbcf69..b44c716b84cc 100644 --- a/src/py/flwr/superexec/simulation.py +++ b/src/py/flwr/superexec/simulation.py @@ -115,7 +115,7 @@ def start_run( ) # Get ClientApp and SeverApp components - flower_components = config["tool"]["flwr"]["components"] + flower_components = config["tool"]["flwr"]["app"]["components"] clientapp = flower_components["clientapp"] serverapp = flower_components["serverapp"]