From 9831e8a68abfbd6024d9d34076299cc2e17c0ed0 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Sun, 14 Jul 2024 22:27:28 +0200 Subject: [PATCH 01/16] docs(framework) Document public/private API approach (#3562) --- ...or-explanation-public-and-private-apis.rst | 118 ++++++++++++++++++ doc/source/index.rst | 1 + 2 files changed, 119 insertions(+) create mode 100644 doc/source/contributor-explanation-public-and-private-apis.rst diff --git a/doc/source/contributor-explanation-public-and-private-apis.rst b/doc/source/contributor-explanation-public-and-private-apis.rst new file mode 100644 index 000000000000..1dfdf88f97d3 --- /dev/null +++ b/doc/source/contributor-explanation-public-and-private-apis.rst @@ -0,0 +1,118 @@ +Public and private APIs +======================= + +In Python, everything is public. +To enable developers to understand which components can be relied upon, Flower declares a public API. +Components that are part of the public API can be relied upon. +Changes to the public API are announced in the release notes and are subject to deprecation policies. + +Everything that is not part of the public API is part of the private API. +Even though Python allows accessing them, user code should never use those components. +Private APIs can change at any time, even in patch releases. + +How can you determine whether a component is part of the public API or not? Easy: + +- `Use the Flower API reference documentation `_ +- `Use the Flower CLI reference documentation `_ + +Everything listed in the reference documentation is part of the public API. +This document explains how Flower maintainers define the public API and how you can determine whether a component is part of the public API or not by reading the Flower source code. + +Flower public API +----------------- + +Flower has a well-defined public API. Let's look at this in more detail. + +.. important:: + + Every component that is reachable by recursively following ``__init__.__all__`` starting from the root package (``flwr``) is part of the public API. + +If you want to determine whether a component (class/function/generator/...) is part of the public API or not, you need to start at the root of the ``flwr`` package. +Let's use ``tree -L 1 -d src/py/flwr`` to look at the Python sub-packages contained ``flwr``: + +.. code-block:: bash + + flwr + ├── cli + ├── client + ├── common + ├── proto + ├── server + └── simulation + +Contrast this with the definition of ``__all__`` in the root ``src/py/flwr/__init__.py``: + +.. code-block:: python + + # From `flwr/__init__.py` + __all__ = [ + "client", + "common", + "server", + "simulation", + ] + +You can see that ``flwr`` has six subpackages (``cli``, ``client``, ``common``, ``proto``, ``server``, ``simulation``), but only four of them are "exported" via ``__all__`` (``client``, ``common``, ``server``, ``simulation``). + +What does this mean? It means that ``client``, ``common``, ``server`` and ``simulation`` are part of the public API, but ``cli`` and ``proto`` are not. +The ``flwr`` subpackages ``cli`` and ``proto`` are private APIs. +A private API can change completely from one release to the next (even in patch releases). +It can change in a breaking way, it can be renamed (for example, ``flwr.cli`` could be renamed to ``flwr.command``) and it can even be removed completely. + +Therefore, as a Flower user: + +- ``from flwr import client`` ✅ Ok, you're importing a public API. +- ``from flwr import proto`` ❌ Not recommended, you're importing a private API. + +What about components that are nested deeper in the hierarchy? Let's look at Flower strategies to see another typical pattern. +Flower strategies like ``FedAvg`` are often imported using ``from flwr.server.strategy import FedAvg``. +Let's look at ``src/py/flwr/server/strategy/__init__.py``: + +.. code-block:: python + + from .fedavg import FedAvg as FedAvg + # ... more imports + + __all__ = [ + "FedAvg", + # ... more exports + ] + +What's notable here is that all strategies are implemented in dedicated modules (e.g., ``fedavg.py``). +In ``__init__.py``, we *import* the components we want to make part of the public API and then *export* them via ``__all__``. +Note that we export the component itself (for example, the ``FedAvg`` class), but not the module it is defined in (for example, ``fedavg.py``). +This allows us to move the definition of ``FedAvg`` into a different module (or even a module in a subpackage) without breaking the public API (as long as we update the import path in ``__init__.py``). + +Therefore: + +- ``from flwr.server.strategy import FedAvg`` ✅ Ok, you're importing a class that is part of the public API. +- ``from flwr.server.strategy import fedavg`` ❌ Not recommended, you're importing a private module. + +This approach is also implemented in the tooling that automatically builds API reference docs. + +Flower public API of private packages +------------------------------------- + +We also use this to define the public API of private subpackages. +Public, in this context, means the API that other ``flwr`` subpackages should use. +For example, ``flwr.server.driver`` is a private subpackage (it's not exported via ``src/py/flwr/server/__init__.py``'s ``__all__``). + +Still, the private sub-package ``flwr.server.driver`` defines a "public" API using ``__all__`` in ``src/py/flwr/server/driver/__init__.py``: + +.. code-block:: python + + from .driver import Driver + from .grpc_driver import GrpcDriver + from .inmemory_driver import InMemoryDriver + + __all__ = [ + "Driver", + "GrpcDriver", + "InMemoryDriver", + ] + +The interesting part is that both ``GrpcDriver`` and ``InMemoryDriver`` are never used by Flower framework users, only by other parts of the Flower framework codebase. +Those other parts of the codebase import, for example, ``InMemoryDriver`` using ``from flwr.server.driver import InMemoryDriver`` (i.e., the ``InMemoryDriver`` exported via ``__all__``), not ``from flwr.server.driver.in_memory_driver import InMemoryDriver`` (``in_memory_driver.py`` is the module containing the actual ``InMemoryDriver`` class definition). + +This is because ``flwr.server.driver`` defines a public interface for other ``flwr`` subpackages. +This allows codeowners of ``flwr.server.driver`` to refactor the package without breaking other ``flwr``-internal users. diff --git a/doc/source/index.rst b/doc/source/index.rst index f62c5ebf4786..a0115620fce9 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -174,6 +174,7 @@ The Flower community welcomes contributions. The following docs are intended to :caption: Contributor explanations contributor-explanation-architecture + contributor-explanation-public-and-private-apis .. toctree:: :maxdepth: 1 From db9759733e49aed6df6eb971840087974c63ac2e Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Sun, 14 Jul 2024 22:43:38 +0200 Subject: [PATCH 02/16] fix(framework:skip) Use correct arguments (#3799) --- src/py/flwr/superexec/deployment.py | 34 +++++++++++++++++------------ 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/py/flwr/superexec/deployment.py b/src/py/flwr/superexec/deployment.py index 3a3bc3bf2b1e..bbe7882692f0 100644 --- a/src/py/flwr/superexec/deployment.py +++ b/src/py/flwr/superexec/deployment.py @@ -144,21 +144,27 @@ def start_run( run_id: int = self._create_run(fab_id, fab_version, override_config) log(INFO, "Created run %s", str(run_id)) - # Start ServerApp + command = [ + "flower-server-app", + "--run-id", + str(run_id), + "--superlink", + str(self.superlink), + ] + + if self.flwr_dir: + command.append("--flwr-dir") + command.append(self.flwr_dir) + + if self.root_certificates is None: + command.append("--insecure") + else: + command.append("--root-certificates") + command.append(self.root_certificates) + + # Execute the command proc = subprocess.Popen( # pylint: disable=consider-using-with - [ - "flower-server-app", - "--run-id", - str(run_id), - f"--flwr-dir {self.flwr_dir}" if self.flwr_dir else "", - "--superlink", - self.superlink, - ( - "--insecure" - if self.root_certificates is None - else f"--root-certificates {self.root_certificates}" - ), - ], + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, From 5155a62de2721fd58fd85cecfcc51c4302b940fb Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 15 Jul 2024 10:12:54 +0200 Subject: [PATCH 03/16] feat(framework) Add simulation engine `SuperExec` plugin (#3589) Co-authored-by: Charles Beauville Co-authored-by: Daniel J. Beutel --- src/py/flwr/superexec/simulation.py | 157 ++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 src/py/flwr/superexec/simulation.py diff --git a/src/py/flwr/superexec/simulation.py b/src/py/flwr/superexec/simulation.py new file mode 100644 index 000000000000..9a8e19365ab9 --- /dev/null +++ b/src/py/flwr/superexec/simulation.py @@ -0,0 +1,157 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Simulation engine executor.""" + + +import subprocess +import sys +from logging import ERROR, INFO, WARN +from typing import Dict, Optional + +from typing_extensions import override + +from flwr.cli.config_utils import load_and_validate +from flwr.cli.install import install_from_fab +from flwr.common.constant import RUN_ID_NUM_BYTES +from flwr.common.logger import log +from flwr.server.superlink.state.utils import generate_rand_int_from_bytes + +from .executor import Executor, RunTracker + + +class SimulationEngine(Executor): + """Simulation engine executor. + + Parameters + ---------- + num_supernodes: Opitonal[str] (default: None) + Total number of nodes to involve in the simulation. + """ + + def __init__( + self, + num_supernodes: Optional[str] = None, + ) -> None: + self.num_supernodes = num_supernodes + + @override + def set_config( + self, + config: Dict[str, str], + ) -> None: + """Set executor config arguments. + + Parameters + ---------- + config : Dict[str, str] + A dictionary for configuration values. + Supported configuration key/value pairs: + - "num-supernodes": str + Number of nodes to register for the simulation. + """ + if not config: + return + if num_supernodes := config.get("num-supernodes"): + self.num_supernodes = num_supernodes + + # Validate config + if self.num_supernodes is None: + log( + ERROR, + "To start a run with the simulation plugin, please specify " + "the number of SuperNodes. This can be done by using the " + "`--executor-config` argument when launching the SuperExec.", + ) + raise ValueError("`num-supernodes` must not be `None`") + + @override + def start_run( + self, fab_file: bytes, override_config: Dict[str, str] + ) -> Optional[RunTracker]: + """Start run using the Flower Simulation Engine.""" + try: + if override_config: + raise ValueError( + "Overriding the run config is not yet supported with the " + "simulation executor.", + ) + + # Install FAB to flwr dir + fab_path = install_from_fab(fab_file, None, True) + + # Install FAB Python package + subprocess.check_call( + [sys.executable, "-m", "pip", "install", "--no-deps", str(fab_path)], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + # Load and validate config + config, errors, warnings = load_and_validate(fab_path / "pyproject.toml") + if errors: + raise ValueError(errors) + + if warnings: + log(WARN, warnings) + + if config is None: + raise ValueError( + "Config extracted from FAB's pyproject.toml is not valid" + ) + + # Get ClientApp and SeverApp components + flower_components = config["flower"]["components"] + clientapp = flower_components["clientapp"] + serverapp = flower_components["serverapp"] + + # In Simulation there is no SuperLink, still we create a run_id + run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES) + log(INFO, "Created run %s", str(run_id)) + + # Prepare commnand + command = [ + "flower-simulation", + "--client-app", + f"{clientapp}", + "--server-app", + f"{serverapp}", + "--num-supernodes", + f"{self.num_supernodes}", + "--run-id", + str(run_id), + ] + + # Start Simulation + proc = subprocess.Popen( # pylint: disable=consider-using-with + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + log(INFO, "Started run %s", str(run_id)) + + return RunTracker( + run_id=run_id, + proc=proc, + ) + + # pylint: disable-next=broad-except + except Exception as e: + log(ERROR, "Could not start run: %s", str(e)) + return None + + +executor = SimulationEngine() From 31b86b0acdbefa87ce288ada9f3190035d42e2ea Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 15 Jul 2024 16:06:16 +0200 Subject: [PATCH 04/16] refactor(framework) Replace `run_id` with `Run` in simulation (#3802) --- src/py/flwr/simulation/run_simulation.py | 48 ++++++++++++------------ 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index de101fe3e09f..7060a972dd9a 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -26,14 +26,16 @@ from flwr.client import ClientApp from flwr.common import EventType, event, log +from flwr.common.constant import RUN_ID_NUM_BYTES from flwr.common.logger import set_logger_propagation, update_console_handler from flwr.common.typing import Run from flwr.server.driver import Driver, InMemoryDriver -from flwr.server.run_serverapp import run +from flwr.server.run_serverapp import run as run_server_app from flwr.server.server_app import ServerApp from flwr.server.superlink.fleet import vce from flwr.server.superlink.fleet.vce.backend.backend import BackendConfig from flwr.server.superlink.state import StateFactory +from flwr.server.superlink.state.utils import generate_rand_int_from_bytes from flwr.simulation.ray_transport.utils import ( enable_tf_gpu_growth as enable_gpu_growth, ) @@ -54,7 +56,11 @@ def run_simulation_from_cli() -> None: backend_name=args.backend, backend_config=backend_config_dict, app_dir=args.app_dir, - run_id=args.run_id, + run=( + Run(run_id=args.run_id, fab_id="", fab_version="", override_config={}) + if args.run_id + else None + ), enable_tf_gpu_growth=args.enable_tf_gpu_growth, verbose_logging=args.verbose, ) @@ -156,7 +162,7 @@ def server_th_with_start_checks( enable_gpu_growth() # Run ServerApp - run( + run_server_app( driver=_driver, server_app_dir=_server_app_dir, server_app_run_config=_server_app_run_config, @@ -193,16 +199,6 @@ def server_th_with_start_checks( return serverapp_th -def _override_run_id(state: StateFactory, run_id_to_replace: int, run_id: int) -> None: - """Override the run_id of an existing Run.""" - log(DEBUG, "Pre-registering run with id %s", run_id) - # Remove run - run_info: Run = state.state().run_ids.pop(run_id_to_replace) # type: ignore - # Update with new run_id and insert back in state - run_info.run_id = run_id - state.state().run_ids[run_id] = run_info # type: ignore - - # pylint: disable=too-many-locals def _main_loop( num_supernodes: int, @@ -210,7 +206,7 @@ def _main_loop( backend_config_stream: str, app_dir: str, enable_tf_gpu_growth: bool, - run_id: Optional[int] = None, + run: Run, client_app: Optional[ClientApp] = None, client_app_attr: Optional[str] = None, server_app: Optional[ServerApp] = None, @@ -225,16 +221,13 @@ def _main_loop( server_app_thread_has_exception = threading.Event() serverapp_th = None try: - # Create run (with empty fab_id and fab_version) - run_id_ = state_factory.state().create_run("", "", {}) + # Register run + log(DEBUG, "Pre-registering run with id %s", run.run_id) + state_factory.state().run_ids[run.run_id] = run # type: ignore server_app_run_config: Dict[str, str] = {} - if run_id: - _override_run_id(state_factory, run_id_to_replace=run_id_, run_id=run_id) - run_id_ = run_id - # Initialize Driver - driver = InMemoryDriver(run_id=run_id_, state_factory=state_factory) + driver = InMemoryDriver(run_id=run.run_id, state_factory=state_factory) # Get and run ServerApp thread serverapp_th = run_serverapp_th( @@ -289,7 +282,7 @@ def _run_simulation( client_app_attr: Optional[str] = None, server_app_attr: Optional[str] = None, app_dir: str = "", - run_id: Optional[int] = None, + run: Optional[Run] = None, enable_tf_gpu_growth: bool = False, verbose_logging: bool = False, ) -> None: @@ -332,8 +325,8 @@ def _run_simulation( Add specified directory to the PYTHONPATH and load `ClientApp` from there. (Default: current working directory.) - run_id : Optional[int] - An integer specifying the ID of the run started when running this function. + run : Optional[Run] + An object carrying details about the run. enable_tf_gpu_growth : bool (default: False) A boolean to indicate whether to enable GPU growth on the main thread. This is @@ -371,13 +364,18 @@ def _run_simulation( # Convert config to original JSON-stream format backend_config_stream = json.dumps(backend_config) + # If no `Run` object is set, create one + if run is None: + run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES) + run = Run(run_id=run_id, fab_id="", fab_version="", override_config={}) + args = ( num_supernodes, backend_name, backend_config_stream, app_dir, enable_tf_gpu_growth, - run_id, + run, client_app, client_app_attr, server_app, From ee5b2878bd6efe35a649f044a16f9e5e6f95d4a2 Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 15 Jul 2024 16:29:08 +0200 Subject: [PATCH 05/16] refactor(framework) Register `Context` early in Simulation Engine (#3804) --- .../server/superlink/fleet/vce/vce_api.py | 42 ++++++++++++------- .../superlink/fleet/vce/vce_api_test.py | 3 ++ src/py/flwr/simulation/run_simulation.py | 1 + 3 files changed, 31 insertions(+), 15 deletions(-) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index cd30c40167c5..66bca5a391c7 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -39,6 +39,7 @@ from flwr.common.message import Error from flwr.common.object_ref import load_app from flwr.common.serde import message_from_taskins, message_to_taskres +from flwr.common.typing import Run from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from flwr.server.superlink.state import State, StateFactory @@ -60,6 +61,27 @@ def _register_nodes( return nodes_mapping +def _register_node_states( + nodes_mapping: NodeToPartitionMapping, run: Run +) -> Dict[int, NodeState]: + """Create NodeState objects and pre-register the context for the run.""" + node_states: Dict[int, NodeState] = {} + num_partitions = len(set(nodes_mapping.values())) + for node_id, partition_id in nodes_mapping.items(): + node_states[node_id] = NodeState( + node_id=node_id, + node_config={ + PARTITION_ID_KEY: str(partition_id), + NUM_PARTITIONS_KEY: str(num_partitions), + }, + ) + + # Pre-register Context objects + node_states[node_id].register_context(run_id=run.run_id, run=run) + + return node_states + + # pylint: disable=too-many-arguments,too-many-locals def worker( app_fn: Callable[[], ClientApp], @@ -78,8 +100,7 @@ def worker( task_ins: TaskIns = taskins_queue.get(timeout=1.0) node_id = task_ins.task.consumer.node_id - # Register and retrieve context - node_states[node_id].register_context(run_id=task_ins.run_id) + # Retrieve context context = node_states[node_id].retrieve_context(run_id=task_ins.run_id) # Convert TaskIns to Message @@ -151,7 +172,7 @@ def put_taskres_into_state( pass -def run( +def run_api( app_fn: Callable[[], ClientApp], backend_fn: Callable[[], Backend], nodes_mapping: NodeToPartitionMapping, @@ -237,6 +258,7 @@ def start_vce( backend_config_json_stream: str, app_dir: str, f_stop: threading.Event, + run: Run, client_app: Optional[ClientApp] = None, client_app_attr: Optional[str] = None, num_supernodes: Optional[int] = None, @@ -287,17 +309,7 @@ def start_vce( ) # Construct mapping of NodeStates - node_states: Dict[int, NodeState] = {} - # Number of unique partitions - num_partitions = len(set(nodes_mapping.values())) - for node_id, partition_id in nodes_mapping.items(): - node_states[node_id] = NodeState( - node_id=node_id, - node_config={ - PARTITION_ID_KEY: str(partition_id), - NUM_PARTITIONS_KEY: str(num_partitions), - }, - ) + node_states = _register_node_states(nodes_mapping=nodes_mapping, run=run) # Load backend config log(DEBUG, "Supported backends: %s", list(supported_backends.keys())) @@ -348,7 +360,7 @@ def _load() -> ClientApp: _ = app_fn() # Run main simulation loop - run( + run_api( app_fn, backend_fn, nodes_mapping, diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py index 7d37f03f6ade..4dfc08560523 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -165,6 +165,8 @@ def start_and_shutdown( if not app_dir: app_dir = _autoresolve_app_dir() + run = Run(run_id=1234, fab_id="", fab_version="", override_config={}) + start_vce( num_supernodes=num_supernodes, client_app_attr=client_app_attr, @@ -173,6 +175,7 @@ def start_and_shutdown( state_factory=state_factory, app_dir=app_dir, f_stop=f_stop, + run=run, existing_nodes_mapping=nodes_mapping, ) diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 7060a972dd9a..60e6e16eed27 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -252,6 +252,7 @@ def _main_loop( app_dir=app_dir, state_factory=state_factory, f_stop=f_stop, + run=run, ) except Exception as ex: From 2f6cec2dcc51303ae44f7e65ae36cea26b84edad Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Mon, 15 Jul 2024 18:39:36 +0200 Subject: [PATCH 06/16] feat(framework) Use federations config in `flwr run` (#3800) Co-authored-by: Daniel J. Beutel --- .../app/pyproject.flowertune.toml.tpl | 14 +-- .../new/templates/app/pyproject.hf.toml.tpl | 11 +- .../new/templates/app/pyproject.jax.toml.tpl | 9 +- .../new/templates/app/pyproject.mlx.toml.tpl | 11 +- .../templates/app/pyproject.numpy.toml.tpl | 11 +- .../templates/app/pyproject.pytorch.toml.tpl | 11 +- .../templates/app/pyproject.sklearn.toml.tpl | 11 +- .../app/pyproject.tensorflow.toml.tpl | 11 +- src/py/flwr/cli/run/run.py | 111 +++++++++--------- 9 files changed, 87 insertions(+), 113 deletions(-) diff --git a/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl index 2ed6bd36fd89..109cbf66a35b 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl @@ -6,9 +6,6 @@ build-backend = "hatchling.build" name = "$package_name" version = "1.0.0" description = "" -authors = [ - { name = "The Flower Authors", email = "hello@flower.ai" }, -] license = { text = "Apache License (2.0)" } dependencies = [ "flwr[simulation]>=1.9.0,<2.0", @@ -32,11 +29,8 @@ publisher = "$username" serverapp = "$import_name.app:server" clientapp = "$import_name.app:client" -[flower.engine] -name = "simulation" - -[flower.engine.simulation.supernode] -num = $num_clients +[flower.federations] +default = "localhost" -[flower.engine.simulation] -backend_config = { client_resources = { num_cpus = 8, num_gpus = 1.0 } } +[flower.federations.localhost] +options.num-supernodes = 10 diff --git a/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl index 71004f3421cd..6c7e50393098 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl @@ -6,9 +6,6 @@ build-backend = "hatchling.build" name = "$package_name" version = "1.0.0" description = "" -authors = [ - { name = "The Flower Authors", email = "hello@flower.ai" }, -] license = { text = "Apache License (2.0)" } dependencies = [ "flwr[simulation]>=1.9.0,<2.0", @@ -30,8 +27,8 @@ publisher = "$username" serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" -[flower.engine] -name = "simulation" +[flower.federations] +default = "localhost" -[flower.engine.simulation.supernode] -num = 2 +[flower.federations.localhost] +options.num-supernodes = 10 diff --git a/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl index c5463e08b92c..f5c66cc729b8 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl @@ -6,9 +6,6 @@ build-backend = "hatchling.build" name = "$package_name" version = "1.0.0" description = "" -authors = [ - { name = "The Flower Authors", email = "hello@flower.ai" }, -] license = {text = "Apache License (2.0)"} dependencies = [ "flwr[simulation]>=1.9.0,<2.0", @@ -26,3 +23,9 @@ publisher = "$username" [flower.components] serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" + +[flower.federations] +default = "localhost" + +[flower.federations.localhost] +options.num-supernodes = 10 diff --git a/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl index a850135a1fc5..eaeec144adb2 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl @@ -6,9 +6,6 @@ build-backend = "hatchling.build" name = "$package_name" version = "1.0.0" description = "" -authors = [ - { name = "The Flower Authors", email = "hello@flower.ai" }, -] license = { text = "Apache License (2.0)" } dependencies = [ "flwr[simulation]>=1.9.0,<2.0", @@ -27,8 +24,8 @@ publisher = "$username" serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" -[flower.engine] -name = "simulation" +[flower.federations] +default = "localhost" -[flower.engine.simulation.supernode] -num = 2 +[flower.federations.localhost] +options.num-supernodes = 10 diff --git a/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl index d49015eb567f..6f386990ba6e 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl @@ -6,9 +6,6 @@ build-backend = "hatchling.build" name = "$package_name" version = "1.0.0" description = "" -authors = [ - { name = "The Flower Authors", email = "hello@flower.ai" }, -] license = { text = "Apache License (2.0)" } dependencies = [ "flwr[simulation]>=1.9.0,<2.0", @@ -25,8 +22,8 @@ publisher = "$username" serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" -[flower.engine] -name = "simulation" +[flower.federations] +default = "localhost" -[flower.engine.simulation.supernode] -num = 2 +[flower.federations.localhost] +options.num-supernodes = 10 diff --git a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl index b56c0041b96c..4313079fa74a 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl @@ -6,9 +6,6 @@ build-backend = "hatchling.build" name = "$package_name" version = "1.0.0" description = "" -authors = [ - { name = "The Flower Authors", email = "hello@flower.ai" }, -] license = { text = "Apache License (2.0)" } dependencies = [ "flwr[simulation]>=1.9.0,<2.0", @@ -27,8 +24,8 @@ publisher = "$username" serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" -[flower.engine] -name = "simulation" +[flower.federations] +default = "localhost" -[flower.engine.simulation.supernode] -num = 2 +[flower.federations.localhost] +options.num-supernodes = 10 diff --git a/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl index 6f914ae659b1..8ab7c10d0107 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl @@ -6,9 +6,6 @@ build-backend = "hatchling.build" name = "$package_name" version = "1.0.0" description = "" -authors = [ - { name = "The Flower Authors", email = "hello@flower.ai" }, -] license = { text = "Apache License (2.0)" } dependencies = [ "flwr[simulation]>=1.9.0,<2.0", @@ -26,8 +23,8 @@ publisher = "$username" serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" -[flower.engine] -name = "simulation" +[flower.federations] +default = "localhost" -[flower.engine.simulation.supernode] -num = 2 +[flower.federations.localhost] +options.num-supernodes = 10 diff --git a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl index 4ecd16143dcc..a64dfbe6bf77 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl @@ -6,9 +6,6 @@ build-backend = "hatchling.build" name = "$package_name" version = "1.0.0" description = "" -authors = [ - { name = "The Flower Authors", email = "hello@flower.ai" }, -] license = { text = "Apache License (2.0)" } dependencies = [ "flwr[simulation]>=1.9.0,<2.0", @@ -26,8 +23,8 @@ publisher = "$username" serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" -[flower.engine] -name = "simulation" +[flower.federations] +default = "localhost" -[flower.engine.simulation.supernode] -num = 2 +[flower.federations.localhost] +options.num-supernodes = 10 diff --git a/src/py/flwr/cli/run/run.py b/src/py/flwr/cli/run/run.py index b23ba3f7d0cf..76d1f47e4fa9 100644 --- a/src/py/flwr/cli/run/run.py +++ b/src/py/flwr/cli/run/run.py @@ -15,18 +15,16 @@ """Flower command line interface `run` command.""" import sys -from enum import Enum from logging import DEBUG from pathlib import Path -from typing import Dict, Optional +from typing import Any, Dict, Optional import typer from typing_extensions import Annotated -from flwr.cli import config_utils from flwr.cli.build import build +from flwr.cli.config_utils import load_and_validate from flwr.common.config import parse_config_args -from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel from flwr.common.logger import log from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611 @@ -34,31 +32,12 @@ from flwr.simulation.run_simulation import _run_simulation -class Engine(str, Enum): - """Enum defining the engine to run on.""" - - SIMULATION = "simulation" - - # pylint: disable-next=too-many-locals def run( - engine: Annotated[ - Optional[Engine], - typer.Option( - case_sensitive=False, - help="The engine to run FL with (currently only simulation is supported).", - ), - ] = None, - use_superexec: Annotated[ - bool, - typer.Option( - case_sensitive=False, help="Use this flag to use the new SuperExec API" - ), - ] = False, directory: Annotated[ - Optional[Path], - typer.Option(help="Path of the Flower project to run"), - ] = None, + Path, + typer.Argument(help="Path of the Flower project to run"), + ] = Path("."), config_overrides: Annotated[ Optional[str], typer.Option( @@ -72,7 +51,7 @@ def run( typer.secho("Loading project configuration... ", fg=typer.colors.BLUE) pyproject_path = directory / "pyproject.toml" if directory else None - config, errors, warnings = config_utils.load_and_validate(path=pyproject_path) + config, errors, warnings = load_and_validate(path=pyproject_path) if config is None: typer.secho( @@ -94,48 +73,37 @@ def run( typer.secho("Success", fg=typer.colors.GREEN) - if use_superexec: - _start_superexec_run( - parse_config_args(config_overrides, separator=","), directory - ) - return - - server_app_ref = config["flower"]["components"]["serverapp"] - client_app_ref = config["flower"]["components"]["clientapp"] - - if engine is None: - engine = config["flower"]["engine"]["name"] - - if engine == Engine.SIMULATION: - num_supernodes = config["flower"]["engine"]["simulation"]["supernode"]["num"] - backend_config = config["flower"]["engine"]["simulation"].get( - "backend_config", None - ) - - typer.secho("Starting run... ", fg=typer.colors.BLUE) - _run_simulation( - server_app_attr=server_app_ref, - client_app_attr=client_app_ref, - num_supernodes=num_supernodes, - backend_config=backend_config, - ) - else: + try: + federation_name = config["flower"]["federations"]["default"] + federation = config["flower"]["federations"][federation_name] + except KeyError as err: typer.secho( - f"Engine '{engine}' is not yet supported in `flwr run`", + "❌ The project's `pyproject.toml` needs to declare " + "a default federation (with a SuperExec address or an " + "`options.num-supernodes` value).", fg=typer.colors.RED, bold=True, ) + raise typer.Exit(code=1) from err + if "address" in federation: + _run_with_superexec(federation, directory, config_overrides) + else: + _run_without_superexec(config, federation, federation_name) -def _start_superexec_run( - override_config: Dict[str, str], directory: Optional[Path] + +def _run_with_superexec( + federation: Dict[str, str], + directory: Optional[Path], + config_overrides: Optional[str], ) -> None: + def on_channel_state_change(channel_connectivity: str) -> None: """Log channel connectivity.""" log(DEBUG, channel_connectivity) channel = create_channel( - server_address=SUPEREXEC_DEFAULT_ADDRESS, + server_address=federation["address"], insecure=True, root_certificates=None, max_message_length=GRPC_MAX_MESSAGE_LENGTH, @@ -148,7 +116,34 @@ def on_channel_state_change(channel_connectivity: str) -> None: req = StartRunRequest( fab_file=Path(fab_path).read_bytes(), - override_config=override_config, + override_config=parse_config_args(config_overrides, separator=","), ) res = stub.StartRun(req) typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN) + + +def _run_without_superexec( + config: Dict[str, Any], federation: Dict[str, Any], federation_name: str +) -> None: + server_app_ref = config["flower"]["components"]["serverapp"] + client_app_ref = config["flower"]["components"]["clientapp"] + + try: + num_supernodes = federation["options"]["num-supernodes"] + except KeyError as err: + typer.secho( + "❌ The project's `pyproject.toml` needs to declare the number of" + " SuperNodes in the simulation. To simulate 10 SuperNodes," + " use the following notation:\n\n" + f"[flower.federations.{federation_name}]\n" + "options.num-supernodes = 10\n", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) from err + + _run_simulation( + server_app_attr=server_app_ref, + client_app_attr=client_app_ref, + num_supernodes=num_supernodes, + ) From 125a0c7617e8081193b40584ad094e7ec43ccf2d Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 15 Jul 2024 20:19:16 +0200 Subject: [PATCH 07/16] refactor(framework) Refactor `ClientApp` loading to use explicit arguments (#3805) --- src/py/flwr/client/supernode/app.py | 37 +++++++++++++++++++---------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/src/py/flwr/client/supernode/app.py b/src/py/flwr/client/supernode/app.py index d61b986bc7af..2f2fa58b428c 100644 --- a/src/py/flwr/client/supernode/app.py +++ b/src/py/flwr/client/supernode/app.py @@ -60,7 +60,12 @@ def run_supernode() -> None: _warn_deprecated_server_arg(args) root_certificates = _get_certificates(args) - load_fn = _get_load_client_app_fn(args, multi_app=True) + load_fn = _get_load_client_app_fn( + default_app_ref=getattr(args, "client-app"), + dir_arg=args.dir, + flwr_dir_arg=args.flwr_dir, + multi_app=True, + ) authentication_keys = _try_setup_client_authentication(args) _start_client_internal( @@ -93,7 +98,11 @@ def run_client_app() -> None: _warn_deprecated_server_arg(args) root_certificates = _get_certificates(args) - load_fn = _get_load_client_app_fn(args, multi_app=False) + load_fn = _get_load_client_app_fn( + default_app_ref=getattr(args, "client-app"), + dir_arg=args.dir, + multi_app=False, + ) authentication_keys = _try_setup_client_authentication(args) _start_client_internal( @@ -166,7 +175,10 @@ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]: def _get_load_client_app_fn( - args: argparse.Namespace, multi_app: bool + default_app_ref: str, + dir_arg: str, + multi_app: bool, + flwr_dir_arg: Optional[str] = None, ) -> Callable[[str, str], ClientApp]: """Get the load_client_app_fn function. @@ -178,25 +190,24 @@ def _get_load_client_app_fn( loads a default ClientApp. """ # Find the Flower directory containing Flower Apps (only for multi-app) - flwr_dir = Path("") - if "flwr_dir" in args: - if args.flwr_dir is None: + if not multi_app: + flwr_dir = Path("") + else: + if flwr_dir_arg is None: flwr_dir = get_flwr_dir() else: - flwr_dir = Path(args.flwr_dir).absolute() + flwr_dir = Path(flwr_dir_arg).absolute() inserted_path = None - default_app_ref: str = getattr(args, "client-app") - if not multi_app: log( DEBUG, "Flower SuperNode will load and validate ClientApp `%s`", - getattr(args, "client-app"), + default_app_ref, ) # Insert sys.path - dir_path = Path(args.dir).absolute() + dir_path = Path(dir_arg).absolute() sys.path.insert(0, str(dir_path)) inserted_path = str(dir_path) @@ -208,7 +219,7 @@ def _load(fab_id: str, fab_version: str) -> ClientApp: # If multi-app feature is disabled if not multi_app: # Get sys path to be inserted - dir_path = Path(args.dir).absolute() + dir_path = Path(dir_arg).absolute() # Set app reference client_app_ref = default_app_ref @@ -221,7 +232,7 @@ def _load(fab_id: str, fab_version: str) -> ClientApp: log(WARN, "FAB ID is not provided; the default ClientApp will be loaded.") # Get sys path to be inserted - dir_path = Path(args.dir).absolute() + dir_path = Path(dir_arg).absolute() # Set app reference client_app_ref = default_app_ref From 285acfa210eec1532465a64e84f170209e2991ae Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Mon, 15 Jul 2024 21:22:17 +0200 Subject: [PATCH 08/16] feat(framework) Add federation argument to `flwr run` (#3807) --- src/py/flwr/cli/run/run.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/src/py/flwr/cli/run/run.py b/src/py/flwr/cli/run/run.py index 76d1f47e4fa9..1ae4017492b0 100644 --- a/src/py/flwr/cli/run/run.py +++ b/src/py/flwr/cli/run/run.py @@ -38,6 +38,10 @@ def run( Path, typer.Argument(help="Path of the Flower project to run"), ] = Path("."), + federation_name: Annotated[ + Optional[str], + typer.Argument(help="Name of the federation to run the app on"), + ] = None, config_overrides: Annotated[ Optional[str], typer.Option( @@ -73,18 +77,30 @@ def run( typer.secho("Success", fg=typer.colors.GREEN) - try: - federation_name = config["flower"]["federations"]["default"] - federation = config["flower"]["federations"][federation_name] - except KeyError as err: + federation_name = federation_name or config["flower"]["federations"].get("default") + + if federation_name is None: typer.secho( - "❌ The project's `pyproject.toml` needs to declare " - "a default federation (with a SuperExec address or an " + "❌ No federation name was provided and the project's `pyproject.toml` " + "doesn't declare a default federation (with a SuperExec address or an " "`options.num-supernodes` value).", fg=typer.colors.RED, bold=True, ) - raise typer.Exit(code=1) from err + raise typer.Exit(code=1) + + # Validate the federation exists in the configuration + federation = config["flower"]["federations"].get(federation_name) + if federation is None: + available_feds = list(config["flower"]["federations"]) + typer.secho( + f"❌ There is no `{federation_name}` federation declared in the " + "`pyproject.toml`.\n The following federations were found:\n\n" + "\n".join(available_feds) + "\n\n", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) if "address" in federation: _run_with_superexec(federation, directory, config_overrides) From 22bbc006cff57dc0963ad79b4d694aabe3b86e69 Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 15 Jul 2024 21:36:18 +0200 Subject: [PATCH 09/16] refactor(framework) Improve app loading in simulation engine (#3806) --- examples/simulation-pytorch/sim.py | 37 +++++++++++-------- .../server/superlink/fleet/vce/vce_api.py | 31 +++++++++------- src/py/flwr/simulation/run_simulation.py | 18 +++++++++ 3 files changed, 58 insertions(+), 28 deletions(-) diff --git a/examples/simulation-pytorch/sim.py b/examples/simulation-pytorch/sim.py index db68e75653fc..dcc0f39a79ef 100644 --- a/examples/simulation-pytorch/sim.py +++ b/examples/simulation-pytorch/sim.py @@ -87,11 +87,13 @@ def get_client_fn(dataset: FederatedDataset): the strategy to participate. """ - def client_fn(cid: str) -> fl.client.Client: + def client_fn(context) -> fl.client.Client: """Construct a FlowerClient with its own dataset partition.""" # Let's get the partition corresponding to the i-th client - client_dataset = dataset.load_partition(int(cid), "train") + client_dataset = dataset.load_partition( + int(context.node_config["partition-id"]), "train" + ) # Now let's split it into train (90%) and validation (10%) client_dataset_splits = client_dataset.train_test_split(test_size=0.1, seed=42) @@ -171,15 +173,23 @@ def evaluate( mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS}) centralized_testset = mnist_fds.load_split("test") -# Configure the strategy -strategy = fl.server.strategy.FedAvg( - fraction_fit=0.1, # Sample 10% of available clients for training - fraction_evaluate=0.05, # Sample 5% of available clients for evaluation - min_available_clients=10, - on_fit_config_fn=fit_config, - evaluate_metrics_aggregation_fn=weighted_average, # Aggregate federated metrics - evaluate_fn=get_evaluate_fn(centralized_testset), # Global evaluation function -) +from flwr.server import ServerAppComponents + + +def server_fn(context): + # Configure the strategy + strategy = fl.server.strategy.FedAvg( + fraction_fit=0.1, # Sample 10% of available clients for training + fraction_evaluate=0.05, # Sample 5% of available clients for evaluation + min_available_clients=10, + on_fit_config_fn=fit_config, + evaluate_metrics_aggregation_fn=weighted_average, # Aggregate federated metrics + evaluate_fn=get_evaluate_fn(centralized_testset), # Global evaluation function + ) + return ServerAppComponents( + strategy=strategy, config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS) + ) + # ClientApp for Flower-Next client = fl.client.ClientApp( @@ -187,10 +197,7 @@ def evaluate( ) # ServerApp for Flower-Next -server = fl.server.ServerApp( - config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS), - strategy=strategy, -) +server = fl.server.ServerApp(server_fn=server_fn) def main(): diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index 66bca5a391c7..b652207961a1 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -16,7 +16,6 @@ import json -import sys import threading import time import traceback @@ -29,6 +28,7 @@ from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError from flwr.client.node_state import NodeState +from flwr.client.supernode.app import _get_load_client_app_fn from flwr.common.constant import ( NUM_PARTITIONS_KEY, PARTITION_ID_KEY, @@ -37,7 +37,6 @@ ) from flwr.common.logger import log from flwr.common.message import Error -from flwr.common.object_ref import load_app from flwr.common.serde import message_from_taskins, message_to_taskres from flwr.common.typing import Run from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 @@ -259,6 +258,7 @@ def start_vce( app_dir: str, f_stop: threading.Event, run: Run, + flwr_dir: Optional[str] = None, client_app: Optional[ClientApp] = None, client_app_attr: Optional[str] = None, num_supernodes: Optional[int] = None, @@ -338,16 +338,12 @@ def backend_fn() -> Backend: def _load() -> ClientApp: if client_app_attr: - - if app_dir is not None: - sys.path.insert(0, app_dir) - - app: ClientApp = load_app(client_app_attr, LoadClientAppError, app_dir) - - if not isinstance(app, ClientApp): - raise LoadClientAppError( - f"Attribute {client_app_attr} is not of type {ClientApp}", - ) from None + app = _get_load_client_app_fn( + default_app_ref=client_app_attr, + dir_arg=app_dir, + flwr_dir_arg=flwr_dir, + multi_app=True, + )(run.fab_id, run.fab_version) if client_app: app = client_app @@ -357,7 +353,16 @@ def _load() -> ClientApp: try: # Test if ClientApp can be loaded - _ = app_fn() + client_app = app_fn() + + # Cache `ClientApp` + if client_app_attr: + # Now wrap the loaded ClientApp in a dummy function + # this prevent unnecesary low-level loading of ClientApp + def _load_client_app() -> ClientApp: + return client_app + + app_fn = _load_client_app # Run main simulation loop run_api( diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 60e6e16eed27..8c70bf8374d0 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -207,6 +207,7 @@ def _main_loop( app_dir: str, enable_tf_gpu_growth: bool, run: Run, + flwr_dir: Optional[str] = None, client_app: Optional[ClientApp] = None, client_app_attr: Optional[str] = None, server_app: Optional[ServerApp] = None, @@ -253,6 +254,7 @@ def _main_loop( state_factory=state_factory, f_stop=f_stop, run=run, + flwr_dir=flwr_dir, ) except Exception as ex: @@ -283,6 +285,7 @@ def _run_simulation( client_app_attr: Optional[str] = None, server_app_attr: Optional[str] = None, app_dir: str = "", + flwr_dir: Optional[str] = None, run: Optional[Run] = None, enable_tf_gpu_growth: bool = False, verbose_logging: bool = False, @@ -326,6 +329,9 @@ def _run_simulation( Add specified directory to the PYTHONPATH and load `ClientApp` from there. (Default: current working directory.) + flwr_dir : Optional[str] + The path containing installed Flower Apps. + run : Optional[Run] An object carrying details about the run. @@ -377,6 +383,7 @@ def _run_simulation( app_dir, enable_tf_gpu_growth, run, + flwr_dir, client_app, client_app_attr, server_app, @@ -464,6 +471,17 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser: "ClientApp and ServerApp from there." " Default: current working directory.", ) + parser.add_argument( + "--flwr-dir", + default=None, + help="""The path containing installed Flower Apps. + By default, this value is equal to: + + - `$FLWR_HOME/` if `$FLWR_HOME` is defined + - `$XDG_DATA_HOME/.flwr/` if `$XDG_DATA_HOME` is defined + - `$HOME/.flwr/` in all other cases + """, + ) parser.add_argument( "--run-id", type=int, From 5ec1697659a9f6396e40d1fcdb9017b55a259609 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Mon, 15 Jul 2024 22:26:37 +0200 Subject: [PATCH 10/16] refactor(framework) Switch to `tool.flwr` instead of `flower` in `pyproject.toml` (#3809) --- src/py/flwr/cli/build.py | 2 +- src/py/flwr/cli/config_utils.py | 30 +++--- src/py/flwr/cli/config_utils_test.py | 96 ++++++++----------- src/py/flwr/cli/install.py | 2 +- .../app/pyproject.flowertune.toml.tpl | 8 +- .../new/templates/app/pyproject.hf.toml.tpl | 8 +- .../new/templates/app/pyproject.jax.toml.tpl | 8 +- .../new/templates/app/pyproject.mlx.toml.tpl | 8 +- .../templates/app/pyproject.numpy.toml.tpl | 8 +- .../templates/app/pyproject.pytorch.toml.tpl | 8 +- .../templates/app/pyproject.sklearn.toml.tpl | 8 +- .../app/pyproject.tensorflow.toml.tpl | 8 +- src/py/flwr/cli/run/run.py | 14 +-- src/py/flwr/client/supernode/app.py | 2 +- src/py/flwr/common/config.py | 2 +- src/py/flwr/common/config_test.py | 38 ++++---- src/py/flwr/server/run_serverapp.py | 2 +- src/py/flwr/superexec/simulation.py | 2 +- 18 files changed, 120 insertions(+), 134 deletions(-) diff --git a/src/py/flwr/cli/build.py b/src/py/flwr/cli/build.py index f63d0acd5d73..599ce613698c 100644 --- a/src/py/flwr/cli/build.py +++ b/src/py/flwr/cli/build.py @@ -85,7 +85,7 @@ def build( # Set the name of the zip file fab_filename = ( - f"{conf['flower']['publisher']}" + f"{conf['tool']['flwr']['publisher']}" f".{directory.name}" f".{conf['project']['version'].replace('.', '-')}.fab" ) diff --git a/src/py/flwr/cli/config_utils.py b/src/py/flwr/cli/config_utils.py index 33bf12e34b04..9147ebba4995 100644 --- a/src/py/flwr/cli/config_utils.py +++ b/src/py/flwr/cli/config_utils.py @@ -60,7 +60,7 @@ def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]: return ( conf["project"]["version"], - f"{conf['flower']['publisher']}/{conf['project']['name']}", + f"{conf['tool']['flwr']['publisher']}/{conf['project']['name']}", ) @@ -136,20 +136,20 @@ def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]] if "authors" not in config["project"]: warnings.append('Recommended property "authors" missing in [project]') - if "flower" not in config: - errors.append("Missing [flower] section") + if "tool" not in config or "flwr" not in config["tool"]: + errors.append("Missing [tool.flwr] section") else: - if "publisher" not in config["flower"]: - errors.append('Property "publisher" missing in [flower]') - if "config" in config["flower"]: - _validate_run_config(config["flower"]["config"], errors) - if "components" not in config["flower"]: - errors.append("Missing [flower.components] section") + if "publisher" not in config["tool"]["flwr"]: + errors.append('Property "publisher" missing in [tool.flwr]') + if "config" in config["tool"]["flwr"]: + _validate_run_config(config["tool"]["flwr"]["config"], errors) + if "components" not in config["tool"]["flwr"]: + errors.append("Missing [tool.flwr.components] section") else: - if "serverapp" not in config["flower"]["components"]: - errors.append('Property "serverapp" missing in [flower.components]') - if "clientapp" not in config["flower"]["components"]: - errors.append('Property "clientapp" missing in [flower.components]') + if "serverapp" not in config["tool"]["flwr"]["components"]: + errors.append('Property "serverapp" missing in [tool.flwr.components]') + if "clientapp" not in config["tool"]["flwr"]["components"]: + errors.append('Property "clientapp" missing in [tool.flwr.components]') return len(errors) == 0, errors, warnings @@ -165,14 +165,14 @@ def validate( # Validate serverapp is_valid, reason = object_ref.validate( - config["flower"]["components"]["serverapp"], check_module + config["tool"]["flwr"]["components"]["serverapp"], check_module ) if not is_valid and isinstance(reason, str): return False, [reason], [] # Validate clientapp is_valid, reason = object_ref.validate( - config["flower"]["components"]["clientapp"], check_module + config["tool"]["flwr"]["components"]["clientapp"], check_module ) if not is_valid and isinstance(reason, str): diff --git a/src/py/flwr/cli/config_utils_test.py b/src/py/flwr/cli/config_utils_test.py index b24425cd08f4..35d9900703b6 100644 --- a/src/py/flwr/cli/config_utils_test.py +++ b/src/py/flwr/cli/config_utils_test.py @@ -34,27 +34,18 @@ def test_load_pyproject_toml_load_from_cwd(tmp_path: Path) -> None: name = "fedgpt" version = "1.0.0" description = "" - authors = [ - { name = "The Flower Authors", email = "hello@flower.ai" }, - ] license = {text = "Apache License (2.0)"} dependencies = [ "flwr[simulation]>=1.9.0,<2.0", "numpy>=1.21.0", ] - [flower] + [tool.flwr] publisher = "flwrlabs" - [flower.components] + [tool.flwr.components] serverapp = "fedgpt.server:app" clientapp = "fedgpt.client:app" - - [flower.engine] - name = "simulation" # optional - - [flower.engine.simulation.supernode] - count = 10 # optional """ expected_config = { "build-system": {"build-backend": "hatchling.build", "requires": ["hatchling"]}, @@ -62,19 +53,16 @@ def test_load_pyproject_toml_load_from_cwd(tmp_path: Path) -> None: "name": "fedgpt", "version": "1.0.0", "description": "", - "authors": [{"email": "hello@flower.ai", "name": "The Flower Authors"}], "license": {"text": "Apache License (2.0)"}, "dependencies": ["flwr[simulation]>=1.9.0,<2.0", "numpy>=1.21.0"], }, - "flower": { - "publisher": "flwrlabs", - "components": { - "serverapp": "fedgpt.server:app", - "clientapp": "fedgpt.client:app", - }, - "engine": { - "name": "simulation", - "simulation": {"supernode": {"count": 10}}, + "tool": { + "flwr": { + "publisher": "flwrlabs", + "components": { + "serverapp": "fedgpt.server:app", + "clientapp": "fedgpt.client:app", + }, }, }, } @@ -109,27 +97,18 @@ def test_load_pyproject_toml_from_path(tmp_path: Path) -> None: name = "fedgpt" version = "1.0.0" description = "" - authors = [ - { name = "The Flower Authors", email = "hello@flower.ai" }, - ] license = {text = "Apache License (2.0)"} dependencies = [ "flwr[simulation]>=1.9.0,<2.0", "numpy>=1.21.0", ] - [flower] + [tool.flwr] publisher = "flwrlabs" - [flower.components] + [tool.flwr.components] serverapp = "fedgpt.server:app" clientapp = "fedgpt.client:app" - - [flower.engine] - name = "simulation" # optional - - [flower.engine.simulation.supernode] - count = 10 # optional """ expected_config = { "build-system": {"build-backend": "hatchling.build", "requires": ["hatchling"]}, @@ -137,19 +116,16 @@ def test_load_pyproject_toml_from_path(tmp_path: Path) -> None: "name": "fedgpt", "version": "1.0.0", "description": "", - "authors": [{"email": "hello@flower.ai", "name": "The Flower Authors"}], "license": {"text": "Apache License (2.0)"}, "dependencies": ["flwr[simulation]>=1.9.0,<2.0", "numpy>=1.21.0"], }, - "flower": { - "publisher": "flwrlabs", - "components": { - "serverapp": "fedgpt.server:app", - "clientapp": "fedgpt.client:app", - }, - "engine": { - "name": "simulation", - "simulation": {"supernode": {"count": 10}}, + "tool": { + "flwr": { + "publisher": "flwrlabs", + "components": { + "serverapp": "fedgpt.server:app", + "clientapp": "fedgpt.client:app", + }, }, }, } @@ -219,7 +195,7 @@ def test_validate_pyproject_toml_fields_no_flower_components() -> None: "license": "", "authors": [], }, - "flower": {}, + "tool": {"flwr": {}}, } # Execute @@ -242,7 +218,7 @@ def test_validate_pyproject_toml_fields_no_server_and_client_app() -> None: "license": "", "authors": [], }, - "flower": {"components": {}}, + "tool": {"flwr": {"components": {}}}, } # Execute @@ -265,9 +241,11 @@ def test_validate_pyproject_toml_fields() -> None: "license": "", "authors": [], }, - "flower": { - "publisher": "flwrlabs", - "components": {"serverapp": "", "clientapp": ""}, + "tool": { + "flwr": { + "publisher": "flwrlabs", + "components": {"serverapp": "", "clientapp": ""}, + }, }, } @@ -291,11 +269,13 @@ def test_validate_pyproject_toml() -> None: "license": "", "authors": [], }, - "flower": { - "publisher": "flwrlabs", - "components": { - "serverapp": "flwr.cli.run:run", - "clientapp": "flwr.cli.run:run", + "tool": { + "flwr": { + "publisher": "flwrlabs", + "components": { + "serverapp": "flwr.cli.run:run", + "clientapp": "flwr.cli.run:run", + }, }, }, } @@ -320,11 +300,13 @@ def test_validate_pyproject_toml_fail() -> None: "license": "", "authors": [], }, - "flower": { - "publisher": "flwrlabs", - "components": { - "serverapp": "flwr.cli.run:run", - "clientapp": "flwr.cli.run:runa", + "tool": { + "flwr": { + "publisher": "flwrlabs", + "components": { + "serverapp": "flwr.cli.run:run", + "clientapp": "flwr.cli.run:runa", + }, }, }, } diff --git a/src/py/flwr/cli/install.py b/src/py/flwr/cli/install.py index de9227bee450..7444f10c1eb7 100644 --- a/src/py/flwr/cli/install.py +++ b/src/py/flwr/cli/install.py @@ -149,7 +149,7 @@ def validate_and_install( ) raise typer.Exit(code=1) - publisher = config["flower"]["publisher"] + publisher = config["tool"]["flwr"]["publisher"] project_name = config["project"]["name"] version = config["project"]["version"] diff --git a/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl index 109cbf66a35b..17630dd9d0dc 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl @@ -22,15 +22,15 @@ dependencies = [ [tool.hatch.build.targets.wheel] packages = ["."] -[flower] +[tool.flwr] publisher = "$username" -[flower.components] +[tool.flwr.components] serverapp = "$import_name.app:server" clientapp = "$import_name.app:client" -[flower.federations] +[tool.flwr.federations] default = "localhost" -[flower.federations.localhost] +[tool.flwr.federations.localhost] options.num-supernodes = 10 diff --git a/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl index 6c7e50393098..6f46d6de5bf5 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl @@ -20,15 +20,15 @@ dependencies = [ [tool.hatch.build.targets.wheel] packages = ["."] -[flower] +[tool.flwr] publisher = "$username" -[flower.components] +[tool.flwr.components] serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" -[flower.federations] +[tool.flwr.federations] default = "localhost" -[flower.federations.localhost] +[tool.flwr.federations.localhost] options.num-supernodes = 10 diff --git a/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl index f5c66cc729b8..045a1f4e57eb 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl @@ -17,15 +17,15 @@ dependencies = [ [tool.hatch.build.targets.wheel] packages = ["."] -[flower] +[tool.flwr] publisher = "$username" -[flower.components] +[tool.flwr.components] serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" -[flower.federations] +[tool.flwr.federations] default = "localhost" -[flower.federations.localhost] +[tool.flwr.federations.localhost] options.num-supernodes = 10 diff --git a/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl index eaeec144adb2..5ea2c420d6f8 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl @@ -17,15 +17,15 @@ dependencies = [ [tool.hatch.build.targets.wheel] packages = ["."] -[flower] +[tool.flwr] publisher = "$username" -[flower.components] +[tool.flwr.components] serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" -[flower.federations] +[tool.flwr.federations] default = "localhost" -[flower.federations.localhost] +[tool.flwr.federations.localhost] options.num-supernodes = 10 diff --git a/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl index 6f386990ba6e..d166616bb616 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl @@ -15,15 +15,15 @@ dependencies = [ [tool.hatch.build.targets.wheel] packages = ["."] -[flower] +[tool.flwr] publisher = "$username" -[flower.components] +[tool.flwr.components] serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" -[flower.federations] +[tool.flwr.federations] default = "localhost" -[flower.federations.localhost] +[tool.flwr.federations.localhost] options.num-supernodes = 10 diff --git a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl index 4313079fa74a..c0323126516d 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl @@ -17,15 +17,15 @@ dependencies = [ [tool.hatch.build.targets.wheel] packages = ["."] -[flower] +[tool.flwr] publisher = "$username" -[flower.components] +[tool.flwr.components] serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" -[flower.federations] +[tool.flwr.federations] default = "localhost" -[flower.federations.localhost] +[tool.flwr.federations.localhost] options.num-supernodes = 10 diff --git a/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl index 8ab7c10d0107..0e63375aab00 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl @@ -16,15 +16,15 @@ dependencies = [ [tool.hatch.build.targets.wheel] packages = ["."] -[flower] +[tool.flwr] publisher = "$username" -[flower.components] +[tool.flwr.components] serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" -[flower.federations] +[tool.flwr.federations] default = "localhost" -[flower.federations.localhost] +[tool.flwr.federations.localhost] options.num-supernodes = 10 diff --git a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl index a64dfbe6bf77..aeca4a17805f 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl @@ -16,15 +16,15 @@ dependencies = [ [tool.hatch.build.targets.wheel] packages = ["."] -[flower] +[tool.flwr] publisher = "$username" -[flower.components] +[tool.flwr.components] serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" -[flower.federations] +[tool.flwr.federations] default = "localhost" -[flower.federations.localhost] +[tool.flwr.federations.localhost] options.num-supernodes = 10 diff --git a/src/py/flwr/cli/run/run.py b/src/py/flwr/cli/run/run.py index 1ae4017492b0..c39ae0decd4b 100644 --- a/src/py/flwr/cli/run/run.py +++ b/src/py/flwr/cli/run/run.py @@ -77,7 +77,9 @@ def run( typer.secho("Success", fg=typer.colors.GREEN) - federation_name = federation_name or config["flower"]["federations"].get("default") + federation_name = federation_name or config["tool"]["flwr"]["federations"].get( + "default" + ) if federation_name is None: typer.secho( @@ -90,9 +92,9 @@ def run( raise typer.Exit(code=1) # Validate the federation exists in the configuration - federation = config["flower"]["federations"].get(federation_name) + federation = config["tool"]["flwr"]["federations"].get(federation_name) if federation is None: - available_feds = list(config["flower"]["federations"]) + available_feds = list(config["tool"]["flwr"]["federations"]) typer.secho( f"❌ There is no `{federation_name}` federation declared in the " "`pyproject.toml`.\n The following federations were found:\n\n" @@ -141,8 +143,8 @@ def on_channel_state_change(channel_connectivity: str) -> None: def _run_without_superexec( config: Dict[str, Any], federation: Dict[str, Any], federation_name: str ) -> None: - server_app_ref = config["flower"]["components"]["serverapp"] - client_app_ref = config["flower"]["components"]["clientapp"] + server_app_ref = config["tool"]["flwr"]["components"]["serverapp"] + client_app_ref = config["tool"]["flwr"]["components"]["clientapp"] try: num_supernodes = federation["options"]["num-supernodes"] @@ -151,7 +153,7 @@ def _run_without_superexec( "❌ The project's `pyproject.toml` needs to declare the number of" " SuperNodes in the simulation. To simulate 10 SuperNodes," " use the following notation:\n\n" - f"[flower.federations.{federation_name}]\n" + f"[tool.flwr.federations.{federation_name}]\n" "options.num-supernodes = 10\n", fg=typer.colors.RED, bold=True, diff --git a/src/py/flwr/client/supernode/app.py b/src/py/flwr/client/supernode/app.py index 2f2fa58b428c..027c3376b7f3 100644 --- a/src/py/flwr/client/supernode/app.py +++ b/src/py/flwr/client/supernode/app.py @@ -248,7 +248,7 @@ def _load(fab_id: str, fab_version: str) -> ClientApp: dir_path = Path(project_dir).absolute() # Set app reference - client_app_ref = config["flower"]["components"]["clientapp"] + client_app_ref = config["tool"]["flwr"]["components"]["clientapp"] # Set sys.path nonlocal inserted_path diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index 54d74353e4ed..e2b06ff86110 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -97,7 +97,7 @@ def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> Dict[str, str]: project_dir = get_project_dir(run.fab_id, run.fab_version, flwr_dir) - default_config = get_project_config(project_dir)["flower"].get("config", {}) + default_config = get_project_config(project_dir)["tool"]["flwr"].get("config", {}) flat_default_config = flatten_dict(default_config) return _fuse_dicts(flat_default_config, run.override_config) diff --git a/src/py/flwr/common/config_test.py b/src/py/flwr/common/config_test.py index fe429bab9cb5..899240c1e76a 100644 --- a/src/py/flwr/common/config_test.py +++ b/src/py/flwr/common/config_test.py @@ -93,20 +93,20 @@ def test_get_fused_config_valid(tmp_path: Path) -> None: "numpy>=1.21.0", ] - [flower] + [tool.flwr] publisher = "flwrlabs" - [flower.components] + [tool.flwr.components] serverapp = "fedgpt.server:app" clientapp = "fedgpt.client:app" - [flower.config] + [tool.flwr.config] num_server_rounds = "10" momentum = "0.1" lr = "0.01" serverapp.test = "key" - [flower.config.clientapp] + [tool.flwr.config.clientapp] test = "key" """ overrides = { @@ -131,7 +131,7 @@ def test_get_fused_config_valid(tmp_path: Path) -> None: f.write(textwrap.dedent(pyproject_toml_content)) # Execute - default_config = get_project_config(tmp_path)["flower"].get("config", {}) + default_config = get_project_config(tmp_path)["tool"]["flwr"].get("config", {}) config = _fuse_dicts(flatten_dict(default_config), overrides) @@ -158,14 +158,14 @@ def test_get_project_config_file_valid(tmp_path: Path) -> None: "numpy>=1.21.0", ] - [flower] + [tool.flwr] publisher = "flwrlabs" - [flower.components] + [tool.flwr.components] serverapp = "fedgpt.server:app" clientapp = "fedgpt.client:app" - [flower.config] + [tool.flwr.config] num_server_rounds = "10" momentum = "0.1" lr = "0.01" @@ -179,16 +179,18 @@ def test_get_project_config_file_valid(tmp_path: Path) -> None: "license": {"text": "Apache License (2.0)"}, "dependencies": ["flwr[simulation]>=1.9.0,<2.0", "numpy>=1.21.0"], }, - "flower": { - "publisher": "flwrlabs", - "components": { - "serverapp": "fedgpt.server:app", - "clientapp": "fedgpt.client:app", - }, - "config": { - "num_server_rounds": "10", - "momentum": "0.1", - "lr": "0.01", + "tool": { + "flwr": { + "publisher": "flwrlabs", + "components": { + "serverapp": "fedgpt.server:app", + "clientapp": "fedgpt.client:app", + }, + "config": { + "num_server_rounds": "10", + "momentum": "0.1", + "lr": "0.01", + }, }, }, } diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index 4cc25feb7e0e..efaba24f05f9 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -186,7 +186,7 @@ def run_server_app() -> None: # pylint: disable=too-many-branches run_ = driver.run server_app_dir = str(get_project_dir(run_.fab_id, run_.fab_version, flwr_dir)) config = get_project_config(server_app_dir) - server_app_attr = config["flower"]["components"]["serverapp"] + server_app_attr = config["tool"]["flwr"]["components"]["serverapp"] server_app_run_config = get_fused_config(run_, flwr_dir) else: # User provided `server-app`, but not `--run-id` diff --git a/src/py/flwr/superexec/simulation.py b/src/py/flwr/superexec/simulation.py index 9a8e19365ab9..fa7a8ad9b0d3 100644 --- a/src/py/flwr/superexec/simulation.py +++ b/src/py/flwr/superexec/simulation.py @@ -112,7 +112,7 @@ def start_run( ) # Get ClientApp and SeverApp components - flower_components = config["flower"]["components"] + flower_components = config["tool"]["flwr"]["components"] clientapp = flower_components["clientapp"] serverapp = flower_components["serverapp"] From 690e6acd84e42609e3be10a60b06444b252f8f74 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Mon, 15 Jul 2024 22:39:08 +0200 Subject: [PATCH 11/16] feat(framework) Add secure channel support for SuperExec (#3808) Co-authored-by: Daniel J. Beutel --- src/py/flwr/cli/run/run.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/cli/run/run.py b/src/py/flwr/cli/run/run.py index c39ae0decd4b..512da83d13fe 100644 --- a/src/py/flwr/cli/run/run.py +++ b/src/py/flwr/cli/run/run.py @@ -120,10 +120,38 @@ def on_channel_state_change(channel_connectivity: str) -> None: """Log channel connectivity.""" log(DEBUG, channel_connectivity) + insecure_str = federation.get("insecure") + if root_certificates := federation.get("root-certificates"): + root_certificates_bytes = Path(root_certificates).read_bytes() + if insecure := bool(insecure_str): + typer.secho( + "❌ `root_certificates` were provided but the `insecure` parameter" + "is set to `True`.", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + else: + root_certificates_bytes = None + if insecure_str is None: + typer.secho( + "❌ To disable TLS, set `insecure = true` in `pyproject.toml`.", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + if not (insecure := bool(insecure_str)): + typer.secho( + "❌ No certificate were given yet `insecure` is set to `False`.", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + channel = create_channel( server_address=federation["address"], - insecure=True, - root_certificates=None, + insecure=insecure, + root_certificates=root_certificates_bytes, max_message_length=GRPC_MAX_MESSAGE_LENGTH, interceptors=None, ) From b9c3a363bf2d801677f663909b326b723d275124 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Tue, 16 Jul 2024 10:57:48 +0200 Subject: [PATCH 12/16] refactor(framework) Move `tool.flwr` to `tool.flwr.app` (#3811) --- src/py/flwr/cli/build.py | 2 +- src/py/flwr/cli/config_utils.py | 38 +++++++----- src/py/flwr/cli/config_utils_test.py | 58 +++++++++++-------- src/py/flwr/cli/install.py | 2 +- .../app/pyproject.flowertune.toml.tpl | 4 +- .../new/templates/app/pyproject.hf.toml.tpl | 4 +- .../new/templates/app/pyproject.jax.toml.tpl | 4 +- .../new/templates/app/pyproject.mlx.toml.tpl | 4 +- .../templates/app/pyproject.numpy.toml.tpl | 4 +- .../templates/app/pyproject.pytorch.toml.tpl | 4 +- .../templates/app/pyproject.sklearn.toml.tpl | 4 +- .../app/pyproject.tensorflow.toml.tpl | 4 +- src/py/flwr/common/config.py | 4 +- src/py/flwr/common/config_test.py | 38 ++++++------ 14 files changed, 99 insertions(+), 75 deletions(-) diff --git a/src/py/flwr/cli/build.py b/src/py/flwr/cli/build.py index 599ce613698c..670b8bd64908 100644 --- a/src/py/flwr/cli/build.py +++ b/src/py/flwr/cli/build.py @@ -85,7 +85,7 @@ def build( # Set the name of the zip file fab_filename = ( - f"{conf['tool']['flwr']['publisher']}" + f"{conf['tool']['flwr']['app']['publisher']}" f".{directory.name}" f".{conf['project']['version'].replace('.', '-')}.fab" ) diff --git a/src/py/flwr/cli/config_utils.py b/src/py/flwr/cli/config_utils.py index 9147ebba4995..f46a53857dfc 100644 --- a/src/py/flwr/cli/config_utils.py +++ b/src/py/flwr/cli/config_utils.py @@ -60,7 +60,7 @@ def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]: return ( conf["project"]["version"], - f"{conf['tool']['flwr']['publisher']}/{conf['project']['name']}", + f"{conf['tool']['flwr']['app']['publisher']}/{conf['project']['name']}", ) @@ -136,20 +136,28 @@ def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]] if "authors" not in config["project"]: warnings.append('Recommended property "authors" missing in [project]') - if "tool" not in config or "flwr" not in config["tool"]: - errors.append("Missing [tool.flwr] section") + if ( + "tool" not in config + or "flwr" not in config["tool"] + or "app" not in config["tool"]["flwr"] + ): + errors.append("Missing [tool.flwr.app] section") else: - if "publisher" not in config["tool"]["flwr"]: - errors.append('Property "publisher" missing in [tool.flwr]') - if "config" in config["tool"]["flwr"]: - _validate_run_config(config["tool"]["flwr"]["config"], errors) - if "components" not in config["tool"]["flwr"]: - errors.append("Missing [tool.flwr.components] section") + if "publisher" not in config["tool"]["flwr"]["app"]: + errors.append('Property "publisher" missing in [tool.flwr.app]') + if "config" in config["tool"]["flwr"]["app"]: + _validate_run_config(config["tool"]["flwr"]["app"]["config"], errors) + if "components" not in config["tool"]["flwr"]["app"]: + errors.append("Missing [tool.flwr.app.components] section") else: - if "serverapp" not in config["tool"]["flwr"]["components"]: - errors.append('Property "serverapp" missing in [tool.flwr.components]') - if "clientapp" not in config["tool"]["flwr"]["components"]: - errors.append('Property "clientapp" missing in [tool.flwr.components]') + if "serverapp" not in config["tool"]["flwr"]["app"]["components"]: + errors.append( + 'Property "serverapp" missing in [tool.flwr.app.components]' + ) + if "clientapp" not in config["tool"]["flwr"]["app"]["components"]: + errors.append( + 'Property "clientapp" missing in [tool.flwr.app.components]' + ) return len(errors) == 0, errors, warnings @@ -165,14 +173,14 @@ def validate( # Validate serverapp is_valid, reason = object_ref.validate( - config["tool"]["flwr"]["components"]["serverapp"], check_module + config["tool"]["flwr"]["app"]["components"]["serverapp"], check_module ) if not is_valid and isinstance(reason, str): return False, [reason], [] # Validate clientapp is_valid, reason = object_ref.validate( - config["tool"]["flwr"]["components"]["clientapp"], check_module + config["tool"]["flwr"]["app"]["components"]["clientapp"], check_module ) if not is_valid and isinstance(reason, str): diff --git a/src/py/flwr/cli/config_utils_test.py b/src/py/flwr/cli/config_utils_test.py index 35d9900703b6..077f254fb914 100644 --- a/src/py/flwr/cli/config_utils_test.py +++ b/src/py/flwr/cli/config_utils_test.py @@ -40,10 +40,10 @@ def test_load_pyproject_toml_load_from_cwd(tmp_path: Path) -> None: "numpy>=1.21.0", ] - [tool.flwr] + [tool.flwr.app] publisher = "flwrlabs" - [tool.flwr.components] + [tool.flwr.app.components] serverapp = "fedgpt.server:app" clientapp = "fedgpt.client:app" """ @@ -58,10 +58,12 @@ def test_load_pyproject_toml_load_from_cwd(tmp_path: Path) -> None: }, "tool": { "flwr": { - "publisher": "flwrlabs", - "components": { - "serverapp": "fedgpt.server:app", - "clientapp": "fedgpt.client:app", + "app": { + "publisher": "flwrlabs", + "components": { + "serverapp": "fedgpt.server:app", + "clientapp": "fedgpt.client:app", + }, }, }, }, @@ -103,10 +105,10 @@ def test_load_pyproject_toml_from_path(tmp_path: Path) -> None: "numpy>=1.21.0", ] - [tool.flwr] + [tool.flwr.app] publisher = "flwrlabs" - [tool.flwr.components] + [tool.flwr.app.components] serverapp = "fedgpt.server:app" clientapp = "fedgpt.client:app" """ @@ -121,10 +123,12 @@ def test_load_pyproject_toml_from_path(tmp_path: Path) -> None: }, "tool": { "flwr": { - "publisher": "flwrlabs", - "components": { - "serverapp": "fedgpt.server:app", - "clientapp": "fedgpt.client:app", + "app": { + "publisher": "flwrlabs", + "components": { + "serverapp": "fedgpt.server:app", + "clientapp": "fedgpt.client:app", + }, }, }, }, @@ -195,7 +199,7 @@ def test_validate_pyproject_toml_fields_no_flower_components() -> None: "license": "", "authors": [], }, - "tool": {"flwr": {}}, + "tool": {"flwr": {"app": {}}}, } # Execute @@ -218,7 +222,7 @@ def test_validate_pyproject_toml_fields_no_server_and_client_app() -> None: "license": "", "authors": [], }, - "tool": {"flwr": {"components": {}}}, + "tool": {"flwr": {"app": {"components": {}}}}, } # Execute @@ -243,8 +247,10 @@ def test_validate_pyproject_toml_fields() -> None: }, "tool": { "flwr": { - "publisher": "flwrlabs", - "components": {"serverapp": "", "clientapp": ""}, + "app": { + "publisher": "flwrlabs", + "components": {"serverapp": "", "clientapp": ""}, + }, }, }, } @@ -271,10 +277,12 @@ def test_validate_pyproject_toml() -> None: }, "tool": { "flwr": { - "publisher": "flwrlabs", - "components": { - "serverapp": "flwr.cli.run:run", - "clientapp": "flwr.cli.run:run", + "app": { + "publisher": "flwrlabs", + "components": { + "serverapp": "flwr.cli.run:run", + "clientapp": "flwr.cli.run:run", + }, }, }, }, @@ -302,10 +310,12 @@ def test_validate_pyproject_toml_fail() -> None: }, "tool": { "flwr": { - "publisher": "flwrlabs", - "components": { - "serverapp": "flwr.cli.run:run", - "clientapp": "flwr.cli.run:runa", + "app": { + "publisher": "flwrlabs", + "components": { + "serverapp": "flwr.cli.run:run", + "clientapp": "flwr.cli.run:runa", + }, }, }, }, diff --git a/src/py/flwr/cli/install.py b/src/py/flwr/cli/install.py index 7444f10c1eb7..a1a66e42fd65 100644 --- a/src/py/flwr/cli/install.py +++ b/src/py/flwr/cli/install.py @@ -149,7 +149,7 @@ def validate_and_install( ) raise typer.Exit(code=1) - publisher = config["tool"]["flwr"]["publisher"] + publisher = config["tool"]["flwr"]["app"]["publisher"] project_name = config["project"]["name"] version = config["project"]["version"] diff --git a/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl index 17630dd9d0dc..5934a258564b 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl @@ -22,10 +22,10 @@ dependencies = [ [tool.hatch.build.targets.wheel] packages = ["."] -[tool.flwr] +[tool.flwr.app] publisher = "$username" -[tool.flwr.components] +[tool.flwr.app.components] serverapp = "$import_name.app:server" clientapp = "$import_name.app:client" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl index 6f46d6de5bf5..59f41f062af4 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl @@ -20,10 +20,10 @@ dependencies = [ [tool.hatch.build.targets.wheel] packages = ["."] -[tool.flwr] +[tool.flwr.app] publisher = "$username" -[tool.flwr.components] +[tool.flwr.app.components] serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl index 045a1f4e57eb..27f4b30ec3b8 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl @@ -17,10 +17,10 @@ dependencies = [ [tool.hatch.build.targets.wheel] packages = ["."] -[tool.flwr] +[tool.flwr.app] publisher = "$username" -[tool.flwr.components] +[tool.flwr.app.components] serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl index 5ea2c420d6f8..9c8905a0e8e5 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl @@ -17,10 +17,10 @@ dependencies = [ [tool.hatch.build.targets.wheel] packages = ["."] -[tool.flwr] +[tool.flwr.app] publisher = "$username" -[tool.flwr.components] +[tool.flwr.app.components] serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl index d166616bb616..38bfa1888c4c 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl @@ -15,10 +15,10 @@ dependencies = [ [tool.hatch.build.targets.wheel] packages = ["."] -[tool.flwr] +[tool.flwr.app] publisher = "$username" -[tool.flwr.components] +[tool.flwr.app.components] serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl index c0323126516d..2fd366d8e350 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl @@ -17,10 +17,10 @@ dependencies = [ [tool.hatch.build.targets.wheel] packages = ["."] -[tool.flwr] +[tool.flwr.app] publisher = "$username" -[tool.flwr.components] +[tool.flwr.app.components] serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl index 0e63375aab00..143c3756858c 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl @@ -16,10 +16,10 @@ dependencies = [ [tool.hatch.build.targets.wheel] packages = ["."] -[tool.flwr] +[tool.flwr.app] publisher = "$username" -[tool.flwr.components] +[tool.flwr.app.components] serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl index aeca4a17805f..964bef58c498 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl @@ -16,10 +16,10 @@ dependencies = [ [tool.hatch.build.targets.wheel] packages = ["."] -[tool.flwr] +[tool.flwr.app] publisher = "$username" -[tool.flwr.components] +[tool.flwr.app.components] serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index e2b06ff86110..247c4ef775a7 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -97,7 +97,9 @@ def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> Dict[str, str]: project_dir = get_project_dir(run.fab_id, run.fab_version, flwr_dir) - default_config = get_project_config(project_dir)["tool"]["flwr"].get("config", {}) + default_config = get_project_config(project_dir)["tool"]["flwr"]["app"].get( + "config", {} + ) flat_default_config = flatten_dict(default_config) return _fuse_dicts(flat_default_config, run.override_config) diff --git a/src/py/flwr/common/config_test.py b/src/py/flwr/common/config_test.py index 899240c1e76a..feef89e7d5cb 100644 --- a/src/py/flwr/common/config_test.py +++ b/src/py/flwr/common/config_test.py @@ -93,20 +93,20 @@ def test_get_fused_config_valid(tmp_path: Path) -> None: "numpy>=1.21.0", ] - [tool.flwr] + [tool.flwr.app] publisher = "flwrlabs" - [tool.flwr.components] + [tool.flwr.app.components] serverapp = "fedgpt.server:app" clientapp = "fedgpt.client:app" - [tool.flwr.config] + [tool.flwr.app.config] num_server_rounds = "10" momentum = "0.1" lr = "0.01" serverapp.test = "key" - [tool.flwr.config.clientapp] + [tool.flwr.app.config.clientapp] test = "key" """ overrides = { @@ -131,7 +131,9 @@ def test_get_fused_config_valid(tmp_path: Path) -> None: f.write(textwrap.dedent(pyproject_toml_content)) # Execute - default_config = get_project_config(tmp_path)["tool"]["flwr"].get("config", {}) + default_config = get_project_config(tmp_path)["tool"]["flwr"]["app"].get( + "config", {} + ) config = _fuse_dicts(flatten_dict(default_config), overrides) @@ -158,14 +160,14 @@ def test_get_project_config_file_valid(tmp_path: Path) -> None: "numpy>=1.21.0", ] - [tool.flwr] + [tool.flwr.app] publisher = "flwrlabs" - [tool.flwr.components] + [tool.flwr.app.components] serverapp = "fedgpt.server:app" clientapp = "fedgpt.client:app" - [tool.flwr.config] + [tool.flwr.app.config] num_server_rounds = "10" momentum = "0.1" lr = "0.01" @@ -181,15 +183,17 @@ def test_get_project_config_file_valid(tmp_path: Path) -> None: }, "tool": { "flwr": { - "publisher": "flwrlabs", - "components": { - "serverapp": "fedgpt.server:app", - "clientapp": "fedgpt.client:app", - }, - "config": { - "num_server_rounds": "10", - "momentum": "0.1", - "lr": "0.01", + "app": { + "publisher": "flwrlabs", + "components": { + "serverapp": "fedgpt.server:app", + "clientapp": "fedgpt.client:app", + }, + "config": { + "num_server_rounds": "10", + "momentum": "0.1", + "lr": "0.01", + }, }, }, }, From 028f619d44cb4280bc13fd1d45ab2eddbafaa281 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Tue, 16 Jul 2024 11:13:59 +0200 Subject: [PATCH 13/16] feat(framework:skip) Add config function for fusing dicts (#3813) Co-authored-by: Daniel J. Beutel --- src/py/flwr/common/config.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index 247c4ef775a7..6049fcbcceed 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -86,6 +86,18 @@ def _fuse_dicts( return fused_dict +def get_fused_config_from_dir( + project_dir: Path, override_config: Dict[str, str] +) -> Dict[str, str]: + """Merge the overrides from a given dict with the config from a Flower App.""" + default_config = get_project_config(project_dir)["tool"]["flwr"]["app"].get( + "config", {} + ) + flat_default_config = flatten_dict(default_config) + + return _fuse_dicts(flat_default_config, override_config) + + def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> Dict[str, str]: """Merge the overrides from a `Run` with the config from a FAB. @@ -97,12 +109,7 @@ def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> Dict[str, str]: project_dir = get_project_dir(run.fab_id, run.fab_version, flwr_dir) - default_config = get_project_config(project_dir)["tool"]["flwr"]["app"].get( - "config", {} - ) - flat_default_config = flatten_dict(default_config) - - return _fuse_dicts(flat_default_config, run.override_config) + return get_fused_config_from_dir(project_dir, run.override_config) def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, str]: From 517016c79c6c83154480ed4c6952e3513210e7fb Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Tue, 16 Jul 2024 13:10:37 +0200 Subject: [PATCH 14/16] feat(framework) Remove federations field from FAB (#3814) --- pyproject.toml | 1 + src/py/flwr/cli/build.py | 16 +++++++++++++++- .../templates/app/pyproject.flowertune.toml.tpl | 2 +- .../cli/new/templates/app/pyproject.hf.toml.tpl | 2 +- .../cli/new/templates/app/pyproject.jax.toml.tpl | 2 +- .../cli/new/templates/app/pyproject.mlx.toml.tpl | 2 +- .../new/templates/app/pyproject.numpy.toml.tpl | 2 +- .../new/templates/app/pyproject.pytorch.toml.tpl | 2 +- .../new/templates/app/pyproject.sklearn.toml.tpl | 2 +- .../templates/app/pyproject.tensorflow.toml.tpl | 2 +- 10 files changed, 24 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c5ab0e5edcee..7fe1ef7843d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ pycryptodome = "^3.18.0" iterators = "^0.0.2" typer = { version = "^0.9.0", extras=["all"] } tomli = "^2.0.1" +tomli-w = "^1.0.0" pathspec = "^0.12.1" # Optional dependencies (Simulation Engine) ray = { version = "==2.10.0", optional = true, python = ">=3.8,<3.12" } diff --git a/src/py/flwr/cli/build.py b/src/py/flwr/cli/build.py index 670b8bd64908..1f7f75d36184 100644 --- a/src/py/flwr/cli/build.py +++ b/src/py/flwr/cli/build.py @@ -20,6 +20,7 @@ from typing import Optional import pathspec +import tomli_w import typer from typing_extensions import Annotated @@ -93,15 +94,28 @@ def build( allowed_extensions = {".py", ".toml", ".md"} + # Remove the 'federations' field from 'tool.flwr' if it exists + if ( + "tool" in conf + and "flwr" in conf["tool"] + and "federations" in conf["tool"]["flwr"] + ): + del conf["tool"]["flwr"]["federations"] + + toml_contents = tomli_w.dumps(conf) + with zipfile.ZipFile(fab_filename, "w", zipfile.ZIP_DEFLATED) as fab_file: + fab_file.writestr("pyproject.toml", toml_contents) + + # Continue with adding other files for root, _, files in os.walk(directory, topdown=True): - # Filter directories and files based on .gitignore files = [ f for f in files if not ignore_spec.match_file(Path(root) / f) and f != fab_filename and Path(f).suffix in allowed_extensions + and f != "pyproject.toml" # Exclude the original pyproject.toml ] for file in files: diff --git a/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl index 5934a258564b..507b5d50b843 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl @@ -6,7 +6,7 @@ build-backend = "hatchling.build" name = "$package_name" version = "1.0.0" description = "" -license = { text = "Apache License (2.0)" } +license = "Apache-2.0" dependencies = [ "flwr[simulation]>=1.9.0,<2.0", "flwr-datasets>=0.1.0,<1.0.0", diff --git a/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl index 59f41f062af4..7a63e1ab5368 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl @@ -6,7 +6,7 @@ build-backend = "hatchling.build" name = "$package_name" version = "1.0.0" description = "" -license = { text = "Apache License (2.0)" } +license = "Apache-2.0" dependencies = [ "flwr[simulation]>=1.9.0,<2.0", "flwr-datasets>=0.0.2,<1.0.0", diff --git a/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl index 27f4b30ec3b8..297784a4d2d8 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl @@ -6,7 +6,7 @@ build-backend = "hatchling.build" name = "$package_name" version = "1.0.0" description = "" -license = {text = "Apache License (2.0)"} +license = "Apache-2.0" dependencies = [ "flwr[simulation]>=1.9.0,<2.0", "jax==0.4.26", diff --git a/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl index 9c8905a0e8e5..fb55f6628cea 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl @@ -6,7 +6,7 @@ build-backend = "hatchling.build" name = "$package_name" version = "1.0.0" description = "" -license = { text = "Apache License (2.0)" } +license = "Apache-2.0" dependencies = [ "flwr[simulation]>=1.9.0,<2.0", "flwr-datasets[vision]>=0.0.2,<1.0.0", diff --git a/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl index 38bfa1888c4c..ae88472647dc 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl @@ -6,7 +6,7 @@ build-backend = "hatchling.build" name = "$package_name" version = "1.0.0" description = "" -license = { text = "Apache License (2.0)" } +license = "Apache-2.0" dependencies = [ "flwr[simulation]>=1.9.0,<2.0", "numpy>=1.21.0", diff --git a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl index 2fd366d8e350..2dd49a25fd90 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl @@ -6,7 +6,7 @@ build-backend = "hatchling.build" name = "$package_name" version = "1.0.0" description = "" -license = { text = "Apache License (2.0)" } +license = "Apache-2.0" dependencies = [ "flwr[simulation]>=1.9.0,<2.0", "flwr-datasets[vision]>=0.0.2,<1.0.0", diff --git a/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl index 143c3756858c..8458fa64ea2d 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl @@ -6,7 +6,7 @@ build-backend = "hatchling.build" name = "$package_name" version = "1.0.0" description = "" -license = { text = "Apache License (2.0)" } +license = "Apache-2.0" dependencies = [ "flwr[simulation]>=1.9.0,<2.0", "flwr-datasets[vision]>=0.0.2,<1.0.0", diff --git a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl index 964bef58c498..2bf0e7d5642c 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl @@ -6,7 +6,7 @@ build-backend = "hatchling.build" name = "$package_name" version = "1.0.0" description = "" -license = { text = "Apache License (2.0)" } +license = "Apache-2.0" dependencies = [ "flwr[simulation]>=1.9.0,<2.0", "flwr-datasets[vision]>=0.0.2,<1.0.0", From 284921db5dc5442ed5e156c80101be98f51e8ad4 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Tue, 16 Jul 2024 15:43:03 +0200 Subject: [PATCH 15/16] ci(datasets:skip) Fix pyarrow version (#3817) --- datasets/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/datasets/pyproject.toml b/datasets/pyproject.toml index 017374181f59..e3afd8b87075 100644 --- a/datasets/pyproject.toml +++ b/datasets/pyproject.toml @@ -59,6 +59,7 @@ pillow = { version = ">=6.2.1", optional = true } soundfile = { version = ">=0.12.1", optional = true } librosa = { version = ">=0.10.0.post2", optional = true } tqdm ="^4.66.1" +pyarrow = "==16.1.0" matplotlib = "^3.7.5" seaborn = "^0.13.0" From 1f3fe0f7fe6d4191d24084fdaa74eb528d57eec3 Mon Sep 17 00:00:00 2001 From: Javier Date: Tue, 16 Jul 2024 15:49:05 +0200 Subject: [PATCH 16/16] feat(framework) Update context registration when running an app directory (#3815) --- src/py/flwr/client/app.py | 6 +++--- src/py/flwr/client/node_state.py | 20 +++++++++++++++++--- src/py/flwr/client/supernode/app.py | 2 +- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 348ef8910dd3..127bb423851f 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -195,7 +195,7 @@ def _start_client_internal( ] = None, max_retries: Optional[int] = None, max_wait_time: Optional[float] = None, - flwr_dir: Optional[Path] = None, + flwr_path: Optional[Path] = None, ) -> None: """Start a Flower client node which connects to a Flower server. @@ -241,7 +241,7 @@ class `flwr.client.Client` (default: None) The maximum duration before the client stops trying to connect to the server in case of connection error. If set to None, there is no limit to the total time. - flwr_dir: Optional[Path] (default: None) + flwr_path: Optional[Path] (default: None) The fully resolved path containing installed Flower Apps. """ if insecure is None: @@ -402,7 +402,7 @@ def _on_backoff(retry_state: RetryState) -> None: # Register context for this run node_state.register_context( - run_id=run_id, run=runs[run_id], flwr_dir=flwr_dir + run_id=run_id, run=runs[run_id], flwr_path=flwr_path ) # Retrieve context for this run diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/node_state.py index 393ca4564a35..08c19967ea3d 100644 --- a/src/py/flwr/client/node_state.py +++ b/src/py/flwr/client/node_state.py @@ -20,7 +20,7 @@ from typing import Dict, Optional from flwr.common import Context, RecordSet -from flwr.common.config import get_fused_config +from flwr.common.config import get_fused_config, get_fused_config_from_dir from flwr.common.typing import Run @@ -48,11 +48,25 @@ def register_context( self, run_id: int, run: Optional[Run] = None, - flwr_dir: Optional[Path] = None, + flwr_path: Optional[Path] = None, + app_dir: Optional[str] = None, ) -> None: """Register new run context for this node.""" if run_id not in self.run_infos: - initial_run_config = get_fused_config(run, flwr_dir) if run else {} + initial_run_config = {} + if app_dir: + # Load from app directory + app_path = Path(app_dir) + if app_path.is_dir(): + override_config = run.override_config if run else {} + initial_run_config = get_fused_config_from_dir( + app_path, override_config + ) + else: + raise ValueError("The specified `app_dir` must be a directory.") + else: + # Load from .fab + initial_run_config = get_fused_config(run, flwr_path) if run else {} self.run_infos[run_id] = RunInfo( initial_run_config=initial_run_config, context=Context( diff --git a/src/py/flwr/client/supernode/app.py b/src/py/flwr/client/supernode/app.py index 027c3376b7f3..a364318c766c 100644 --- a/src/py/flwr/client/supernode/app.py +++ b/src/py/flwr/client/supernode/app.py @@ -78,7 +78,7 @@ def run_supernode() -> None: max_retries=args.max_retries, max_wait_time=args.max_wait_time, node_config=parse_config_args(args.node_config), - flwr_dir=get_flwr_dir(args.flwr_dir), + flwr_path=get_flwr_dir(args.flwr_dir), ) # Graceful shutdown