From 2b0c33fa6429d876966995db0df9f56939bf3292 Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 26 Feb 2024 11:18:30 +0100 Subject: [PATCH 01/17] Introduce Backend.terminate and asyncio event to stop (#3008) --- .../superlink/fleet/vce/backend/backend.py | 4 ++++ .../superlink/fleet/vce/backend/raybackend.py | 9 ++++++++- .../flwr/server/superlink/fleet/vce/vce_api.py | 4 +++- .../flwr/simulation/ray_transport/ray_actor.py | 18 +++++++++++++++--- 4 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/backend.py b/src/py/flwr/server/superlink/fleet/vce/backend/backend.py index 2df4be76e7a0..f2796a5758a0 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/backend.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/backend.py @@ -53,6 +53,10 @@ def num_workers(self) -> int: def is_worker_idle(self) -> bool: """Report whether a backend worker is idle and can therefore run a ClientApp.""" + @abstractmethod + async def terminate(self) -> None: + """Terminate backend.""" + @abstractmethod async def process_message( self, diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py index 24620aab083f..cc3cf4348491 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py @@ -28,6 +28,7 @@ ClientAppActor, init_ray, ) +from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth from .backend import Backend, BackendConfig @@ -56,7 +57,9 @@ def __init__( self.client_resources_key = "client_resources" # Create actor pool - actor_kwargs = backend_config.get("actor_kwargs", {}) + use_tf = backend_config.get("tensorflow", False) + actor_kwargs = {"on_actor_init_fn": enable_tf_gpu_growth} if use_tf else {} + client_resources = self._validate_client_resources(config=backend_config) self.pool = BasicActorPool( actor_type=ClientAppActor, @@ -151,3 +154,7 @@ async def process_message( ) = await self.pool.fetch_result_and_return_actor_to_pool(future) return out_mssg, updated_context + + async def terminate(self) -> None: + """Terminate all actors in actor pool.""" + await self.pool.terminate_all_actors() diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index c91bae9ddabb..666e7e7d9ec3 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -14,9 +14,10 @@ # ============================================================================== """Fleet VirtualClientEngine API.""" +import asyncio import json from logging import ERROR, INFO -from typing import Dict +from typing import Dict, Optional from flwr.client.clientapp import ClientApp, load_client_app from flwr.client.node_state import NodeState @@ -49,6 +50,7 @@ def start_vce( backend_config_json_stream: str, state_factory: StateFactory, working_dir: str, + f_stop: Optional[asyncio.Event] = None, ) -> None: """Start Fleet API with the VirtualClientEngine (VCE).""" # Register SuperNodes diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index e899ce282618..5ac0b2c27484 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -18,7 +18,7 @@ import threading import traceback from abc import ABC -from logging import ERROR, WARNING +from logging import DEBUG, ERROR, WARNING from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union import ray @@ -46,7 +46,7 @@ class VirtualClientEngineActor(ABC): def terminate(self) -> None: """Manually terminate Actor object.""" - log(WARNING, "Manually terminating %s}", self.__class__.__name__) + log(WARNING, "Manually terminating %s", self.__class__.__name__) ray.actor.exit_actor() def run( @@ -434,7 +434,9 @@ def __init__( self.client_resources = client_resources # Queue of idle actors - self.pool: "asyncio.Queue[Type[VirtualClientEngineActor]]" = asyncio.Queue() + self.pool: "asyncio.Queue[Type[VirtualClientEngineActor]]" = asyncio.Queue( + maxsize=1024 + ) self.num_actors = 0 # Resolve arguments to pass during actor init @@ -464,6 +466,16 @@ async def add_actors_to_pool(self, num_actors: int) -> None: await self.pool.put(self.create_actor_fn()) # type: ignore self.num_actors += num_actors + async def terminate_all_actors(self) -> None: + """Terminate actors in pool.""" + num_terminated = 0 + while self.pool.qsize(): + actor = await self.pool.get() + actor.terminate.remote() # type: ignore + num_terminated += 1 + + log(DEBUG, "Terminated %i actors", num_terminated) + async def submit( self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context] ) -> Any: From ad94eaf0621de776a4a0a2645e41f27bc536b536 Mon Sep 17 00:00:00 2001 From: Sebastian van der Voort Date: Mon, 26 Feb 2024 11:45:06 +0100 Subject: [PATCH 02/17] Replace dead links, typo fixes and general improvements in framework docs (#2989) Co-authored-by: svdvoort <23049683+Svdvoort@users.noreply.github.com> --- doc/Makefile | 6 ++ ...contributor-how-to-build-docker-images.rst | 2 +- ...ributor-how-to-contribute-translations.rst | 4 +- ...contributor-how-to-create-new-messages.rst | 4 +- ...ow-to-develop-in-vscode-dev-containers.rst | 6 +- ...ntributor-ref-good-first-contributions.rst | 4 +- ...tributor-tutorial-contribute-on-github.rst | 35 +++--- ...-tutorial-get-started-as-a-contributor.rst | 6 +- ...-pytorch-from-centralized-to-federated.rst | 10 +- .../example-walkthrough-pytorch-mnist.rst | 100 +++++++++--------- doc/source/how-to-configure-clients.rst | 2 +- doc/source/how-to-install-flower.rst | 2 +- doc/source/how-to-monitor-simulation.rst | 4 +- doc/source/ref-example-projects.rst | 10 +- doc/source/ref-faq.rst | 4 +- .../tutorial-quickstart-huggingface.rst | 47 ++++---- doc/source/tutorial-quickstart-ios.rst | 6 +- doc/source/tutorial-quickstart-mxnet.rst | 4 +- doc/source/tutorial-quickstart-pytorch.rst | 2 +- .../tutorial-quickstart-scikitlearn.rst | 8 +- doc/source/tutorial-quickstart-xgboost.rst | 4 +- ...uild-a-strategy-from-scratch-pytorch.ipynb | 10 +- ...-series-customize-the-client-pytorch.ipynb | 8 +- ...ries-get-started-with-flower-pytorch.ipynb | 4 +- ...-federated-learning-strategy-pytorch.ipynb | 4 +- ...al-series-what-is-federated-learning.ipynb | 6 +- src/py/flwr/server/driver/driver.py | 1 + 27 files changed, 154 insertions(+), 149 deletions(-) diff --git a/doc/Makefile b/doc/Makefile index 085fe301e162..0e11bb3b4b3c 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -7,6 +7,7 @@ SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = source BUILDDIR = build +LINKCHECKDIR = build/linkcheck # Put it first so that "make" without argument is like "make help". help: @@ -29,6 +30,11 @@ serve: make docs python -m http.server --directory build/html +checklinks: + $(SPHINXBUILD) -b linkcheck "$(SOURCEDIR)" "$(LINKCHECKDIR)" + @echo + @echo "Check finished. Report is in $(LINKCHECKDIR)." + # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile diff --git a/doc/source/contributor-how-to-build-docker-images.rst b/doc/source/contributor-how-to-build-docker-images.rst index 3beae7422bef..5dead265bee2 100644 --- a/doc/source/contributor-how-to-build-docker-images.rst +++ b/doc/source/contributor-how-to-build-docker-images.rst @@ -17,7 +17,7 @@ Before we can start, we need to meet a few prerequisites in our local developmen #. Verify the Docker daemon is running. Please follow the first section on - `Run Flower using Docker `_ + :doc:`Run Flower using Docker ` which covers this step in more detail. Currently, Flower provides two images, a base image and a server image. There will also be a client diff --git a/doc/source/contributor-how-to-contribute-translations.rst b/doc/source/contributor-how-to-contribute-translations.rst index 1614b8e5a040..ba59901cf1c4 100644 --- a/doc/source/contributor-how-to-contribute-translations.rst +++ b/doc/source/contributor-how-to-contribute-translations.rst @@ -8,7 +8,7 @@ the translations are often imperfect. If you speak languages other than English, you might be able to help us in our effort to make Federated Learning accessible to as many people as possible by contributing to those translations! This might also be a great opportunity for those wanting to become open source -contributors with little prerequistes. +contributors with little prerequisites. Our translation project is publicly available over on `Weblate `_, this where most @@ -44,7 +44,7 @@ This is what the interface looks like: .. image:: _static/weblate_interface.png -You input your translation in the textbox at the top and then, once you are +You input your translation in the text box at the top and then, once you are happy with it, you either press ``Save and continue`` (to save the translation and go to the next untranslated string), ``Save and stay`` (to save the translation and stay on the same page), ``Suggest`` (to add your translation to diff --git a/doc/source/contributor-how-to-create-new-messages.rst b/doc/source/contributor-how-to-create-new-messages.rst index 24fa5f573158..5d9f4600361c 100644 --- a/doc/source/contributor-how-to-create-new-messages.rst +++ b/doc/source/contributor-how-to-create-new-messages.rst @@ -29,8 +29,8 @@ Let's now see what we need to implement in order to get this simple function bet Message Types for Protocol Buffers ---------------------------------- -The first thing we need to do is to define a message type for the RPC system in :code:`transport.proto`. -Note that we have to do it for both the request and response messages. For more details on the syntax of proto3, please see the `official documentation `_. +The first thing we need to do is to define a message type for the RPC system in :code:`transport.proto`. +Note that we have to do it for both the request and response messages. For more details on the syntax of proto3, please see the `official documentation `_. Within the :code:`ServerMessage` block: diff --git a/doc/source/contributor-how-to-develop-in-vscode-dev-containers.rst b/doc/source/contributor-how-to-develop-in-vscode-dev-containers.rst index 19d46c5753c6..c861457b6edc 100644 --- a/doc/source/contributor-how-to-develop-in-vscode-dev-containers.rst +++ b/doc/source/contributor-how-to-develop-in-vscode-dev-containers.rst @@ -8,7 +8,7 @@ When working on the Flower framework we want to ensure that all contributors use Workspace files are mounted from the local file system or copied or cloned into the container. Extensions are installed and run inside the container, where they have full access to the tools, platform, and file system. This means that you can seamlessly switch your entire development environment just by connecting to a different container. -Source: `Official VSCode documentation `_ +Source: `Official VSCode documentation `_ Getting started @@ -20,5 +20,5 @@ Now you should be good to go. When starting VSCode, it will ask you to run in th In some cases your setup might be more involved. For those cases consult the following sources: -* `Developing inside a Container `_ -* `Remote development in Containers `_ +* `Developing inside a Container `_ +* `Remote development in Containers `_ diff --git a/doc/source/contributor-ref-good-first-contributions.rst b/doc/source/contributor-ref-good-first-contributions.rst index cbf21e2845bc..2b8ce88413f5 100644 --- a/doc/source/contributor-ref-good-first-contributions.rst +++ b/doc/source/contributor-ref-good-first-contributions.rst @@ -22,11 +22,11 @@ are: Request for Flower Baselines ---------------------------- -If you are not familiar with Flower Baselines, you should probably check-out our `contributing guide for baselines `_. +If you are not familiar with Flower Baselines, you should probably check-out our `contributing guide for baselines `_. You should then check out the open `issues `_ for baseline requests. -If you find a baseline that you'd like to work on and that has no assignes, feel free to assign it to yourself and start working on it! +If you find a baseline that you'd like to work on and that has no assignees, feel free to assign it to yourself and start working on it! Otherwise, if you don't find a baseline you'd like to work on, be sure to open a new issue with the baseline request template! diff --git a/doc/source/contributor-tutorial-contribute-on-github.rst b/doc/source/contributor-tutorial-contribute-on-github.rst index 273b47a636cc..6da81ce73662 100644 --- a/doc/source/contributor-tutorial-contribute-on-github.rst +++ b/doc/source/contributor-tutorial-contribute-on-github.rst @@ -3,8 +3,7 @@ Contribute on GitHub This guide is for people who want to get involved with Flower, but who are not used to contributing to GitHub projects. -If you're familiar with how contributing on GitHub works, you can directly checkout our -`getting started guide for contributors `_. +If you're familiar with how contributing on GitHub works, you can directly checkout our :doc:`getting started guide for contributors `. Setting up the repository @@ -12,7 +11,7 @@ Setting up the repository 1. **Create a GitHub account and setup Git** Git is a distributed version control tool. This allows for an entire codebase's history to be stored and every developer's machine. - It is a software that will need to be installed on your local machine, you can follow this `guide `_ to set it up. + It is a software that will need to be installed on your local machine, you can follow this `guide `_ to set it up. GitHub, itself, is a code hosting platform for version control and collaboration. It allows for everyone to collaborate and work from anywhere on remote repositories. @@ -22,7 +21,7 @@ Setting up the repository you download code from a remote repository on GitHub, make changes locally and keep track of them using Git and then you upload your new history back to GitHub. 2. **Forking the Flower repository** - A fork is a personal copy of a GitHub repository. To create one for Flower, you must navigate to https://github.com/adap/flower (while connected to your GitHub account) + A fork is a personal copy of a GitHub repository. To create one for Flower, you must navigate to ``_ (while connected to your GitHub account) and click the ``Fork`` button situated on the top right of the page. .. image:: _static/fork_button.png @@ -68,7 +67,7 @@ Setting up the repository 5. **Add upstream** Now we will add an upstream address to our repository. - Still in the same directroy, we must run the following command: + Still in the same directory, we must run the following command: .. code-block:: shell @@ -93,7 +92,7 @@ Setting up the repository Setting up the coding environment --------------------------------- -This can be achieved by following this `getting started guide for contributors`_ (note that you won't need to clone the repository). +This can be achieved by following this :doc:`getting started guide for contributors ` (note that you won't need to clone the repository). Once you are able to write code and test it, you can finally start making changes! @@ -256,28 +255,28 @@ Example of first contribution Problem ******* -For our documentation, we’ve started to use the `Diàtaxis framework `_. +For our documentation, we've started to use the `Diàtaxis framework `_. -Our “How to” guides should have titles that continue the sencence “How to …”, for example, “How to upgrade to Flower 1.0”. +Our "How to" guides should have titles that continue the sentence "How to …", for example, "How to upgrade to Flower 1.0". Most of our guides do not follow this new format yet, and changing their title is (unfortunately) more involved than one might think. -This issue is about changing the title of a doc from present continious to present simple. +This issue is about changing the title of a doc from present continuous to present simple. -Let's take the example of “Saving Progress” which we changed to “Save Progress”. Does this pass our check? +Let's take the example of "Saving Progress" which we changed to "Save Progress". Does this pass our check? -Before: ”How to saving progress” ❌ +Before: "How to saving progress" ❌ -After: ”How to save progress” ✅ +After: "How to save progress" ✅ Solution ******** -This is a tiny change, but it’ll allow us to test your end-to-end setup. After cloning and setting up the Flower repo, here’s what you should do: +This is a tiny change, but it'll allow us to test your end-to-end setup. After cloning and setting up the Flower repo, here's what you should do: - Find the source file in ``doc/source`` - Make the change in the ``.rst`` file (beware, the dashes under the title should be the same length as the title itself) -- Build the docs and check the result: ``_ +- Build the docs and `check the result `_ Rename file ::::::::::: @@ -285,7 +284,7 @@ Rename file You might have noticed that the file name still reflects the old wording. If we just change the file, then we break all existing links to it - it is **very important** to avoid that, breaking links can harm our search engine ranking. -Here’s how to change the file name: +Here's how to change the file name: - Change the file name to ``save-progress.rst`` - Add a redirect rule to ``doc/source/conf.py`` @@ -303,7 +302,7 @@ This is where we define the whole arborescence of the navbar. Open PR ::::::: -- Commit the changes (commit messages are always imperative: “Do something”, in this case “Change …”) +- Commit the changes (commit messages are always imperative: "Do something", in this case "Change …") - Push the changes to your fork - Open a PR (as shown above) - Wait for it to be approved! @@ -343,7 +342,7 @@ Next steps Once you have made your first PR, and want to contribute more, be sure to check out the following : -- `Good first contributions `_, where you should particularly look into the :code:`baselines` contributions. +- :doc:`Good first contributions `, where you should particularly look into the :code:`baselines` contributions. Appendix @@ -361,7 +360,7 @@ Above this header you should see the following comment that explains how to writ Inside the following 'Changelog entry' section, you should put the description of your changes that will be added to the changelog alongside your PR title. - If the section is completely empty (without any token) or non-existant, + If the section is completely empty (without any token) or non-existent, the changelog will just contain the title of the PR for the changelog entry, without any description. If the section contains some text other than tokens, it will use it to add a description to the change. diff --git a/doc/source/contributor-tutorial-get-started-as-a-contributor.rst b/doc/source/contributor-tutorial-get-started-as-a-contributor.rst index 72c6df5fdbc7..01810c7244d3 100644 --- a/doc/source/contributor-tutorial-get-started-as-a-contributor.rst +++ b/doc/source/contributor-tutorial-get-started-as-a-contributor.rst @@ -11,7 +11,7 @@ Prerequisites Flower uses :code:`pyproject.toml` to manage dependencies and configure development tools (the ones which support it). Poetry is a build tool which -supports `PEP 517 `_. +supports `PEP 517 `_. Developer Machine Setup @@ -27,7 +27,7 @@ For macOS * Install `homebrew `_. Don't forget the post-installation actions to add `brew` to your PATH. * Install `xz` (to install different Python versions) and `pandoc` to build the docs:: - + $ brew install xz pandoc For Ubuntu @@ -54,7 +54,7 @@ GitHub:: * If you don't have :code:`pyenv` installed, the following script that will install it, set it up, and create the virtual environment (with :code:`Python 3.8.17` by default):: $ ./dev/setup-defaults.sh # once completed, run the bootstrap script - + * If you already have :code:`pyenv` installed (along with the :code:`pyenv-virtualenv` plugin), you can use the following convenience script (with :code:`Python 3.8.17` by default):: $ ./dev/venv-create.sh # once completed, run the `bootstrap.sh` script diff --git a/doc/source/example-fedbn-pytorch-from-centralized-to-federated.rst b/doc/source/example-fedbn-pytorch-from-centralized-to-federated.rst index 5d4dac0c0cda..0139f3b8dc31 100644 --- a/doc/source/example-fedbn-pytorch-from-centralized-to-federated.rst +++ b/doc/source/example-fedbn-pytorch-from-centralized-to-federated.rst @@ -3,11 +3,11 @@ Example: FedBN in PyTorch - From Centralized To Federated This tutorial will show you how to use Flower to build a federated version of an existing machine learning workload with `FedBN `_, a federated training strategy designed for non-iid data. We are using PyTorch to train a Convolutional Neural Network(with Batch Normalization layers) on the CIFAR-10 dataset. -When applying FedBN, only few changes needed compared to `Example: PyTorch - From Centralized To Federated `_. +When applying FedBN, only few changes needed compared to :doc:`Example: PyTorch - From Centralized To Federated `. Centralized Training -------------------- -All files are revised based on `Example: PyTorch - From Centralized To Federated `_. +All files are revised based on :doc:`Example: PyTorch - From Centralized To Federated `. The only thing to do is modifying the file called :code:`cifar.py`, revised part is shown below: The model architecture defined in class Net() is added with Batch Normalization layers accordingly. @@ -45,13 +45,13 @@ You can now run your machine learning workload: python3 cifar.py So far this should all look fairly familiar if you've used PyTorch before. -Let's take the next step and use what we've built to create a federated learning system within FedBN, the sytstem consists of one server and two clients. +Let's take the next step and use what we've built to create a federated learning system within FedBN, the system consists of one server and two clients. Federated Training ------------------ -If you have read `Example: PyTorch - From Centralized To Federated `_, the following parts are easy to follow, onyl :code:`get_parameters` and :code:`set_parameters` function in :code:`client.py` needed to revise. -If not, please read the `Example: PyTorch - From Centralized To Federated `_. first. +If you have read :doc:`Example: PyTorch - From Centralized To Federated `, the following parts are easy to follow, only :code:`get_parameters` and :code:`set_parameters` function in :code:`client.py` needed to revise. +If not, please read the :doc:`Example: PyTorch - From Centralized To Federated `. first. Our example consists of one *server* and two *clients*. In FedBN, :code:`server.py` keeps unchanged, we can start the server directly. diff --git a/doc/source/example-walkthrough-pytorch-mnist.rst b/doc/source/example-walkthrough-pytorch-mnist.rst index 0be0af6e1ca6..8717c196f043 100644 --- a/doc/source/example-walkthrough-pytorch-mnist.rst +++ b/doc/source/example-walkthrough-pytorch-mnist.rst @@ -1,11 +1,11 @@ Example: Walk-Through PyTorch & MNIST ===================================== -In this tutorial we will learn, how to train a Convolutional Neural Network on MNIST using Flower and PyTorch. +In this tutorial we will learn, how to train a Convolutional Neural Network on MNIST using Flower and PyTorch. -Our example consists of one *server* and two *clients* all having the same model. +Our example consists of one *server* and two *clients* all having the same model. -*Clients* are responsible for generating individual weight-updates for the model based on their local datasets. +*Clients* are responsible for generating individual weight-updates for the model based on their local datasets. These updates are then sent to the *server* which will aggregate them to produce a better model. Finally, the *server* sends this improved version of the model back to each *client*. A complete cycle of weight updates is called a *round*. @@ -15,7 +15,7 @@ Now that we have a rough idea of what is going on, let's get started. We first n $ pip install flwr -Since we want to use PyTorch to solve a computer vision task, let's go ahead an install PyTorch and the **torchvision** library: +Since we want to use PyTorch to solve a computer vision task, let's go ahead an install PyTorch and the **torchvision** library: .. code-block:: shell @@ -32,51 +32,51 @@ Go ahead and launch on a terminal the *run-server.sh* script first as follows: .. code-block:: shell - $ bash ./run-server.sh + $ bash ./run-server.sh -Now that the server is up and running, go ahead and launch the clients. +Now that the server is up and running, go ahead and launch the clients. .. code-block:: shell - $ bash ./run-clients.sh + $ bash ./run-clients.sh Et voilà! You should be seeing the training procedure and, after a few iterations, the test accuracy for each client. .. code-block:: shell - Train Epoch: 10 [30000/30016 (100%)] Loss: 0.007014 - - Train Epoch: 10 [30000/30016 (100%)] Loss: 0.000403 - - Train Epoch: 11 [30000/30016 (100%)] Loss: 0.001280 - - Train Epoch: 11 [30000/30016 (100%)] Loss: 0.000641 - - Train Epoch: 12 [30000/30016 (100%)] Loss: 0.006784 - - Train Epoch: 12 [30000/30016 (100%)] Loss: 0.007134 - - Client 1 - Evaluate on 5000 samples: Average loss: 0.0290, Accuracy: 99.16% - + Train Epoch: 10 [30000/30016 (100%)] Loss: 0.007014 + + Train Epoch: 10 [30000/30016 (100%)] Loss: 0.000403 + + Train Epoch: 11 [30000/30016 (100%)] Loss: 0.001280 + + Train Epoch: 11 [30000/30016 (100%)] Loss: 0.000641 + + Train Epoch: 12 [30000/30016 (100%)] Loss: 0.006784 + + Train Epoch: 12 [30000/30016 (100%)] Loss: 0.007134 + + Client 1 - Evaluate on 5000 samples: Average loss: 0.0290, Accuracy: 99.16% + Client 0 - Evaluate on 5000 samples: Average loss: 0.0328, Accuracy: 99.14% -Now, let's see what is really happening inside. +Now, let's see what is really happening inside. Flower Server ------------- Inside the server helper script *run-server.sh* you will find the following code that basically runs the :code:`server.py` -.. code-block:: bash +.. code-block:: bash python -m flwr_example.quickstart-pytorch.server We can go a bit deeper and see that :code:`server.py` simply launches a server that will coordinate three rounds of training. -Flower Servers are very customizable, but for simple workloads, we can start a server using the `start_server `_ function and leave all the configuration possibilities at their default values, as seen below. +Flower Servers are very customizable, but for simple workloads, we can start a server using the `start_server `_ function and leave all the configuration possibilities at their default values, as seen below. .. code-block:: python @@ -90,18 +90,18 @@ Flower Client Next, let's take a look at the *run-clients.sh* file. You will see that it contains the main loop that starts a set of *clients*. -.. code-block:: bash +.. code-block:: bash python -m flwr_example.quickstart-pytorch.client \ --cid=$i \ --server_address=$SERVER_ADDRESS \ - --nb_clients=$NUM_CLIENTS + --nb_clients=$NUM_CLIENTS * **cid**: is the client ID. It is an integer that uniquely identifies client identifier. -* **sever_address**: String that identifies IP and port of the server. +* **sever_address**: String that identifies IP and port of the server. * **nb_clients**: This defines the number of clients being created. This piece of information is not required by the client, but it helps us partition the original MNIST dataset to make sure that every client is working on unique subsets of both *training* and *test* sets. -Again, we can go deeper and look inside :code:`flwr_example/quickstart-pytorch/client.py`. +Again, we can go deeper and look inside :code:`flwr_example/quickstart-pytorch/client.py`. After going through the argument parsing code at the beginning of our :code:`main` function, you will find a call to :code:`mnist.load_data`. This function is responsible for partitioning the original MNIST datasets (*training* and *test*) and returning a :code:`torch.utils.data.DataLoader` s for each of them. We then instantiate a :code:`PytorchMNISTClient` object with our client ID, our DataLoaders, the number of epochs in each round, and which device we want to use for training (CPU or GPU). @@ -152,7 +152,7 @@ Now, let's look closely into the :code:`PytorchMNISTClient` inside :code:`flwr_e Parameters ---------- - weights: fl.common.NDArrays + weights: fl.common.NDArrays Weights received by the server and set to local model @@ -179,8 +179,8 @@ Now, let's look closely into the :code:`PytorchMNISTClient` inside :code:`flwr_e Parameters ---------- - ins: fl.common.FitIns - Parameters sent by the server to be used during training. + ins: fl.common.FitIns + Parameters sent by the server to be used during training. Returns ------- @@ -214,9 +214,9 @@ Now, let's look closely into the :code:`PytorchMNISTClient` inside :code:`flwr_e Parameters ---------- - ins: fl.common.EvaluateIns - Parameters sent by the server to be used during testing. - + ins: fl.common.EvaluateIns + Parameters sent by the server to be used during testing. + Returns ------- @@ -262,9 +262,9 @@ The code for the CNN is available under :code:`quickstart-pytorch.mnist` and it Parameters ---------- - x: Tensor + x: Tensor Mini-batch of shape (N,28,28) containing images from MNIST dataset. - + Returns ------- @@ -287,7 +287,7 @@ The code for the CNN is available under :code:`quickstart-pytorch.mnist` and it return output -The second thing to notice is that :code:`PytorchMNISTClient` class inherits from the :code:`fl.client.Client`, and hence it must implement the following methods: +The second thing to notice is that :code:`PytorchMNISTClient` class inherits from the :code:`fl.client.Client`, and hence it must implement the following methods: .. code-block:: python @@ -312,7 +312,7 @@ The second thing to notice is that :code:`PytorchMNISTClient` class inherits fro """Evaluate the provided weights using the locally held dataset.""" -When comparing the abstract class to its derived class :code:`PytorchMNISTClient` you will notice that :code:`fit` calls a :code:`train` function and that :code:`evaluate` calls a :code:`test`: function. +When comparing the abstract class to its derived class :code:`PytorchMNISTClient` you will notice that :code:`fit` calls a :code:`train` function and that :code:`evaluate` calls a :code:`test`: function. These functions can both be found inside the same :code:`quickstart-pytorch.mnist` module: @@ -330,14 +330,14 @@ These functions can both be found inside the same :code:`quickstart-pytorch.mnis ---------- model: torch.nn.ModuleList Neural network model used in this example. - + train_loader: torch.utils.data.DataLoader DataLoader used in traning. - - epochs: int - Number of epochs to run in each round. - - device: torch.device + + epochs: int + Number of epochs to run in each round. + + device: torch.device (Default value = torch.device("cpu")) Device where the network will be trained within a client. @@ -399,10 +399,10 @@ These functions can both be found inside the same :code:`quickstart-pytorch.mnis ---------- model: torch.nn.ModuleList : Neural network model used in this example. - + test_loader: torch.utils.data.DataLoader : DataLoader used in test. - + device: torch.device : (Default value = torch.device("cpu")) Device where the network will be tested within a client. @@ -435,19 +435,19 @@ These functions can both be found inside the same :code:`quickstart-pytorch.mnis Observe that these functions encapsulate regular training and test loops and provide :code:`fit` and :code:`evaluate` with final statistics for each round. -You could substitute them with your custom train and test loops and change the network architecture, and the entire example would still work flawlessly. -As a matter of fact, why not try and modify the code to an example of your liking? +You could substitute them with your custom train and test loops and change the network architecture, and the entire example would still work flawlessly. +As a matter of fact, why not try and modify the code to an example of your liking? Give It a Try ------------- -Looking through the quickstart code description above will have given a good understanding of how *clients* and *servers* work in Flower, how to run a simple experiment, and the internals of a client wrapper. +Looking through the quickstart code description above will have given a good understanding of how *clients* and *servers* work in Flower, how to run a simple experiment, and the internals of a client wrapper. Here are a few things you could try on your own and get more experience with Flower: - Try and change :code:`PytorchMNISTClient` so it can accept different architectures. - Modify the :code:`train` function so that it accepts different optimizers - Modify the :code:`test` function so that it proves not only the top-1 (regular accuracy) but also the top-5 accuracy? -- Go larger! Try to adapt the code to larger images and datasets. Why not try training on ImageNet with a ResNet-50? +- Go larger! Try to adapt the code to larger images and datasets. Why not try training on ImageNet with a ResNet-50? You are ready now. Enjoy learning in a federated way! diff --git a/doc/source/how-to-configure-clients.rst b/doc/source/how-to-configure-clients.rst index bfb5a8f63761..ff0a2f4033df 100644 --- a/doc/source/how-to-configure-clients.rst +++ b/doc/source/how-to-configure-clients.rst @@ -86,7 +86,7 @@ Configuring individual clients In some cases, it is necessary to send different configuration values to different clients. -This can be achieved by customizing an existing strategy or by `implementing a custom strategy from scratch `_. Here's a nonsensical example that customizes :code:`FedAvg` by adding a custom ``"hello": "world"`` configuration key/value pair to the config dict of a *single client* (only the first client in the list, the other clients in this round to not receive this "special" config value): +This can be achieved by customizing an existing strategy or by :doc:`implementing a custom strategy from scratch `. Here's a nonsensical example that customizes :code:`FedAvg` by adding a custom ``"hello": "world"`` configuration key/value pair to the config dict of a *single client* (only the first client in the list, the other clients in this round to not receive this "special" config value): .. code-block:: python diff --git a/doc/source/how-to-install-flower.rst b/doc/source/how-to-install-flower.rst index dc88076424f8..aebe5f7316de 100644 --- a/doc/source/how-to-install-flower.rst +++ b/doc/source/how-to-install-flower.rst @@ -57,7 +57,7 @@ Advanced installation options Install via Docker ~~~~~~~~~~~~~~~~~~ -`How to run Flower using Docker `_ +:doc:`How to run Flower using Docker ` Install pre-release ~~~~~~~~~~~~~~~~~~~ diff --git a/doc/source/how-to-monitor-simulation.rst b/doc/source/how-to-monitor-simulation.rst index 740004914eed..61a3be68deec 100644 --- a/doc/source/how-to-monitor-simulation.rst +++ b/doc/source/how-to-monitor-simulation.rst @@ -231,6 +231,6 @@ A: Either the simulation has already finished, or you still need to start Promet Resources --------- -Ray Dashboard: ``_ +Ray Dashboard: ``_ -Ray Metrics: ``_ +Ray Metrics: ``_ diff --git a/doc/source/ref-example-projects.rst b/doc/source/ref-example-projects.rst index 8eb723000cac..bade86dfaa54 100644 --- a/doc/source/ref-example-projects.rst +++ b/doc/source/ref-example-projects.rst @@ -23,7 +23,7 @@ The TensorFlow/Keras quickstart example shows CIFAR-10 image classification with MobileNetV2: - `Quickstart TensorFlow (Code) `_ -- `Quickstart TensorFlow (Tutorial) `_ +- :doc:`Quickstart TensorFlow (Tutorial) ` - `Quickstart TensorFlow (Blog Post) `_ @@ -34,7 +34,7 @@ The PyTorch quickstart example shows CIFAR-10 image classification with a simple Convolutional Neural Network: - `Quickstart PyTorch (Code) `_ -- `Quickstart PyTorch (Tutorial) `_ +- :doc:`Quickstart PyTorch (Tutorial) ` PyTorch: From Centralized To Federated @@ -43,7 +43,7 @@ PyTorch: From Centralized To Federated This example shows how a regular PyTorch project can be federated using Flower: - `PyTorch: From Centralized To Federated (Code) `_ -- `PyTorch: From Centralized To Federated (Tutorial) `_ +- :doc:`PyTorch: From Centralized To Federated (Tutorial) ` Federated Learning on Raspberry Pi and Nvidia Jetson @@ -60,7 +60,7 @@ Legacy Examples (`flwr_example`) -------------------------------- .. warning:: - The useage examples in `flwr_example` are deprecated and will be removed in + The usage examples in `flwr_example` are deprecated and will be removed in the future. New examples are provided as standalone projects in `examples `_. @@ -114,7 +114,7 @@ For more details, see :code:`src/py/flwr_example/pytorch_cifar`. ImageNet-2012 Image Classification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -`ImageNet-2012 `_ is one of the major computer +`ImageNet-2012 `_ is one of the major computer vision datasets. The Flower ImageNet example uses PyTorch to train a ResNet-18 classifier in a federated learning setup with ten clients. diff --git a/doc/source/ref-faq.rst b/doc/source/ref-faq.rst index 932396e3c583..26b7dca4a0a7 100644 --- a/doc/source/ref-faq.rst +++ b/doc/source/ref-faq.rst @@ -3,7 +3,7 @@ FAQ This page collects answers to commonly asked questions about Federated Learning with Flower. -.. dropdown:: :fa:`eye,mr-1` Can Flower run on Juptyter Notebooks / Google Colab? +.. dropdown:: :fa:`eye,mr-1` Can Flower run on Jupyter Notebooks / Google Colab? Yes, it can! Flower even comes with a few under-the-hood optimizations to make it work even better on Colab. Here's a quickstart example: @@ -27,6 +27,6 @@ This page collects answers to commonly asked questions about Federated Learning * `Flower meets Nevermined GitHub Repository `_. * `Flower meets Nevermined YouTube video `_. - * `Flower meets KOSMoS `_. + * `Flower meets KOSMoS `_. * `Flower meets Talan blog post `_ . * `Flower meets Talan GitHub Repository `_ . diff --git a/doc/source/tutorial-quickstart-huggingface.rst b/doc/source/tutorial-quickstart-huggingface.rst index 1e06120b452f..7d8128230901 100644 --- a/doc/source/tutorial-quickstart-huggingface.rst +++ b/doc/source/tutorial-quickstart-huggingface.rst @@ -9,8 +9,8 @@ Quickstart 🤗 Transformers Let's build a federated learning system using Hugging Face Transformers and Flower! -We will leverage Hugging Face to federate the training of language models over multiple clients using Flower. -More specifically, we will fine-tune a pre-trained Transformer model (distilBERT) +We will leverage Hugging Face to federate the training of language models over multiple clients using Flower. +More specifically, we will fine-tune a pre-trained Transformer model (distilBERT) for sequence classification over a dataset of IMDB ratings. The end goal is to detect if a movie rating is positive or negative. @@ -32,8 +32,8 @@ Standard Hugging Face workflow Handling the data ^^^^^^^^^^^^^^^^^ -To fetch the IMDB dataset, we will use Hugging Face's :code:`datasets` library. -We then need to tokenize the data and create :code:`PyTorch` dataloaders, +To fetch the IMDB dataset, we will use Hugging Face's :code:`datasets` library. +We then need to tokenize the data and create :code:`PyTorch` dataloaders, this is all done in the :code:`load_data` function: .. code-block:: python @@ -80,8 +80,8 @@ this is all done in the :code:`load_data` function: Training and testing the model ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Once we have a way of creating our trainloader and testloader, -we can take care of the training and testing. +Once we have a way of creating our trainloader and testloader, +we can take care of the training and testing. This is very similar to any :code:`PyTorch` training or testing loop: .. code-block:: python @@ -120,12 +120,12 @@ This is very similar to any :code:`PyTorch` training or testing loop: Creating the model itself ^^^^^^^^^^^^^^^^^^^^^^^^^ -To create the model itself, +To create the model itself, we will just load the pre-trained distillBERT model using Hugging Face’s :code:`AutoModelForSequenceClassification` : .. code-block:: python - from transformers import AutoModelForSequenceClassification + from transformers import AutoModelForSequenceClassification net = AutoModelForSequenceClassification.from_pretrained( CHECKPOINT, num_labels=2 @@ -138,8 +138,8 @@ Federating the example Creating the IMDBClient ^^^^^^^^^^^^^^^^^^^^^^^ -To federate our example to multiple clients, -we first need to write our Flower client class (inheriting from :code:`flwr.client.NumPyClient`). +To federate our example to multiple clients, +we first need to write our Flower client class (inheriting from :code:`flwr.client.NumPyClient`). This is very easy, as our model is a standard :code:`PyTorch` model: .. code-block:: python @@ -166,17 +166,17 @@ This is very easy, as our model is a standard :code:`PyTorch` model: return float(loss), len(testloader), {"accuracy": float(accuracy)} -The :code:`get_parameters` function lets the server get the client's parameters. -Inversely, the :code:`set_parameters` function allows the server to send its parameters to the client. -Finally, the :code:`fit` function trains the model locally for the client, -and the :code:`evaluate` function tests the model locally and returns the relevant metrics. +The :code:`get_parameters` function lets the server get the client's parameters. +Inversely, the :code:`set_parameters` function allows the server to send its parameters to the client. +Finally, the :code:`fit` function trains the model locally for the client, +and the :code:`evaluate` function tests the model locally and returns the relevant metrics. Starting the server ^^^^^^^^^^^^^^^^^^^ -Now that we have a way to instantiate clients, we need to create our server in order to aggregate the results. -Using Flower, this can be done very easily by first choosing a strategy (here, we are using :code:`FedAvg`, -which will define the global weights as the average of all the clients' weights at each round) +Now that we have a way to instantiate clients, we need to create our server in order to aggregate the results. +Using Flower, this can be done very easily by first choosing a strategy (here, we are using :code:`FedAvg`, +which will define the global weights as the average of all the clients' weights at each round) and then using the :code:`flwr.server.start_server` function: .. code-block:: python @@ -186,7 +186,7 @@ and then using the :code:`flwr.server.start_server` function: losses = [num_examples * m["loss"] for num_examples, m in metrics] examples = [num_examples for num_examples, _ in metrics] return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)} - + # Define strategy strategy = fl.server.strategy.FedAvg( fraction_fit=1.0, @@ -202,7 +202,7 @@ and then using the :code:`flwr.server.start_server` function: ) -The :code:`weighted_average` function is there to provide a way to aggregate the metrics distributed amongst +The :code:`weighted_average` function is there to provide a way to aggregate the metrics distributed amongst the clients (basically this allows us to display a nice average accuracy and loss for every round). Putting everything together @@ -213,18 +213,17 @@ We can now start client instances using: .. code-block:: python fl.client.start_client( - server_address="127.0.0.1:8080", + server_address="127.0.0.1:8080", client=IMDBClient().to_client() ) And they will be able to connect to the server and start the federated training. -If you want to check out everything put together, -you should check out the full code example: -[https://github.com/adap/flower/tree/main/examples/quickstart-huggingface](https://github.com/adap/flower/tree/main/examples/quickstart-huggingface). +If you want to check out everything put together, +you should check out the `full code example `_ . -Of course, this is a very basic example, and a lot can be added or modified, +Of course, this is a very basic example, and a lot can be added or modified, it was just to showcase how simply we could federate a Hugging Face workflow using Flower. Note that in this example we used :code:`PyTorch`, but we could have very well used :code:`TensorFlow`. diff --git a/doc/source/tutorial-quickstart-ios.rst b/doc/source/tutorial-quickstart-ios.rst index aa94a72580c1..e4315ce569fb 100644 --- a/doc/source/tutorial-quickstart-ios.rst +++ b/doc/source/tutorial-quickstart-ios.rst @@ -9,7 +9,7 @@ Quickstart iOS In this tutorial we will learn how to train a Neural Network on MNIST using Flower and CoreML on iOS devices. -First of all, for running the Flower Python server, it is recommended to create a virtual environment and run everything within a `virtualenv `_. +First of all, for running the Flower Python server, it is recommended to create a virtual environment and run everything within a :doc:`virtualenv `. For the Flower client implementation in iOS, it is recommended to use Xcode as our IDE. Our example consists of one Python *server* and two iPhone *clients* that all have the same model. @@ -97,9 +97,9 @@ For the MNIST dataset, we need to preprocess it into :code:`MLBatchProvider` obj testBatchProvider: testBatchProvider) Since CoreML does not allow the model parameters to be seen before training, and accessing the model parameters during or after the training can only be done by specifying the layer name, -we need to know this informations beforehand, through looking at the model specification, which are written as proto files. The implementation can be seen in :code:`MLModelInspect`. +we need to know this information beforehand, through looking at the model specification, which are written as proto files. The implementation can be seen in :code:`MLModelInspect`. -After we have all of the necessary informations, let's create our Flower client. +After we have all of the necessary information, let's create our Flower client. .. code-block:: swift diff --git a/doc/source/tutorial-quickstart-mxnet.rst b/doc/source/tutorial-quickstart-mxnet.rst index 08304483af86..fe582f793280 100644 --- a/doc/source/tutorial-quickstart-mxnet.rst +++ b/doc/source/tutorial-quickstart-mxnet.rst @@ -4,14 +4,14 @@ Quickstart MXNet ================ -.. warning:: MXNet is no longer maintained and has been moved into `Attic `_. As a result, we would encourage you to use other ML frameworks alongise Flower, for example, PyTorch. This tutorial might be removed in future versions of Flower. +.. warning:: MXNet is no longer maintained and has been moved into `Attic `_. As a result, we would encourage you to use other ML frameworks alongside Flower, for example, PyTorch. This tutorial might be removed in future versions of Flower. .. meta:: :description: Check out this Federated Learning quickstart tutorial for using Flower with MXNet to train a Sequential model on MNIST. In this tutorial, we will learn how to train a :code:`Sequential` model on MNIST using Flower and MXNet. -It is recommended to create a virtual environment and run everything within this `virtualenv `_. +It is recommended to create a virtual environment and run everything within this :doc:`virtualenv `. Our example consists of one *server* and two *clients* all having the same model. diff --git a/doc/source/tutorial-quickstart-pytorch.rst b/doc/source/tutorial-quickstart-pytorch.rst index 32f9c5ebb3a1..895590808a2b 100644 --- a/doc/source/tutorial-quickstart-pytorch.rst +++ b/doc/source/tutorial-quickstart-pytorch.rst @@ -12,7 +12,7 @@ Quickstart PyTorch In this tutorial we will learn how to train a Convolutional Neural Network on CIFAR10 using Flower and PyTorch. -First of all, it is recommended to create a virtual environment and run everything within a `virtualenv `_. +First of all, it is recommended to create a virtual environment and run everything within a :doc:`virtualenv `. Our example consists of one *server* and two *clients* all having the same model. diff --git a/doc/source/tutorial-quickstart-scikitlearn.rst b/doc/source/tutorial-quickstart-scikitlearn.rst index b95118aa091f..d1d47dc37f19 100644 --- a/doc/source/tutorial-quickstart-scikitlearn.rst +++ b/doc/source/tutorial-quickstart-scikitlearn.rst @@ -9,7 +9,7 @@ Quickstart scikit-learn In this tutorial, we will learn how to train a :code:`Logistic Regression` model on MNIST using Flower and scikit-learn. -It is recommended to create a virtual environment and run everything within this `virtualenv `_. +It is recommended to create a virtual environment and run everything within this :doc:`virtualenv `. Our example consists of one *server* and two *clients* all having the same model. @@ -23,7 +23,7 @@ Now that we have a rough idea of what is going on, let's get started. We first n $ pip install flwr -Since we want to use scikt-learn, let's go ahead and install it: +Since we want to use scikit-learn, let's go ahead and install it: .. code-block:: shell @@ -43,7 +43,7 @@ Now that we have all our dependencies installed, let's run a simple distributed However, before setting up the client and server, we will define all functionalities that we need for our federated learning setup within :code:`utils.py`. The :code:`utils.py` contains different functions defining all the machine learning basics: * :code:`get_model_parameters()` - * Returns the paramters of a :code:`sklearn` LogisticRegression model + * Returns the parameters of a :code:`sklearn` LogisticRegression model * :code:`set_model_params()` * Sets the parameters of a :code:`sklean` LogisticRegression model * :code:`set_initial_params()` @@ -70,7 +70,7 @@ The pre-defined functions are used in the :code:`client.py` and imported. The :c import utils -We load the MNIST dataset from `OpenML `_, a popular image classification dataset of handwritten digits for machine learning. The utility :code:`utils.load_mnist()` downloads the training and test data. The training set is split afterwards into 10 partitions with :code:`utils.partition()`. +We load the MNIST dataset from `OpenML `_, a popular image classification dataset of handwritten digits for machine learning. The utility :code:`utils.load_mnist()` downloads the training and test data. The training set is split afterwards into 10 partitions with :code:`utils.partition()`. .. code-block:: python diff --git a/doc/source/tutorial-quickstart-xgboost.rst b/doc/source/tutorial-quickstart-xgboost.rst index ec9101f4b3fd..751024db14e4 100644 --- a/doc/source/tutorial-quickstart-xgboost.rst +++ b/doc/source/tutorial-quickstart-xgboost.rst @@ -36,7 +36,7 @@ and then we dive into a more complex example (`full code xgboost-comprehensive < Environment Setup -------------------- -First of all, it is recommended to create a virtual environment and run everything within a `virtualenv `_. +First of all, it is recommended to create a virtual environment and run everything within a :doc:`virtualenv `. We first need to install Flower and Flower Datasets. You can do this by running : @@ -596,7 +596,7 @@ Comprehensive Federated XGBoost Now that you have known how federated XGBoost work with Flower, it's time to run some more comprehensive experiments by customising the experimental settings. In the xgboost-comprehensive example (`full code `_), we provide more options to define various experimental setups, including aggregation strategies, data partitioning and centralised/distributed evaluation. -We also support `Flower simulation `_ making it easy to simulate large client cohorts in a resource-aware manner. +We also support :doc:`Flower simulation ` making it easy to simulate large client cohorts in a resource-aware manner. Let's take a look! Cyclic training diff --git a/doc/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb b/doc/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb index 5b2236468909..c5fc777e7f26 100644 --- a/doc/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb +++ b/doc/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb @@ -7,11 +7,11 @@ "source": [ "# Build a strategy from scratch\n", "\n", - "Welcome to the third part of the Flower federated learning tutorial. In previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.dev/docs/framework/tutorial-get-started-with-flower-pytorch.html)) and we learned how strategies can be used to customize the execution on both the server and the clients ([part 2](https://flower.dev/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html)).\n", + "Welcome to the third part of the Flower federated learning tutorial. In previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.ai/docs/framework/tutorial-get-started-with-flower-pytorch.html)) and we learned how strategies can be used to customize the execution on both the server and the clients ([part 2](https://flower.ai/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html)).\n", "\n", - "In this notebook, we'll continue to customize the federated learning system we built previously by creating a custom version of FedAvg (again, using [Flower](https://flower.dev/) and [PyTorch](https://pytorch.org/)).\n", + "In this notebook, we'll continue to customize the federated learning system we built previously by creating a custom version of FedAvg (again, using [Flower](https://flower.ai/) and [PyTorch](https://pytorch.org/)).\n", "\n", - "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.dev/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", + "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.ai/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", "\n", "Let's build a new `Strategy` from scratch!" ] @@ -489,11 +489,11 @@ "source": [ "## Next steps\n", "\n", - "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.dev/join-slack/)\n", + "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.ai/join-slack/)\n", "\n", "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n", "\n", - "The [Flower Federated Learning Tutorial - Part 4](https://flower.dev/docs/framework/tutorial-customize-the-client-pytorch.html) introduces `Client`, the flexible API underlying `NumPyClient`." + "The [Flower Federated Learning Tutorial - Part 4](https://flower.ai/docs/framework/tutorial-customize-the-client-pytorch.html) introduces `Client`, the flexible API underlying `NumPyClient`." ] } ], diff --git a/doc/source/tutorial-series-customize-the-client-pytorch.ipynb b/doc/source/tutorial-series-customize-the-client-pytorch.ipynb index bcfdeb30d3c7..ce09c6cc46c1 100644 --- a/doc/source/tutorial-series-customize-the-client-pytorch.ipynb +++ b/doc/source/tutorial-series-customize-the-client-pytorch.ipynb @@ -869,16 +869,16 @@ "source": [ "## Next steps\n", "\n", - "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.dev/join-slack/)\n", + "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.ai/join-slack/)\n", "\n", "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n", "\n", "This is the final part of the Flower tutorial (for now!), congratulations! You're now well equipped to understand the rest of the documentation. There are many topics we didn't cover in the tutorial, we recommend the following resources:\n", "\n", - "- [Read Flower Docs](https://flower.dev/docs/)\n", + "- [Read Flower Docs](https://flower.ai/docs/)\n", "- [Check out Flower Code Examples](https://github.com/adap/flower/tree/main/examples)\n", - "- [Use Flower Baselines for your research](https://flower.dev/docs/baselines/)\n", - "- [Watch Flower Summit 2023 videos](https://flower.dev/conf/flower-summit-2023/)\n" + "- [Use Flower Baselines for your research](https://flower.ai/docs/baselines/)\n", + "- [Watch Flower Summit 2023 videos](https://flower.ai/conf/flower-summit-2023/)\n" ] } ], diff --git a/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb b/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb index f4b8acaa5bb8..205531c54ee6 100644 --- a/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb +++ b/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb @@ -605,11 +605,11 @@ "source": [ "## Next steps\n", "\n", - "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.dev/join-slack/)\n", + "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.ai/join-slack/)\n", "\n", "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n", "\n", - "The [Flower Federated Learning Tutorial - Part 2](https://flower.dev/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html) goes into more depth about strategies and all the advanced things you can build with them.\n" + "The [Flower Federated Learning Tutorial - Part 2](https://flower.ai/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html) goes into more depth about strategies and all the advanced things you can build with them.\n" ] } ], diff --git a/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb b/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb index c758b8f637b0..e20a8d83f674 100644 --- a/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb +++ b/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb @@ -614,11 +614,11 @@ "source": [ "## Next steps\n", "\n", - "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.dev/join-slack/)\n", + "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.ai/join-slack/)\n", "\n", "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n", "\n", - "The [Flower Federated Learning Tutorial - Part 3](https://flower.dev/docs/framework/tutorial-build-a-strategy-from-scratch-pytorch.html) shows how to build a fully custom `Strategy` from scratch." + "The [Flower Federated Learning Tutorial - Part 3](https://flower.ai/docs/framework/tutorial-build-a-strategy-from-scratch-pytorch.html) shows how to build a fully custom `Strategy` from scratch." ] } ], diff --git a/doc/source/tutorial-series-what-is-federated-learning.ipynb b/doc/source/tutorial-series-what-is-federated-learning.ipynb index 3f7e383b9fbc..d77182838f21 100755 --- a/doc/source/tutorial-series-what-is-federated-learning.ipynb +++ b/doc/source/tutorial-series-what-is-federated-learning.ipynb @@ -13,7 +13,7 @@ "\n", "🧑‍🏫 This tutorial starts at zero and expects no familiarity with federated learning. Only a basic understanding of data science and Python programming is assumed.\n", "\n", - "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the open-source Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.dev/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", + "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the open-source Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.ai/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n", "\n", "Let's get started!" ] @@ -217,11 +217,11 @@ "source": [ "## Next steps\n", "\n", - "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.dev/join-slack/)\n", + "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.ai/join-slack/)\n", "\n", "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n", "\n", - "The [Flower Federated Learning Tutorial - Part 1](https://flower.dev/docs/framework/tutorial-get-started-with-flower-pytorch.html) shows how to build a simple federated learning system with PyTorch and Flower." + "The [Flower Federated Learning Tutorial - Part 1](https://flower.ai/docs/framework/tutorial-get-started-with-flower-pytorch.html) shows how to build a simple federated learning system with PyTorch and Flower." ] } ], diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py index aa8dc1ddd6c0..bcaac1f61b85 100644 --- a/src/py/flwr/server/driver/driver.py +++ b/src/py/flwr/server/driver/driver.py @@ -44,6 +44,7 @@ class Driver: Tuple containing root certificate, server certificate, and private key to start a secure SSL-enabled server. The tuple is expected to have three bytes elements in the following order: + * CA certificate. * server certificate. * server private key. From 27ae762491abe07ac9160614d090b92415827139 Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 26 Feb 2024 14:37:38 +0100 Subject: [PATCH 03/17] Add build/process/terminate tests for `RayBackend` (#3011) --- .../superlink/fleet/vce/backend/raybackend.py | 3 +- .../fleet/vce/backend/raybackend_test.py | 141 ++++++++++++++++++ 2 files changed, 142 insertions(+), 2 deletions(-) create mode 100644 src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py index cc3cf4348491..b29d76b239e5 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py @@ -14,7 +14,6 @@ # ============================================================================== """Ray backend for the Fleet API using the Simulation Engine.""" -import asyncio import pathlib from logging import INFO from typing import Callable, Dict, List, Tuple, Union @@ -145,7 +144,7 @@ async def process_message( (app, message, str(node_id), context), ) - await asyncio.wait([future]) + await future # Fetch result ( diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py new file mode 100644 index 000000000000..f0cca527ab96 --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py @@ -0,0 +1,141 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for Ray backend for the Fleet API using the Simulation Engine.""" + +import asyncio +from math import pi +from typing import Callable, Dict, Optional, Tuple, Union +from unittest import IsolatedAsyncioTestCase + +import ray + +from flwr.client import Client, NumPyClient +from flwr.client.clientapp import ClientApp +from flwr.common import ( + Config, + ConfigsRecord, + Context, + GetPropertiesIns, + Message, + Metadata, + RecordSet, + Scalar, +) +from flwr.common.constant import MESSAGE_TYPE_GET_PROPERTIES +from flwr.common.recordset_compat import getpropertiesins_to_recordset + +from .raybackend import RayBackend + + +class DummyClient(NumPyClient): + """A dummy NumPyClient for tests.""" + + def get_properties(self, config: Config) -> Dict[str, Scalar]: + """Return properties by doing a simple calculation.""" + result = float(config["factor"]) * pi + + # store something in context + self.context.state.configs_records["result"] = ConfigsRecord({"result": result}) + return {"result": result} + + +def get_dummy_client(cid: str) -> Client: # pylint: disable=unused-argument + """Return a DummyClient converted to Client type.""" + return DummyClient().to_client() + + +def _load_app() -> ClientApp: + return ClientApp(client_fn=get_dummy_client) + + +async def backend_build_process_and_termination( + backend: RayBackend, + process_args: Optional[Tuple[Callable[[], ClientApp], Message, Context]] = None, +) -> Union[Tuple[Message, Context], None]: + """Build, process job and terminate RayBackend.""" + await backend.build() + to_return = None + + if process_args: + to_return = await backend.process_message(*process_args) + + await backend.terminate() + + ray.shutdown() + + return to_return + + +class AsyncTestRayBackend(IsolatedAsyncioTestCase): + """A basic class that allows runnig multliple asyncio tests.""" + + def test_backend_creation_and_termination(self) -> None: + """Test creation of RayBackend and its termination.""" + backend = RayBackend(backend_config={}, work_dir="") + asyncio.run( + backend_build_process_and_termination(backend=backend, process_args=None) + ) + + def test_backend_creation_submit_and_termination(self) -> None: + """Test submit.""" + backend = RayBackend(backend_config={}, work_dir="") + + # Define ClientApp + client_app_callable = _load_app + + # Construct a Message + mult_factor = 2024 + getproperties_ins = GetPropertiesIns(config={"factor": mult_factor}) + recordset = getpropertiesins_to_recordset(getproperties_ins) + message = Message( + content=recordset, + metadata=Metadata( + run_id=0, + message_id="", + group_id="", + src_node_id=0, + dst_node_id=0, + reply_to_message="", + ttl="", + message_type=MESSAGE_TYPE_GET_PROPERTIES, + ), + ) + + # Construct emtpy Context + context = Context(state=RecordSet()) + + res = asyncio.run( + backend_build_process_and_termination( + backend=backend, process_args=(client_app_callable, message, context) + ) + ) + + if res is None: + raise AssertionError("This shouldn't happen") + + out_mssg, updated_context = res + + # Verify message content is as expected + content = out_mssg.content + assert ( + content.configs_records["getpropertiesres.properties"]["result"] + == pi * mult_factor + ) + + # Verify context is correct + obtained_result_in_context = updated_context.state.configs_records["result"][ + "result" + ] + assert obtained_result_in_context == pi * mult_factor From 514f670948776ee3468388ca22540e8e4d64ee4f Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 26 Feb 2024 22:03:20 +0100 Subject: [PATCH 04/17] Make `InMemoryState` thread-safe when handling `TaskIns` (#3012) --- .../server/superlink/state/in_memory_state.py | 41 +++++++++++-------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index ecb39f18300a..690fadc032d7 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -16,6 +16,7 @@ import os +import threading from datetime import datetime, timedelta from logging import ERROR from typing import Dict, List, Optional, Set @@ -35,6 +36,7 @@ def __init__(self) -> None: self.run_ids: Set[int] = set() self.task_ins_store: Dict[UUID, TaskIns] = {} self.task_res_store: Dict[UUID, TaskRes] = {} + self.lock = threading.Lock() def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: """Store one TaskIns.""" @@ -57,7 +59,8 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: task_ins.task_id = str(task_id) task_ins.task.created_at = created_at.isoformat() task_ins.task.ttl = ttl.isoformat() - self.task_ins_store[task_id] = task_ins + with self.lock: + self.task_ins_store[task_id] = task_ins # Return the new task_id return task_id @@ -71,22 +74,23 @@ def get_task_ins( # Find TaskIns for node_id that were not delivered yet task_ins_list: List[TaskIns] = [] - for _, task_ins in self.task_ins_store.items(): - # pylint: disable=too-many-boolean-expressions - if ( - node_id is not None # Not anonymous - and task_ins.task.consumer.anonymous is False - and task_ins.task.consumer.node_id == node_id - and task_ins.task.delivered_at == "" - ) or ( - node_id is None # Anonymous - and task_ins.task.consumer.anonymous is True - and task_ins.task.consumer.node_id == 0 - and task_ins.task.delivered_at == "" - ): - task_ins_list.append(task_ins) - if limit and len(task_ins_list) == limit: - break + with self.lock: + for _, task_ins in self.task_ins_store.items(): + # pylint: disable=too-many-boolean-expressions + if ( + node_id is not None # Not anonymous + and task_ins.task.consumer.anonymous is False + and task_ins.task.consumer.node_id == node_id + and task_ins.task.delivered_at == "" + ) or ( + node_id is None # Anonymous + and task_ins.task.consumer.anonymous is True + and task_ins.task.consumer.node_id == 0 + and task_ins.task.delivered_at == "" + ): + task_ins_list.append(task_ins) + if limit and len(task_ins_list) == limit: + break # Mark all of them as delivered delivered_at = now().isoformat() @@ -164,7 +168,8 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None: task_res_to_be_deleted.add(task_res_id) for task_id in task_ins_to_be_deleted: - del self.task_ins_store[task_id] + with self.lock: + del self.task_ins_store[task_id] for task_id in task_res_to_be_deleted: del self.task_res_store[task_id] From 39ef78bf6e742ea7f7a9098ef7e265cb8d60372d Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Mon, 26 Feb 2024 22:28:40 +0100 Subject: [PATCH 05/17] Fix incorrect link in FedPara README (#3014) --- baselines/fedpara/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baselines/fedpara/README.md b/baselines/fedpara/README.md index 0ff9203a9a58..89cca76a2aa2 100644 --- a/baselines/fedpara/README.md +++ b/baselines/fedpara/README.md @@ -93,7 +93,7 @@ As for the parameters ratio ($\gamma$) we use the following model sizes. As in t ## Environment Setup To construct the Python environment follow these steps: -It is assumed that `pyenv` is installed, `poetry` is installed and python 3.10.6 is installed using `pyenv`. Refer to this [documentation](https://flower.ai/docs/baselines/how-to-usef-baselines.html#setting-up-your-machine) to ensure that your machine is ready. +It is assumed that `pyenv` is installed, `poetry` is installed and python 3.10.6 is installed using `pyenv`. Refer to this [documentation](https://flower.ai/docs/baselines/how-to-use-baselines.html#setting-up-your-machine) to ensure that your machine is ready. ```bash # Set Python 3.10 From ccb0b35ce0477a9048162a1a43069eabc951c66b Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 26 Feb 2024 22:46:20 +0100 Subject: [PATCH 06/17] Add `partition_id` to `Metadata` (#3013) --- .../client/message_handler/message_handler.py | 2 +- .../message_handler/message_handler_test.py | 2 ++ src/py/flwr/common/message.py | 19 ++++++++++++++++++- .../ray_transport/ray_client_proxy.py | 1 + .../ray_transport/ray_client_proxy_test.py | 3 ++- 5 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index e7e6c7e05c71..87cace88ec27 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -98,7 +98,7 @@ def handle_legacy_message_from_msgtype( client_fn: ClientFn, message: Message, context: Context ) -> Message: """Handle legacy message in the inner most mod.""" - client = client_fn(str(message.metadata.dst_node_id)) + client = client_fn(str(message.metadata.partition_id)) # Check if NumPyClient is returend if isinstance(client, NumPyClient): diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 9fc126f27923..c24b51972f30 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -269,6 +269,8 @@ def test_invalid_message_run_id(self) -> None: invalid_metadata_list: List[Metadata] = [] attrs = list(vars(self.valid_out_metadata).keys()) for attr in attrs: + if attr == "_partition_id": + continue if attr == "_ttl": # Skip configurable ttl continue # Make an invalid metadata diff --git a/src/py/flwr/common/message.py b/src/py/flwr/common/message.py index 14dae0f6ee57..1e1132e42e27 100644 --- a/src/py/flwr/common/message.py +++ b/src/py/flwr/common/message.py @@ -14,7 +14,6 @@ # ============================================================================== """Message.""" - from __future__ import annotations from dataclasses import dataclass @@ -46,6 +45,10 @@ class Metadata: # pylint: disable=too-many-instance-attributes message_type : str A string that encodes the action to be executed on the receiving end. + partition_id : Optional[int] + An identifier that can be used when loading a particular + data partition for a ClientApp. Making use of this identifier + is more relevant when conducting simulations. """ _run_id: int @@ -56,6 +59,7 @@ class Metadata: # pylint: disable=too-many-instance-attributes _group_id: str _ttl: str _message_type: str + _partition_id: int | None def __init__( # pylint: disable=too-many-arguments self, @@ -67,6 +71,7 @@ def __init__( # pylint: disable=too-many-arguments group_id: str, ttl: str, message_type: str, + partition_id: int | None = None, ) -> None: self._run_id = run_id self._message_id = message_id @@ -76,6 +81,7 @@ def __init__( # pylint: disable=too-many-arguments self._group_id = group_id self._ttl = ttl self._message_type = message_type + self._partition_id = partition_id @property def run_id(self) -> int: @@ -137,6 +143,16 @@ def message_type(self, value: str) -> None: """Set message_type.""" self._message_type = value + @property + def partition_id(self) -> int | None: + """An identifier telling which data partition a ClientApp should use.""" + return self._partition_id + + @partition_id.setter + def partition_id(self, value: int) -> None: + """Set patition_id.""" + self._partition_id = value + @dataclass class Message: @@ -202,6 +218,7 @@ def create_reply(self, content: RecordSet, ttl: str) -> Message: group_id=self.metadata.group_id, ttl=ttl, message_type=self.metadata.message_type, + partition_id=self.metadata.partition_id, ), content=content, ) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 405e0920c5a4..a45321ed2368 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -111,6 +111,7 @@ def _wrap_recordset_in_message( reply_to_message="", ttl=str(timeout) if timeout else "", message_type=message_type, + partition_id=int(self.cid), ), ) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 3eeabe0292c9..24fe3546e7d9 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -198,10 +198,11 @@ def _load_app() -> ClientApp: message_id="", group_id="", src_node_id=0, - dst_node_id=int(cid), + dst_node_id=12345, reply_to_message="", ttl="", message_type=MESSAGE_TYPE_GET_PROPERTIES, + partition_id=int(cid), ), ) pool.submit_client_job( From 7bfc58ad98eef114f6c876e6a364082e5507f38f Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Mon, 26 Feb 2024 22:57:29 +0100 Subject: [PATCH 07/17] Fix incorrect URLs for baseline doc page (#3015) --- doc/locales/fr/LC_MESSAGES/framework-docs.po | 6 +++--- doc/locales/zh_Hans/LC_MESSAGES/framework-docs.po | 2 +- doc/source/ref-changelog.md | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/doc/locales/fr/LC_MESSAGES/framework-docs.po b/doc/locales/fr/LC_MESSAGES/framework-docs.po index 920a47abab3b..ba5ea5ec070d 100644 --- a/doc/locales/fr/LC_MESSAGES/framework-docs.po +++ b/doc/locales/fr/LC_MESSAGES/framework-docs.po @@ -1325,7 +1325,7 @@ msgid "" msgstr "" "Si tu n'es pas familier avec les Flower Baselines, tu devrais " "probablement consulter notre `guide de contribution pour les baselines " -"`_." +"`_." #: ../../source/contributor-ref-good-first-contributions.rst:27 msgid "" @@ -15862,7 +15862,7 @@ msgstr "" "l'utilisation de [Flower Baselines](https://flower.ai/docs/using-" "baselines.html). Avec cette première version préliminaire, nous invitons " "également la communauté à [contribuer à leurs propres lignes de " -"base](https://flower.ai/docs/contributing-baselines.html)." +"base](https://flower.ai/docs/baselines/how-to-contribute-baselines.html)." #: ../../source/ref-changelog.md:662 msgid "" @@ -25474,7 +25474,7 @@ msgstr "" #~ " papers. If you want to add a" #~ " new baseline or experiment, please " #~ "check the `Contributing Baselines " -#~ "`_ " +#~ "`_ " #~ "section." #~ msgstr "" diff --git a/doc/locales/zh_Hans/LC_MESSAGES/framework-docs.po b/doc/locales/zh_Hans/LC_MESSAGES/framework-docs.po index ab1c8dc39e64..b6c32f994597 100644 --- a/doc/locales/zh_Hans/LC_MESSAGES/framework-docs.po +++ b/doc/locales/zh_Hans/LC_MESSAGES/framework-docs.po @@ -1230,7 +1230,7 @@ msgid "" "/contributing-baselines.html>`_." msgstr "" "如果您对 Flower Baselines 还不熟悉,也许可以看看我们的 `Baselines贡献指南 " -"`_。" +"`_。" #: ../../source/contributor-ref-good-first-contributions.rst:27 msgid "" diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index 41dc91873c6c..54092e15a564 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -657,7 +657,7 @@ We would like to give our **special thanks** to all the contributors who made Fl - **Flower Baselines (preview): FedOpt, FedBN, FedAvgM** ([#919](https://github.com/adap/flower/pull/919), [#1127](https://github.com/adap/flower/pull/1127), [#914](https://github.com/adap/flower/pull/914)) - The first preview release of Flower Baselines has arrived! We're kickstarting Flower Baselines with implementations of FedOpt (FedYogi, FedAdam, FedAdagrad), FedBN, and FedAvgM. Check the documentation on how to use [Flower Baselines](https://flower.ai/docs/using-baselines.html). With this first preview release we're also inviting the community to [contribute their own baselines](https://flower.ai/docs/contributing-baselines.html). + The first preview release of Flower Baselines has arrived! We're kickstarting Flower Baselines with implementations of FedOpt (FedYogi, FedAdam, FedAdagrad), FedBN, and FedAvgM. Check the documentation on how to use [Flower Baselines](https://flower.ai/docs/using-baselines.html). With this first preview release we're also inviting the community to [contribute their own baselines](https://flower.ai/docs/baselines/how-to-contribute-baselines.html). - **C++ client SDK (preview) and code example** ([#1111](https://github.com/adap/flower/pull/1111)) From 4abfd066b444e6dbc01b2944ba5934dfa3e39b03 Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Tue, 27 Feb 2024 10:27:19 +0100 Subject: [PATCH 08/17] Add DirichletPartitioner (#2795) Co-authored-by: Javier --- .../flwr_datasets/partitioner/__init__.py | 2 + .../partitioner/dirichlet_partitioner.py | 323 ++++++++++++++++++ .../partitioner/dirichlet_partitioner_test.py | 170 +++++++++ 3 files changed, 495 insertions(+) create mode 100644 datasets/flwr_datasets/partitioner/dirichlet_partitioner.py create mode 100644 datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py diff --git a/datasets/flwr_datasets/partitioner/__init__.py b/datasets/flwr_datasets/partitioner/__init__.py index 5e7c86718f67..6a85f8a11749 100644 --- a/datasets/flwr_datasets/partitioner/__init__.py +++ b/datasets/flwr_datasets/partitioner/__init__.py @@ -15,6 +15,7 @@ """Flower Datasets Partitioner package.""" +from .dirichlet_partitioner import DirichletPartitioner from .exponential_partitioner import ExponentialPartitioner from .iid_partitioner import IidPartitioner from .linear_partitioner import LinearPartitioner @@ -27,6 +28,7 @@ "IidPartitioner", "Partitioner", "NaturalIdPartitioner", + "DirichletPartitioner", "SizePartitioner", "LinearPartitioner", "SquarePartitioner", diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py new file mode 100644 index 000000000000..5f1df71991bb --- /dev/null +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py @@ -0,0 +1,323 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Dirichlet partitioner class that works with Hugging Face Datasets.""" + + +import warnings +from typing import Dict, List, Optional, Union + +import numpy as np + +import datasets +from flwr_datasets.common.typing import NDArrayFloat +from flwr_datasets.partitioner.partitioner import Partitioner + + +# pylint: disable=R0902, R0912 +class DirichletPartitioner(Partitioner): + """Partitioner based on Dirichlet distribution. + + Implementation based on Bayesian Nonparametric Federated Learning of Neural Networks + https://arxiv.org/abs/1905.12022. + + The algorithm sequentially divides the data with each label. The fractions of the + data with each label is drawn from Dirichlet distribution and adjusted in case of + balancing. The data is assigned. In case the `min_partition_size` is not satisfied + the algorithm is run again (the fractions will change since it is a random process + even though the alpha stays the same). + + The notion of balancing is explicitly introduced here (not mentioned in paper but + implemented in the code). It is a mechanism that excludes the node from + assigning new samples to it if the current number of samples on that node exceeds + the average number that the node would get in case of even data distribution. + It is controlled by`self_balancing` parameter. + + Parameters + ---------- + num_partitions : int + The total number of partitions that the data will be divided into. + partition_by : str + Column name of the labels (targets) based on which Dirichlet sampling works. + alpha : Union[int, float, List[float], NDArrayFloat] + Concentration parameter to the Dirichlet distribution + min_partition_size : int + The minimum number of samples that each partitions will have (the sampling + process is repeated if any partition is too small). + self_balancing : bool + Whether assign further samples to a partition after the number of samples + exceeded the average number of samples per partition. (True in the original + paper's code although not mentioned in paper itself). + shuffle: bool + Whether to randomize the order of samples. Shuffling applied after the + samples assignment to nodes. + seed: int + Seed used for dataset shuffling. It has no effect if `shuffle` is False. + + Examples + -------- + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.partitioner import DirichletPartitioner + >>> + >>> partitioner = DirichletPartitioner(num_partitions=10, partition_by="label", + >>> alpha=0.5, min_partition_size=10, + >>> self_balancing=True) + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) + >>> partition = fds.load_partition(0) + >>> print(partition[0]) # Print the first example + {'image': , + 'label': 4} + >>> partition_sizes = [len(fds.load_partition(node_id)) for node_id in range(10)] + >>> print(sorted(partition_sizes)) + [2134, 2615, 3646, 6011, 6170, 6386, 6715, 7653, 8435, 10235] + """ + + def __init__( # pylint: disable=R0913 + self, + num_partitions: int, + partition_by: str, + alpha: Union[int, float, List[float], NDArrayFloat], + min_partition_size: int = 10, + self_balancing: bool = True, + shuffle: bool = True, + seed: Optional[int] = 42, + ) -> None: + super().__init__() + # Attributes based on the constructor + self._num_partitions = num_partitions + self._check_num_partitions_greater_than_zero() + self._alpha: NDArrayFloat = self._initialize_alpha(alpha) + self._partition_by = partition_by + self._min_partition_size: int = min_partition_size + self._self_balancing = self_balancing + self._shuffle = shuffle + self._seed = seed + self._rng = np.random.default_rng(seed=self._seed) # NumPy random generator + + # Utility attributes + # The attributes below are determined during the first call to load_partition + self._avg_num_of_samples_per_node: Optional[float] = None + self._unique_classes: Optional[Union[List[int], List[str]]] = None + self._node_id_to_indices: Dict[int, List[int]] = {} + self._node_id_to_indices_determined = False + + def load_partition(self, node_id: int) -> datasets.Dataset: + """Load a partition based on the partition index. + + Parameters + ---------- + node_id : int + the index that corresponds to the requested partition + + Returns + ------- + dataset_partition : Dataset + single partition of a dataset + """ + # The partitioning is done lazily - only when the first partition is + # requested. Only the first call creates the indices assignments for all the + # partition indices. + self._check_num_partitions_correctness_if_needed() + self._determine_node_id_to_indices_if_needed() + return self.dataset.select(self._node_id_to_indices[node_id]) + + def _initialize_alpha( + self, alpha: Union[int, float, List[float], NDArrayFloat] + ) -> NDArrayFloat: + """Convert alpha to the used format in the code a NDArrayFloat. + + The alpha can be provided in constructor can be in different format for user + convenience. The format into which it's transformed here is used throughout the + code for computation. + + Parameters + ---------- + alpha : Union[int, float, List[float], NDArrayFloat] + Concentration parameter to the Dirichlet distribution + + Returns + ------- + alpha : NDArrayFloat + Concentration parameter in a format ready to used in computation. + """ + if isinstance(alpha, int): + alpha = np.array([float(alpha)], dtype=float).repeat(self._num_partitions) + elif isinstance(alpha, float): + alpha = np.array([alpha], dtype=float).repeat(self._num_partitions) + elif isinstance(alpha, List): + if len(alpha) != self._num_partitions: + raise ValueError( + "If passing alpha as a List, it needs to be of length of equal to " + "num_partitions." + ) + alpha = np.asarray(alpha) + elif isinstance(alpha, np.ndarray): + # pylint: disable=R1720 + if alpha.ndim == 1 and alpha.shape[0] != self._num_partitions: + raise ValueError( + "If passing alpha as an NDArray, its length needs to be of length " + "equal to num_partitions." + ) + elif alpha.ndim == 2: + alpha = alpha.flatten() + if alpha.shape[0] != self._num_partitions: + raise ValueError( + "If passing alpha as an NDArray, its size needs to be of length" + " equal to num_partitions." + ) + else: + raise ValueError("The given alpha format is not supported.") + if not (alpha > 0).all(): + raise ValueError( + f"Alpha values should be strictly greater than zero. " + f"Instead it'd be converted to {alpha}" + ) + return alpha + + def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0914 + """Create an assignment of indices to the partition indices.""" + if self._node_id_to_indices_determined: + return + + # Generate information needed for Dirichlet partitioning + self._unique_classes = self.dataset.unique(self._partition_by) + assert self._unique_classes is not None + # This is needed only if self._self_balancing is True (the default option) + self._avg_num_of_samples_per_node = self.dataset.num_rows / self._num_partitions + + # Change targets list data type to numpy + targets = np.array(self.dataset[self._partition_by]) + + # Repeat the sampling procedure based on the Dirichlet distribution until the + # min_partition_size is reached. + sampling_try = 0 + while True: + # Prepare data structure to store indices assigned to node ids + node_id_to_indices: Dict[int, List[int]] = {} + for nid in range(self._num_partitions): + node_id_to_indices[nid] = [] + + # Iterated over all unique labels (they are not necessarily of type int) + for k in self._unique_classes: + # Access all the indices associated with class k + indices_representing_class_k = np.nonzero(targets == k)[0] + # Determine division (the fractions) of the data representing class k + # among the partitions + class_k_division_proportions = self._rng.dirichlet(self._alpha) + nid_to_proportion_of_k_samples = {} + for nid in range(self._num_partitions): + nid_to_proportion_of_k_samples[nid] = class_k_division_proportions[ + nid + ] + # Balancing (not mentioned in the paper but implemented) + # Do not assign additional samples to the node if it already has more + # than the average numbers of samples per partition. Note that it might + # especially affect classes that are later in the order. This is the + # reason for more sparse division that the alpha might suggest. + if self._self_balancing: + assert self._avg_num_of_samples_per_node is not None + for nid in nid_to_proportion_of_k_samples.copy(): + if ( + len(node_id_to_indices[nid]) + > self._avg_num_of_samples_per_node + ): + nid_to_proportion_of_k_samples[nid] = 0 + + # Normalize the proportions such that they sum up to 1 + sum_proportions = sum(nid_to_proportion_of_k_samples.values()) + for nid, prop in nid_to_proportion_of_k_samples.copy().items(): + nid_to_proportion_of_k_samples[nid] = prop / sum_proportions + + # Determine the split indices + cumsum_division_fractions = np.cumsum( + list(nid_to_proportion_of_k_samples.values()) + ) + cumsum_division_numbers = cumsum_division_fractions * len( + indices_representing_class_k + ) + # [:-1] is because the np.split requires the division indices but the + # last element represents the sum = total number of samples + indices_on_which_split = cumsum_division_numbers.astype(int)[:-1] + + split_indices = np.split( + indices_representing_class_k, indices_on_which_split + ) + + # Append new indices (coming from class k) to the existing indices + for nid, indices in node_id_to_indices.items(): + indices.extend(split_indices[nid].tolist()) + + # Determine if the indices assignment meets the min_partition_size + # If it does not mean the requirement repeat the Dirichlet sampling process + # Otherwise break the while loop + min_sample_size_on_client = min( + len(indices) for indices in node_id_to_indices.values() + ) + if min_sample_size_on_client >= self._min_partition_size: + break + sample_sizes = [len(indices) for indices in node_id_to_indices.values()] + alpha_not_met = [ + self._alpha[i] + for i, ss in enumerate(sample_sizes) + if ss == min(sample_sizes) + ] + mssg_list_alphas = ( + ( + "Generating partitions by sampling from a list of very wide range " + "of alpha values can be hard to achieve. Try reducing the range " + f"between maximum ({max(self._alpha)}) and minimum alpha " + f"({min(self._alpha)}) values or increasing all the values." + ) + if len(self._alpha.flatten().tolist()) > 0 + else "" + ) + warnings.warn( + f"The specified min_partition_size ({self._min_partition_size}) was " + f"not satisfied for alpha ({alpha_not_met}) after " + f"{sampling_try} attempts at sampling from the Dirichlet " + f"distribution. The probability sampling from the Dirichlet " + f"distribution will be repeated. Note: This is not a desired " + f"behavior. It is recommended to adjust the alpha or " + f"min_partition_size instead. {mssg_list_alphas}", + stacklevel=1, + ) + if sampling_try == 10: + raise ValueError( + "The max number of attempts (10) was reached. " + "Please update the values of alpha and try again." + ) + sampling_try += 1 + + # Shuffle the indices not to have the datasets with targets in sequences like + # [00000, 11111, ...]) if the shuffle is True + if self._shuffle: + for indices in node_id_to_indices.values(): + # In place shuffling + self._rng.shuffle(indices) + self._node_id_to_indices = node_id_to_indices + self._node_id_to_indices_determined = True + + def _check_num_partitions_correctness_if_needed(self) -> None: + """Test num_partitions when the dataset is given (in load_partition).""" + if not self._node_id_to_indices_determined: + if self._num_partitions > self.dataset.num_rows: + raise ValueError( + "The number of partitions needs to be smaller than the number of " + "samples in the dataset." + ) + + def _check_num_partitions_greater_than_zero(self) -> None: + """Test num_partition left sides correctness.""" + if not self._num_partitions > 0: + raise ValueError("The number of partitions needs to be greater than zero.") diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py new file mode 100644 index 000000000000..c123f84effb7 --- /dev/null +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py @@ -0,0 +1,170 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test DirichletPartitioner.""" + + +# pylint: disable=W0212 +import unittest +from typing import Tuple, Union + +import numpy as np +from numpy.typing import NDArray +from parameterized import parameterized + +from datasets import Dataset +from flwr_datasets.partitioner.dirichlet_partitioner import DirichletPartitioner + + +def _dummy_setup( + num_partitions: int, + alpha: Union[float, NDArray[np.float_]], + num_rows: int, + partition_by: str, + self_balancing: bool = True, +) -> Tuple[Dataset, DirichletPartitioner]: + """Create a dummy dataset and partitioner for testing.""" + data = { + partition_by: [i % 3 for i in range(num_rows)], + "features": list(range(num_rows)), + } + dataset = Dataset.from_dict(data) + partitioner = DirichletPartitioner( + num_partitions=num_partitions, + alpha=alpha, + partition_by=partition_by, + self_balancing=self_balancing, + ) + partitioner.dataset = dataset + return dataset, partitioner + + +class TestDirichletPartitionerSuccess(unittest.TestCase): + """Test DirichletPartitioner used with no exceptions.""" + + @parameterized.expand( # type: ignore + [ + # num_partitions, alpha, num_rows, partition_by + (3, 0.5, 100, "labels"), + (5, 1.0, 150, "labels"), + ] + ) + def test_valid_initialization( + self, num_partitions: int, alpha: float, num_rows: int, partition_by: str + ) -> None: + """Test if alpha is correct scaled based on the given num_partitions.""" + _, partitioner = _dummy_setup(num_partitions, alpha, num_rows, partition_by) + self.assertEqual( + ( + partitioner._num_partitions, + len(partitioner._alpha), + partitioner._partition_by, + ), + (num_partitions, num_partitions, partition_by), + ) + + def test_min_partition_size_requirement(self) -> None: + """Test if partitions are created with min partition size required.""" + _, partitioner = _dummy_setup(3, 0.5, 100, "labels") + partition_list = [partitioner.load_partition(node_id) for node_id in [0, 1, 2]] + self.assertTrue( + all(len(p) > partitioner._min_partition_size for p in partition_list) + ) + + def test_alpha_in_ndarray_initialization(self) -> None: + """Test alpha does not change when in NDArrayFloat format.""" + _, partitioner = _dummy_setup(3, np.array([1.0, 1.0, 1.0]), 100, "labels") + self.assertTrue(np.all(partitioner._alpha == np.array([1.0, 1.0, 1.0]))) + + def test__determine_node_id_to_indices(self) -> None: + """Test the determine_nod_id_to_indices matches the flag after the call.""" + num_partitions, alpha, num_rows, partition_by = 3, 0.5, 100, "labels" + _, partitioner = _dummy_setup(num_partitions, alpha, num_rows, partition_by) + partitioner._determine_node_id_to_indices_if_needed() + self.assertTrue( + partitioner._node_id_to_indices_determined + and len(partitioner._node_id_to_indices) == num_partitions + ) + + +class TestDirichletPartitionerFailure(unittest.TestCase): + """Test DirichletPartitioner failures (exceptions) by incorrect usage.""" + + @parameterized.expand([(-2,), (-1,), (3,), (4,), (100,)]) # type: ignore + def test_load_invalid_partition_index(self, partition_id): + """Test if raises when the load_partition is above the num_partitions.""" + _, partitioner = _dummy_setup(3, 0.5, 100, "labels") + with self.assertRaises(KeyError): + partitioner.load_partition(partition_id) + + @parameterized.expand( # type: ignore + [ + # alpha, num_partitions + (-0.5, 1), + (-0.5, 2), + (-0.5, 3), + (-0.5, 10), + ([0.5, 0.5, -0.5], 3), + ([-0.5, 0.5, -0.5], 3), + ([-0.5, 0.5, 0.5], 3), + ([-0.5, -0.5, -0.5], 3), + ([0.5, 0.5, -0.5, -0.5, 0.5], 5), + (np.array([0.5, 0.5, -0.5]), 3), + (np.array([-0.5, 0.5, -0.5]), 3), + (np.array([-0.5, 0.5, 0.5]), 3), + (np.array([-0.5, -0.5, -0.5]), 3), + (np.array([0.5, 0.5, -0.5, -0.5, 0.5]), 5), + ] + ) + def test_negative_values_in_alpha(self, alpha, num_partitions): + """Test if giving the negative value of alpha raises error.""" + num_rows, partition_by = 100, "labels" + with self.assertRaises(ValueError): + _, _ = _dummy_setup(num_partitions, alpha, num_rows, partition_by) + + @parameterized.expand( # type: ignore + [ + # alpha, num_partitions + # alpha greater than the num_partitions + ([0.5, 0.5], 1), + ([0.5, 0.5, 0.5], 2), + (np.array([0.5, 0.5]), 1), + (np.array([0.5, 0.5, 0.5]), 2), + (np.array([0.5, 0.5, 0.5, 0.5]), 3), + ] + ) + def test_incorrect_alpha_shape(self, alpha, num_partitions): + """Test alpha list len not matching the num_partitions.""" + with self.assertRaises(ValueError): + DirichletPartitioner( + num_partitions=num_partitions, alpha=alpha, partition_by="labels" + ) + + @parameterized.expand( # type: ignore + [(0,), (-1,), (11,), (100,)] + ) # num_partitions, + def test_invalid_num_partitions(self, num_partitions): + """Test if 0 is invalid num_partitions.""" + with self.assertRaises(ValueError): + _, partitioner = _dummy_setup( + num_partitions=num_partitions, + alpha=1.0, + num_rows=10, + partition_by="labels", + ) + partitioner.load_partition(0) + + +if __name__ == "__main__": + unittest.main() From 65f77a98bfb8fe418ddf96641f4cdbf1ba3a07f0 Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Tue, 27 Feb 2024 10:44:22 +0100 Subject: [PATCH 09/17] Add ShardPartitioner (#2792) Co-authored-by: Javier --- .../flwr_datasets/partitioner/__init__.py | 2 + .../partitioner/shard_partitioner.py | 354 ++++++++++++++++ .../partitioner/shard_partitioner_test.py | 392 ++++++++++++++++++ 3 files changed, 748 insertions(+) create mode 100644 datasets/flwr_datasets/partitioner/shard_partitioner.py create mode 100644 datasets/flwr_datasets/partitioner/shard_partitioner_test.py diff --git a/datasets/flwr_datasets/partitioner/__init__.py b/datasets/flwr_datasets/partitioner/__init__.py index 6a85f8a11749..73d048ddf3ff 100644 --- a/datasets/flwr_datasets/partitioner/__init__.py +++ b/datasets/flwr_datasets/partitioner/__init__.py @@ -21,6 +21,7 @@ from .linear_partitioner import LinearPartitioner from .natural_id_partitioner import NaturalIdPartitioner from .partitioner import Partitioner +from .shard_partitioner import ShardPartitioner from .size_partitioner import SizePartitioner from .square_partitioner import SquarePartitioner @@ -32,5 +33,6 @@ "SizePartitioner", "LinearPartitioner", "SquarePartitioner", + "ShardPartitioner", "ExponentialPartitioner", ] diff --git a/datasets/flwr_datasets/partitioner/shard_partitioner.py b/datasets/flwr_datasets/partitioner/shard_partitioner.py new file mode 100644 index 000000000000..7c86570fe487 --- /dev/null +++ b/datasets/flwr_datasets/partitioner/shard_partitioner.py @@ -0,0 +1,354 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Shard partitioner class.""" + + +# pylint: disable=R0912 +import math +from typing import Dict, List, Optional + +import numpy as np + +import datasets +from flwr_datasets.partitioner.partitioner import Partitioner + + +class ShardPartitioner(Partitioner): # pylint: disable=R0902 + """Partitioner based on shard of (typically) unique classes. + + The algorithm works as follows: the dataset is sorted by label e.g. [samples with + label 1, samples with labels 2 ...], then the shards are created, with each + shard of size = `shard_size` if provided or automatically calculated: + shards_size = len(dataset) / `num_partitions` * `num_shards_per_node`. + + A shard is just a block (chunk) of a `dataset` that contains `shard_size` + consecutive samples. There might be shards that contain samples associated with more + than a single unique label. The first case is (remember the preprocessing step sorts + the dataset by label) when a shard is constructed from samples at the boundaries of + the sorted dataset and therefore belonging to different classes e.g. the "leftover" + of samples of class 1 and the majority of class 2. The another scenario when a shard + has samples with more than one unique label is when the shard size is bigger than + the number of samples of a certain class. + + Each partition is created from `num_shards_per_node` that are chosen randomly. + + There are a few ways of partitioning data that result in certain properties + (depending on the parameters specification): + 1) same number of shards per nodes + the same shard size (specify: + a) `num_shards_per_nodes`, `shard_size`; or b) `num_shards_per_node`) + In case of b the `shard_size` is calculated as floor(len(dataset) / + (`num_shards_per_nodes` * `num_partitions`)) + 2) possibly different number of shards per node (use nearly all data) + the same + shard size (specify: `shard_size` + `keep_incomplete_shard=False`) + 3) possibly different number of shards per node (use all data) + possibly different + shard size (specify: `shard_size` + `keep_incomplete_shard=True`) + + + Algorithm based on the description in Communication-Efficient Learning of Deep + Networks from Decentralized Data https://arxiv.org/abs/1602.05629. This + implementation expands on the initial idea by enabling more hyperparameters + specification therefore providing more control on how partitions are created. + It enables the division obtained in original paper. + + Parameters + ---------- + num_partitions : int + The total number of partitions that the data will be divided into. + partition_by : str + Column name of the labels (targets) based on which Dirichlet sampling works. + num_shards_per_node : Optional[int] + Number of shards to assign to a single partitioner. It's an alternative to + `num_partitions`. + shard_size : Optional[int] + Size of a single shards (a partition has one or more shards). If the size is not + given it will be automatically computed. + keep_incomplete_shard : bool + Whether to drop the last shard which might be incomplete (smaller than the + others). If it is dropped each shard is equal size. (It does not mean that each + client gets equal number of shards, which only happens if + `num_partitions` % `num_shards` = 0). This parameter has no effect if + `num_shards_per_nodes` and `shard_size` are specified. + shuffle: bool + Whether to randomize the order of samples. Shuffling applied after the + samples assignment to nodes. + seed: int + Seed used for dataset shuffling. It has no effect if `shuffle` is False. + + Examples + -------- + 1) If you need same number of shards per nodes + the same shard size (and you know + both of these values) + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.partitioner import ShardPartitioner + >>> + >>> partitioner = ShardPartitioner(num_partitions=10, partition_by="label", + >>> num_shards_per_node=2, shard_size=1_000) + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) + >>> partition = fds.load_partition(0) + >>> print(partition[0]) # Print the first example + {'image': , + 'label': 3} + >>> partition_sizes = [len(fds.load_partition(node_id)) for node_id in range(10)] + >>> print(partition_sizes) + [2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000] + + 2) If you want to use nearly all the data and do not need to have the number of + shard per each node to be the same + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.partitioner import ShardPartitioner + >>> + >>> partitioner = ShardPartitioner(num_partitions=9, partition_by="label", + >>> shard_size=1_000) + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) + >>> partition_sizes = [len(fds.load_partition(node_id)) for node_id in range(9)] + >>> print(partition_sizes) + [7000, 7000, 7000, 7000, 7000, 7000, 6000, 6000, 6000] + + 3) If you want to use all the data + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.partitioner import ShardPartitioner + >>> + >>> partitioner = ShardPartitioner(num_partitions=10, partition_by="label", + >>> shard_size=990, keep_incomplete_shard=True) + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) + >>> partition_sizes = [len(fds.load_partition(node_id)) for node_id in range(10)] + >>> print(sorted(partition_sizes)) + [5550, 5940, 5940, 5940, 5940, 5940, 5940, 5940, 5940, 6930] + """ + + def __init__( # pylint: disable=R0913 + self, + num_partitions: int, + partition_by: str, + num_shards_per_node: Optional[int] = None, + shard_size: Optional[int] = None, + keep_incomplete_shard: bool = False, + shuffle: bool = True, + seed: Optional[int] = 42, + ) -> None: + super().__init__() + # Attributes based on the constructor + _check_if_natual_number(num_partitions, "num_partitions") + self._num_partitions = num_partitions + self._partition_by = partition_by + _check_if_natual_number(num_shards_per_node, "num_shards_per_node", True) + self._num_shards_per_node = num_shards_per_node + self._num_shards_used: Optional[int] = None + _check_if_natual_number(shard_size, "shard_size", True) + self._shard_size = shard_size + self._keep_incomplete_shard = keep_incomplete_shard + self._shuffle = shuffle + self._seed = seed + + # Utility attributes + self._rng = np.random.default_rng(seed=self._seed) # NumPy random generator + self._node_id_to_indices: Dict[int, List[int]] = {} + self._node_id_to_indices_determined = False + + def load_partition(self, node_id: int) -> datasets.Dataset: + """Load a partition based on the partition index. + + Parameters + ---------- + node_id : int + the index that corresponds to the requested partition + + Returns + ------- + dataset_partition : Dataset + single partition of a dataset + """ + # The partitioning is done lazily - only when the first partition is + # requested. Only the first call creates the indices assignments for all the + # partition indices. + self._check_num_partitions_correctness_if_needed() + self._check_possibility_of_partitions_creation() + self._sort_dataset_if_needed() + self._determine_node_id_to_indices_if_needed() + return self.dataset.select(self._node_id_to_indices[node_id]) + + def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0914 + """Assign sample indices to each node id. + + This method works on sorted datasets. A "shard" is a part of the dataset of + consecutive samples (if self._keep_incomplete_shard is False, each shard is same + size). + """ + # No need to do anything if that node_id_to_indices are already determined + if self._node_id_to_indices_determined: + return + + # One of the specification allows to skip the `num_shards_per_node` param + if self._num_shards_per_node is not None: + self._num_shards_used = int( + self._num_partitions * self._num_shards_per_node + ) + num_shards_per_node_array = ( + np.ones(self._num_partitions) * self._num_shards_per_node + ) + if self._shard_size is None: + self._compute_shard_size_if_missing() + assert self._shard_size is not None + if self._keep_incomplete_shard: + num_usable_shards_in_dataset = int( + math.ceil(len(self.dataset) / self._shard_size) + ) + else: + num_usable_shards_in_dataset = int( + math.floor(len(self.dataset) / self._shard_size) + ) + else: + num_usable_shards_in_dataset = int( + math.floor(len(self.dataset) / self._shard_size) + ) + elif self._num_shards_per_node is None: + if self._shard_size is None: + raise ValueError( + "The shard_size needs to be specified if the " + "num_shards_per_node is None" + ) + if self._keep_incomplete_shard is False: + self._num_shards_used = int( + math.floor(len(self.dataset) / self._shard_size) + ) + num_usable_shards_in_dataset = self._num_shards_used + elif self._keep_incomplete_shard is True: + self._num_shards_used = int( + math.ceil(len(self.dataset) / self._shard_size) + ) + num_usable_shards_in_dataset = self._num_shards_used + if num_usable_shards_in_dataset < self._num_partitions: + raise ValueError( + "Based on the given arguments the creation of the partitions " + "is impossible. The implied number of partitions that can be " + "used is lower than the number of requested partitions " + "resulting in empty partitions. Please decrease the size of " + "shards: `shard_size`." + ) + else: + raise ValueError( + "The keep_incomplete_shards need to be specified " + "when _num_shards_per_node is None." + ) + num_shards_per_node = int(self._num_shards_used / self._num_partitions) + # Assign the shards per nodes (so far, the same as in ideal case) + num_shards_per_node_array = ( + np.ones(self._num_partitions) * num_shards_per_node + ) + num_shards_assigned = self._num_partitions * num_shards_per_node + num_shards_to_assign = self._num_shards_used - num_shards_assigned + # Assign the "missing" shards + for i in range(num_shards_to_assign): + num_shards_per_node_array[i] += 1 + + else: + raise ValueError( + "The specification of nm_shards_per_node and " + "keep_incomplete_shards is not correct." + ) + + if num_usable_shards_in_dataset < self._num_partitions: + raise ValueError( + "The specified configuration results in empty partitions because the " + "number of usable shards is smaller that the number partitions. " + "Try decreasing the shard size or the number of partitions. " + ) + + indices_on_which_to_split_shards = np.cumsum( + num_shards_per_node_array, dtype=int + ) + + shard_indices_array = self._rng.permutation(num_usable_shards_in_dataset)[ + : self._num_shards_used + ] + # Randomly assign shards to node_id + nid_to_shard_indices = np.split( + shard_indices_array, indices_on_which_to_split_shards + )[:-1] + node_id_to_indices: Dict[int, List[int]] = { + cid: [] for cid in range(self._num_partitions) + } + # Compute node_id to sample indices based on the shard indices + for node_id in range(self._num_partitions): + for shard_idx in nid_to_shard_indices[node_id]: + start_id = int(shard_idx * self._shard_size) + end_id = min(int((shard_idx + 1) * self._shard_size), len(self.dataset)) + node_id_to_indices[node_id].extend(list(range(start_id, end_id))) + if self._shuffle: + for indices in node_id_to_indices.values(): + # In place shuffling + self._rng.shuffle(indices) + self._node_id_to_indices = node_id_to_indices + self._node_id_to_indices_determined = True + + def _check_num_partitions_correctness_if_needed(self) -> None: + """Test num_partitions when the dataset is given (in load_partition).""" + if not self._node_id_to_indices_determined: + if self._num_partitions > self.dataset.num_rows: + raise ValueError( + "The number of partitions needs to be smaller than the number of " + "samples in the dataset." + ) + + def _sort_dataset_if_needed(self) -> None: + """Sort dataset prior to determining the partitions. + + Operation only needed to be performed one time. It's required for the creation + of shards with the same labels. + """ + if self._node_id_to_indices_determined: + return + self._dataset = self.dataset.sort(self._partition_by) + + def _compute_shard_size_if_missing(self) -> None: + """Compute the parameters needed to perform sharding. + + This method should be called after the dataset is assigned. + """ + if self._shard_size is None: + # If shard size is not specified it needs to be computed + num_rows = self.dataset.num_rows + self._shard_size = int(num_rows / self._num_shards_used) + + def _check_possibility_of_partitions_creation(self) -> None: + if self._shard_size is not None and self._num_shards_per_node is not None: + implied_min_dataset_size = ( + self._shard_size * self._num_shards_per_node * self._num_partitions + ) + if implied_min_dataset_size > len(self.dataset): + raise ValueError( + f"Based on the given arguments the creation of the " + "partitions is impossible. The implied minimum dataset" + f"size is {implied_min_dataset_size} but the dataset" + f"size is {len(self.dataset)}" + ) + + +def _check_if_natual_number( + number: Optional[int], parameter_name: str, none_acceptable: bool = False +) -> None: + if none_acceptable and number is None: + return + if not isinstance(number, int): + raise TypeError( + f"The expected type of {parameter_name} is int but given: {number} of type " + f"{type(number)}. Please specify the correct type." + ) + if not number >= 1: + raise ValueError( + f"The expected value of {parameter_name} is >= 1 (greater or equal to 1) " + f"but given: {number} which does not meet this condition. Please " + f"provide a correct number." + ) diff --git a/datasets/flwr_datasets/partitioner/shard_partitioner_test.py b/datasets/flwr_datasets/partitioner/shard_partitioner_test.py new file mode 100644 index 000000000000..47968699bba7 --- /dev/null +++ b/datasets/flwr_datasets/partitioner/shard_partitioner_test.py @@ -0,0 +1,392 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test ShardPartitioner.""" + + +# pylint: disable=W0212, R0913 +import unittest +from typing import Optional, Tuple + +from datasets import Dataset +from flwr_datasets.partitioner.shard_partitioner import ShardPartitioner + + +def _dummy_setup( + num_rows: int, + partition_by: str, + num_partitions: int, + num_shards_per_node: Optional[int], + shard_size: Optional[int], + keep_incomplete_shard: bool = False, +) -> Tuple[Dataset, ShardPartitioner]: + """Create a dummy dataset for testing.""" + data = { + partition_by: [i % 3 for i in range(num_rows)], + "features": list(range(num_rows)), + } + dataset = Dataset.from_dict(data) + partitioner = ShardPartitioner( + num_partitions=num_partitions, + num_shards_per_node=num_shards_per_node, + partition_by=partition_by, + shard_size=shard_size, + keep_incomplete_shard=keep_incomplete_shard, + ) + partitioner.dataset = dataset + return dataset, partitioner + + +class TestShardPartitionerSpec1(unittest.TestCase): + """Test first possible initialization of ShardPartitioner. + + Specify num_shards_per_node and shard_size arguments. + """ + + def test_correct_num_partitions(self) -> None: + """Test the correct number of partitions is created.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = 3 + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + _ = partitioner.load_partition(0) + num_partitions_created = len(partitioner._node_id_to_indices.keys()) + self.assertEqual(num_partitions_created, num_partitions) + + def test_correct_partition_sizes(self) -> None: + """Test if the partitions sizes are as theoretically calculated.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = 3 + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + sizes = [len(partitioner.load_partition(i)) for i in range(num_partitions)] + sizes = sorted(sizes) + self.assertEqual(sizes, [30, 30, 30]) + + def test_unique_samples(self) -> None: + """Test if each partition has unique samples. + + (No duplicates along partitions). + """ + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = 3 + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + partitions = [ + partitioner.load_partition(i)["features"] for i in range(num_partitions) + ] + combined_list = [item for sublist in partitions for item in sublist] + combined_set = set(combined_list) + self.assertEqual(len(combined_list), len(combined_set)) + + +class TestShardPartitionerSpec2(unittest.TestCase): + """Test second possible initialization of ShardPartitioner. + + Specify shard_size and keep_incomplete_shard=False. This setting creates partitions + that might have various sizes (each shard is same size). + """ + + def test_correct_num_partitions(self) -> None: + """Test the correct number of partitions is created.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = None + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + _ = partitioner.load_partition(0) + num_partitions_created = len(partitioner._node_id_to_indices.keys()) + self.assertEqual(num_partitions_created, num_partitions) + + def test_correct_partition_sizes(self) -> None: + """Test if the partitions sizes are as theoretically calculated.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = None + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + sizes = [len(partitioner.load_partition(i)) for i in range(num_partitions)] + sizes = sorted(sizes) + self.assertEqual(sizes, [30, 40, 40]) + + def test_unique_samples(self) -> None: + """Test if each partition has unique samples. + + (No duplicates along partitions). + """ + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = None + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + partitions = [ + partitioner.load_partition(i)["features"] for i in range(num_partitions) + ] + combined_list = [item for sublist in partitions for item in sublist] + combined_set = set(combined_list) + self.assertEqual(len(combined_list), len(combined_set)) + + +class TestShardPartitionerSpec3(unittest.TestCase): + """Test third possible initialization of ShardPartitioner. + + Specify shard_size and keep_incomplete_shard=True. This setting creates partitions + that might have various sizes (each shard is same size). + """ + + def test_correct_num_partitions(self) -> None: + """Test the correct number of partitions is created.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = None + shard_size = 10 + keep_incomplete_shard = True + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + _ = partitioner.load_partition(0) + num_partitions_created = len(partitioner._node_id_to_indices.keys()) + self.assertEqual(num_partitions_created, num_partitions) + + def test_correct_partition_sizes(self) -> None: + """Test if the partitions sizes are as theoretically calculated.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = None + shard_size = 10 + keep_incomplete_shard = True + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + sizes = [len(partitioner.load_partition(i)) for i in range(num_partitions)] + sizes = sorted(sizes) + self.assertEqual(sizes, [33, 40, 40]) + + def test_unique_samples(self) -> None: + """Test if each partition has unique samples. + + (No duplicates along partitions). + """ + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = None + shard_size = 10 + keep_incomplete_shard = True + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + partitions = [ + partitioner.load_partition(i)["features"] for i in range(num_partitions) + ] + combined_list = [item for sublist in partitions for item in sublist] + combined_set = set(combined_list) + self.assertEqual(len(combined_list), len(combined_set)) + + +class TestShardPartitionerSpec4(unittest.TestCase): + """Test fourth possible initialization of ShardPartitioner. + + Specify num_shards_per_node but not shard_size arguments. + """ + + def test_correct_num_partitions(self) -> None: + """Test the correct number of partitions is created.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = 3 + shard_size = None + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + _ = partitioner.load_partition(0) + num_partitions_created = len(partitioner._node_id_to_indices.keys()) + self.assertEqual(num_partitions_created, num_partitions) + + def test_correct_partition_sizes(self) -> None: + """Test if the partitions sizes are as theoretically calculated.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = 3 + shard_size = None + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + sizes = [len(partitioner.load_partition(i)) for i in range(num_partitions)] + sizes = sorted(sizes) + self.assertEqual(sizes, [36, 36, 36]) + + def test_unique_samples(self) -> None: + """Test if each partition has unique samples. + + (No duplicates along partitions). + """ + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = 3 + shard_size = None + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + partitions = [ + partitioner.load_partition(i)["features"] for i in range(num_partitions) + ] + combined_list = [item for sublist in partitions for item in sublist] + combined_set = set(combined_list) + self.assertEqual(len(combined_list), len(combined_set)) + + +class TestShardPartitionerIncorrectSpec(unittest.TestCase): + """Test the incorrect specification cases. + + The lack of correctness can be caused by the num_partitions, shard_size and + num_shards_per_partition can create. + """ + + def test_incorrect_specification(self) -> None: + """Test if the given specification makes the partitioning possible.""" + partition_by = "label" + num_rows = 10 + num_partitions = 3 + num_shards_per_node = 2 + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + with self.assertRaises(ValueError): + _ = partitioner.load_partition(0) + + def test_too_big_shard_size(self) -> None: + """Test if it is impossible to create an empty partition.""" + partition_by = "label" + num_rows = 20 + num_partitions = 3 + num_shards_per_node = None + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + with self.assertRaises(ValueError): + _ = partitioner.load_partition(2).num_rows + + +if __name__ == "__main__": + unittest.main() From 5a3679a93137bab6f120c06a44983ae2f5a92ed3 Mon Sep 17 00:00:00 2001 From: mohammadnaseri Date: Tue, 27 Feb 2024 13:08:45 +0000 Subject: [PATCH 10/17] Introduce central DP (client-side fixed clipping) (#2893) --- src/py/flwr/client/mod/__init__.py | 2 + src/py/flwr/client/mod/centraldp_mods.py | 76 +++++++++ .../common/differential_privacy_constants.py | 2 + src/py/flwr/server/strategy/__init__.py | 6 +- .../flwr/server/strategy/dp_fixed_clipping.py | 158 +++++++++++++++++- 5 files changed, 241 insertions(+), 3 deletions(-) create mode 100644 src/py/flwr/client/mod/centraldp_mods.py diff --git a/src/py/flwr/client/mod/__init__.py b/src/py/flwr/client/mod/__init__.py index a181865614df..8ed6f52ef788 100644 --- a/src/py/flwr/client/mod/__init__.py +++ b/src/py/flwr/client/mod/__init__.py @@ -15,10 +15,12 @@ """Mods.""" +from .centraldp_mods import fixedclipping_mod from .secure_aggregation.secaggplus_mod import secaggplus_mod from .utils import make_ffn __all__ = [ "make_ffn", "secaggplus_mod", + "fixedclipping_mod", ] diff --git a/src/py/flwr/client/mod/centraldp_mods.py b/src/py/flwr/client/mod/centraldp_mods.py new file mode 100644 index 000000000000..76bfe1b06e55 --- /dev/null +++ b/src/py/flwr/client/mod/centraldp_mods.py @@ -0,0 +1,76 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Clipping modifiers for central DP with client-side clipping.""" + + +from flwr.client.typing import ClientAppCallable +from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays +from flwr.common import recordset_compat as compat +from flwr.common.constant import MESSAGE_TYPE_FIT +from flwr.common.context import Context +from flwr.common.differential_privacy import compute_clip_model_update +from flwr.common.differential_privacy_constants import KEY_CLIPPING_NORM +from flwr.common.message import Message + + +def fixedclipping_mod( + msg: Message, ctxt: Context, call_next: ClientAppCallable +) -> Message: + """Client-side fixed clipping modifier. + + This mod needs to be used with the DifferentialPrivacyClientSideFixedClipping + server-side strategy wrapper. + + The wrapper sends the clipping_norm value to the client. + + This mod clips the client model updates before sending them to the server. + + It operates on messages with type MESSAGE_TYPE_FIT. + + Notes + ----- + Consider the order of mods when using multiple. + + Typically, fixedclipping_mod should be the last to operate on params. + """ + if msg.metadata.message_type != MESSAGE_TYPE_FIT: + return call_next(msg, ctxt) + fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True) + if KEY_CLIPPING_NORM not in fit_ins.config: + raise KeyError( + f"The {KEY_CLIPPING_NORM} value is not supplied by the " + f"DifferentialPrivacyClientSideFixedClipping wrapper at" + f" the server side." + ) + + clipping_norm = float(fit_ins.config[KEY_CLIPPING_NORM]) + server_to_client_params = parameters_to_ndarrays(fit_ins.parameters) + + # Call inner app + out_msg = call_next(msg, ctxt) + fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True) + + client_to_server_params = parameters_to_ndarrays(fit_res.parameters) + + # Clip the client update + compute_clip_model_update( + client_to_server_params, + server_to_client_params, + clipping_norm, + ) + + fit_res.parameters = ndarrays_to_parameters(client_to_server_params) + out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True) + return out_msg diff --git a/src/py/flwr/common/differential_privacy_constants.py b/src/py/flwr/common/differential_privacy_constants.py index 9ec080975aa3..61ba3ba9a958 100644 --- a/src/py/flwr/common/differential_privacy_constants.py +++ b/src/py/flwr/common/differential_privacy_constants.py @@ -14,6 +14,8 @@ # ============================================================================== """Constants for differential privacy.""" + +KEY_CLIPPING_NORM = "clipping_norm" CLIENTS_DISCREPANCY_WARNING = ( "The number of clients returning parameters (%s)" " differs from the number of sampled clients (%s)." diff --git a/src/py/flwr/server/strategy/__init__.py b/src/py/flwr/server/strategy/__init__.py index a31f5d48b77c..6b058c3df05b 100644 --- a/src/py/flwr/server/strategy/__init__.py +++ b/src/py/flwr/server/strategy/__init__.py @@ -16,7 +16,10 @@ from .bulyan import Bulyan as Bulyan -from .dp_fixed_clipping import DifferentialPrivacyServerSideFixedClipping +from .dp_fixed_clipping import ( + DifferentialPrivacyClientSideFixedClipping, + DifferentialPrivacyServerSideFixedClipping, +) from .dpfedavg_adaptive import DPFedAvgAdaptive as DPFedAvgAdaptive from .dpfedavg_fixed import DPFedAvgFixed as DPFedAvgFixed from .fault_tolerant_fedavg import FaultTolerantFedAvg as FaultTolerantFedAvg @@ -59,4 +62,5 @@ "DPFedAvgFixed", "Strategy", "DifferentialPrivacyServerSideFixedClipping", + "DifferentialPrivacyClientSideFixedClipping", ] diff --git a/src/py/flwr/server/strategy/dp_fixed_clipping.py b/src/py/flwr/server/strategy/dp_fixed_clipping.py index d18a6b2079d9..fd1809125d3e 100644 --- a/src/py/flwr/server/strategy/dp_fixed_clipping.py +++ b/src/py/flwr/server/strategy/dp_fixed_clipping.py @@ -36,7 +36,10 @@ add_gaussian_noise_to_params, compute_clip_model_update, ) -from flwr.common.differential_privacy_constants import CLIENTS_DISCREPANCY_WARNING +from flwr.common.differential_privacy_constants import ( + CLIENTS_DISCREPANCY_WARNING, + KEY_CLIPPING_NORM, +) from flwr.common.logger import log from flwr.server.client_manager import ClientManager from flwr.server.client_proxy import ClientProxy @@ -44,7 +47,8 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy): - """Wrapper for Central DP with Server Side Fixed Clipping. + """Strategy wrapper for central differential privacy with server-side fixed + clipping. Parameters ---------- @@ -185,3 +189,153 @@ def evaluate( ) -> Optional[Tuple[float, Dict[str, Scalar]]]: """Evaluate model parameters using an evaluation function from the strategy.""" return self.strategy.evaluate(server_round, parameters) + + +class DifferentialPrivacyClientSideFixedClipping(Strategy): + """Strategy wrapper for central differential privacy with client-side fixed + clipping. + + Use `fixedclipping_mod` modifier at the client side. + + In comparison to `DifferentialPrivacyServerSideFixedClipping`, + which performs clipping on the server-side, `DifferentialPrivacyClientSideFixedClipping` + expects clipping to happen on the client-side, usually by using the built-in + `fixedclipping_mod `. + + Parameters + ---------- + strategy : Strategy + The strategy to which DP functionalities will be added by this wrapper. + noise_multiplier : float + The noise multiplier for the Gaussian mechanism for model updates. + A value of 1.0 or higher is recommended for strong privacy. + clipping_norm : float + The value of the clipping norm. + num_sampled_clients : int + The number of clients that are sampled on each round. + + Examples + -------- + Create a strategy: + + >>> strategy = fl.server.strategy.FedAvg(...) + + Wrap the strategy with the `DifferentialPrivacyServerSideFixedClipping` wrapper: + + >>> DifferentialPrivacyClientSideFixedClipping( + >>> strategy, cfg.noise_multiplier, cfg.clipping_norm, cfg.num_sampled_clients + >>> ) + + On the client, add the `fixedclipping_mod` to the client-side mods: + + >>> app = fl.client.ClientApp( + >>> client_fn=FlowerClient().to_client(), mods=[fixedclipping_mod] + >>> ) + """ + + # pylint: disable=too-many-arguments,too-many-instance-attributes + def __init__( + self, + strategy: Strategy, + noise_multiplier: float, + clipping_norm: float, + num_sampled_clients: int, + ) -> None: + super().__init__() + + self.strategy = strategy + + if noise_multiplier < 0: + raise ValueError("The noise multiplier should be a non-negative value.") + + if clipping_norm <= 0: + raise ValueError("The clipping threshold should be a positive value.") + + if num_sampled_clients <= 0: + raise ValueError( + "The number of sampled clients should be a positive value." + ) + + self.noise_multiplier = noise_multiplier + self.clipping_norm = clipping_norm + self.num_sampled_clients = num_sampled_clients + + def __repr__(self) -> str: + """Compute a string representation of the strategy.""" + rep = "Differential Privacy Strategy Wrapper (Client-Side Fixed Clipping)" + return rep + + def initialize_parameters( + self, client_manager: ClientManager + ) -> Optional[Parameters]: + """Initialize global model parameters using given strategy.""" + return self.strategy.initialize_parameters(client_manager) + + def configure_fit( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the next round of training.""" + additional_config = {KEY_CLIPPING_NORM: self.clipping_norm} + inner_strategy_config_result = self.strategy.configure_fit( + server_round, parameters, client_manager + ) + for _, fit_ins in inner_strategy_config_result: + fit_ins.config.update(additional_config) + + return inner_strategy_config_result + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """Configure the next round of evaluation.""" + return self.strategy.configure_evaluate( + server_round, parameters, client_manager + ) + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Add noise to the aggregated parameters.""" + if failures: + return None, {} + + if len(results) != self.num_sampled_clients: + log( + WARNING, + CLIENTS_DISCREPANCY_WARNING, + len(results), + self.num_sampled_clients, + ) + + # Pass the new parameters for aggregation + aggregated_params, metrics = self.strategy.aggregate_fit( + server_round, results, failures + ) + + # Add Gaussian noise to the aggregated parameters + if aggregated_params: + aggregated_params = add_gaussian_noise_to_params( + aggregated_params, + self.noise_multiplier, + self.clipping_norm, + self.num_sampled_clients, + ) + return aggregated_params, metrics + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation losses using the given strategy.""" + return self.strategy.aggregate_evaluate(server_round, results, failures) + + def evaluate( + self, server_round: int, parameters: Parameters + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + """Evaluate model parameters using an evaluation function from the strategy.""" + return self.strategy.evaluate(server_round, parameters) From 19d0b7097ffd566d0d7ea818e0e4e7faba130e57 Mon Sep 17 00:00:00 2001 From: Javier Date: Tue, 27 Feb 2024 16:16:45 +0100 Subject: [PATCH 11/17] fix (#3022) Co-authored-by: Taner Topal --- src/py/flwr/cli/new/new_test.py | 34 +++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/py/flwr/cli/new/new_test.py b/src/py/flwr/cli/new/new_test.py index 39717bc67ab3..7a4832013b0c 100644 --- a/src/py/flwr/cli/new/new_test.py +++ b/src/py/flwr/cli/new/new_test.py @@ -77,17 +77,23 @@ def test_new(tmp_path: str) -> None: "client.py", } - # Change into the temprorary directory - os.chdir(tmp_path) - - # Execute - new(project_name=project_name, framework=framework) - - # Assert - file_list = os.listdir(os.path.join(tmp_path, project_name.lower())) - assert set(file_list) == expected_files_top_level - - file_list = os.listdir( - os.path.join(tmp_path, project_name.lower(), project_name.lower()) - ) - assert set(file_list) == expected_files_module + # Current directory + origin = os.getcwd() + + try: + # Change into the temprorary directory + os.chdir(tmp_path) + + # Execute + new(project_name=project_name, framework=framework) + + # Assert + file_list = os.listdir(os.path.join(tmp_path, project_name.lower())) + assert set(file_list) == expected_files_top_level + + file_list = os.listdir( + os.path.join(tmp_path, project_name.lower(), project_name.lower()) + ) + assert set(file_list) == expected_files_module + finally: + os.chdir(origin) From 06a7c57b0b90df0ab49565a5dbc695f654a59614 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Tue, 27 Feb 2024 16:28:55 +0100 Subject: [PATCH 12/17] Add new common function for registering exit handlers (#3021) --- src/py/flwr/client/app.py | 3 +- src/py/flwr/common/exit_handlers.py | 87 +++++++++++++++++++++++++++++ src/py/flwr/server/app.py | 61 +++----------------- 3 files changed, 96 insertions(+), 55 deletions(-) create mode 100644 src/py/flwr/common/exit_handlers.py diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 6c3c3a99680d..2d49b3a44f08 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -34,6 +34,7 @@ TRANSPORT_TYPE_REST, TRANSPORT_TYPES, ) +from flwr.common.exit_handlers import register_exit_handlers from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature from .clientapp import load_client_app @@ -104,7 +105,7 @@ def _load() -> ClientApp: root_certificates=root_certificates, insecure=args.insecure, ) - event(EventType.RUN_CLIENT_APP_LEAVE) + register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE) def _parse_args_run_client_app() -> argparse.ArgumentParser: diff --git a/src/py/flwr/common/exit_handlers.py b/src/py/flwr/common/exit_handlers.py new file mode 100644 index 000000000000..30750c28a450 --- /dev/null +++ b/src/py/flwr/common/exit_handlers.py @@ -0,0 +1,87 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Common function to register exit handlers for server and client.""" + + +import sys +from signal import SIGINT, SIGTERM, signal +from threading import Thread +from types import FrameType +from typing import List, Optional + +from grpc import Server + +from flwr.common.telemetry import EventType, event + + +def register_exit_handlers( + event_type: EventType, + grpc_servers: Optional[List[Server]] = None, + bckg_threads: Optional[List[Thread]] = None, +) -> None: + """Register exit handlers for `SIGINT` and `SIGTERM` signals. + + Parameters + ---------- + event_type : EventType + The telemetry event that should be logged before exit. + grpc_servers: Optional[List[Server]] (default: None) + An otpional list of gRPC servers that need to be gracefully + terminated before exiting. + bckg_threads: Optional[List[Thread]] (default: None) + An optional list of threads that need to be gracefully + terminated before exiting. + """ + default_handlers = { + SIGINT: None, + SIGTERM: None, + } + + def graceful_exit_handler( # type: ignore + signalnum, + frame: FrameType, # pylint: disable=unused-argument + ) -> None: + """Exit handler to be registered with `signal.signal`. + + When called will reset signal handler to original signal handler from + default_handlers. + """ + # Reset to default handler + signal(signalnum, default_handlers[signalnum]) + + event_res = event(event_type=event_type) + + if grpc_servers is not None: + for grpc_server in grpc_servers: + grpc_server.stop(grace=1) + + if bckg_threads is not None: + for bckg_thread in bckg_threads: + bckg_thread.join() + + # Ensure event has happend + event_res.result() + + # Setup things for graceful exit + sys.exit(0) + + default_handlers[SIGINT] = signal( # type: ignore + SIGINT, + graceful_exit_handler, # type: ignore + ) + default_handlers[SIGTERM] = signal( # type: ignore + SIGTERM, + graceful_exit_handler, # type: ignore + ) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index ac7a8339b31d..a4913d51315b 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -22,8 +22,6 @@ from logging import ERROR, INFO, WARN from os.path import isfile from pathlib import Path -from signal import SIGINT, SIGTERM, signal -from types import FrameType from typing import List, Optional, Tuple import grpc @@ -36,6 +34,7 @@ TRANSPORT_TYPE_REST, TRANSPORT_TYPE_VCE, ) +from flwr.common.exit_handlers import register_exit_handlers from flwr.common.logger import log from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611 add_DriverServicer_to_server, @@ -212,10 +211,10 @@ def run_driver_api() -> None: ) # Graceful shutdown - _register_exit_handlers( + register_exit_handlers( + event_type=EventType.RUN_DRIVER_API_LEAVE, grpc_servers=[grpc_server], bckg_threads=[], - event_type=EventType.RUN_DRIVER_API_LEAVE, ) # Block @@ -280,10 +279,10 @@ def run_fleet_api() -> None: raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}") # Graceful shutdown - _register_exit_handlers( + register_exit_handlers( + event_type=EventType.RUN_FLEET_API_LEAVE, grpc_servers=grpc_servers, bckg_threads=bckg_threads, - event_type=EventType.RUN_FLEET_API_LEAVE, ) # Block @@ -375,10 +374,10 @@ def run_superlink() -> None: raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}") # Graceful shutdown - _register_exit_handlers( + register_exit_handlers( + event_type=EventType.RUN_SUPERLINK_LEAVE, grpc_servers=grpc_servers, bckg_threads=bckg_threads, - event_type=EventType.RUN_SUPERLINK_LEAVE, ) # Block @@ -413,52 +412,6 @@ def _try_obtain_certificates( return certificates -def _register_exit_handlers( - grpc_servers: List[grpc.Server], - bckg_threads: List[threading.Thread], - event_type: EventType, -) -> None: - default_handlers = { - SIGINT: None, - SIGTERM: None, - } - - def graceful_exit_handler( # type: ignore - signalnum, - frame: FrameType, # pylint: disable=unused-argument - ) -> None: - """Exit handler to be registered with signal.signal. - - When called will reset signal handler to original signal handler from - default_handlers. - """ - # Reset to default handler - signal(signalnum, default_handlers[signalnum]) - - event_res = event(event_type=event_type) - - for grpc_server in grpc_servers: - grpc_server.stop(grace=1) - - for bckg_thread in bckg_threads: - bckg_thread.join() - - # Ensure event has happend - event_res.result() - - # Setup things for graceful exit - sys.exit(0) - - default_handlers[SIGINT] = signal( # type: ignore - SIGINT, - graceful_exit_handler, # type: ignore - ) - default_handlers[SIGTERM] = signal( # type: ignore - SIGTERM, - graceful_exit_handler, # type: ignore - ) - - def _run_driver_api_grpc( address: str, state_factory: StateFactory, From 81fadda74398fd7f9999023df32239c8a5ccfc36 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Tue, 27 Feb 2024 20:28:18 +0100 Subject: [PATCH 13/17] Rename `clientapp` to `client_app` (#3020) --- doc/locales/fr/LC_MESSAGES/framework-docs.po | 14 +++++++------- doc/locales/pt_BR/LC_MESSAGES/framework-docs.po | 12 ++++++------ doc/locales/zh_Hans/LC_MESSAGES/framework-docs.po | 12 ++++++------ src/py/flwr/client/__init__.py | 2 +- src/py/flwr/client/app.py | 4 ++-- src/py/flwr/client/{clientapp.py => client_app.py} | 0 .../server/superlink/fleet/vce/backend/backend.py | 2 +- .../superlink/fleet/vce/backend/raybackend.py | 2 +- .../superlink/fleet/vce/backend/raybackend_test.py | 2 +- src/py/flwr/server/superlink/fleet/vce/vce_api.py | 2 +- src/py/flwr/simulation/ray_transport/ray_actor.py | 2 +- .../simulation/ray_transport/ray_client_proxy.py | 2 +- .../ray_transport/ray_client_proxy_test.py | 2 +- 13 files changed, 29 insertions(+), 29 deletions(-) rename src/py/flwr/client/{clientapp.py => client_app.py} (100%) diff --git a/doc/locales/fr/LC_MESSAGES/framework-docs.po b/doc/locales/fr/LC_MESSAGES/framework-docs.po index ba5ea5ec070d..cace04c6980f 100644 --- a/doc/locales/fr/LC_MESSAGES/framework-docs.po +++ b/doc/locales/fr/LC_MESSAGES/framework-docs.po @@ -8316,10 +8316,10 @@ msgid "" msgstr "" #: ../../source/ref-api/flwr.client.rst:33::1 -#: flwr.client.clientapp.ClientApp:1 of +#: flwr.client.client_app.ClientApp:1 of #, fuzzy msgid "Flower ClientApp." -msgstr "Client de Flower" +msgstr "Flower ClientApp." #: ../../source/ref-api/flwr.client.rst:33::1 msgid ":py:obj:`NumPyClient `\\ \\(\\)" @@ -8623,7 +8623,7 @@ msgstr "" msgid "ClientApp" msgstr "client" -#: flwr.client.clientapp.ClientApp:1 flwr.common.typing.ClientMessage:1 +#: flwr.client.client_app.ClientApp:1 flwr.common.typing.ClientMessage:1 #: flwr.common.typing.DisconnectRes:1 flwr.common.typing.EvaluateIns:1 #: flwr.common.typing.EvaluateRes:1 flwr.common.typing.FitIns:1 #: flwr.common.typing.FitRes:1 flwr.common.typing.GetParametersIns:1 @@ -8638,25 +8638,25 @@ msgid "Bases: :py:class:`object`" msgstr "" #: flwr.client.app.start_client:33 flwr.client.app.start_numpy_client:36 -#: flwr.client.clientapp.ClientApp:4 flwr.server.app.start_server:41 +#: flwr.client.client_app.ClientApp:4 flwr.server.app.start_server:41 #: flwr.server.driver.app.start_driver:30 of #, fuzzy msgid "Examples" msgstr "Exemples de PyTorch" -#: flwr.client.clientapp.ClientApp:5 of +#: flwr.client.client_app.ClientApp:5 of msgid "" "Assuming a typical `Client` implementation named `FlowerClient`, you can " "wrap it in a `ClientApp` as follows:" msgstr "" -#: flwr.client.clientapp.ClientApp:16 of +#: flwr.client.client_app.ClientApp:16 of msgid "" "If the above code is in a Python module called `client`, it can be " "started as follows:" msgstr "" -#: flwr.client.clientapp.ClientApp:21 of +#: flwr.client.client_app.ClientApp:21 of msgid "" "In this `client:app` example, `client` refers to the Python module " "`client.py` in which the previous code lives in and `app` refers to the " diff --git a/doc/locales/pt_BR/LC_MESSAGES/framework-docs.po b/doc/locales/pt_BR/LC_MESSAGES/framework-docs.po index 359458e8db57..df354f7a76cc 100644 --- a/doc/locales/pt_BR/LC_MESSAGES/framework-docs.po +++ b/doc/locales/pt_BR/LC_MESSAGES/framework-docs.po @@ -6480,7 +6480,7 @@ msgid "" msgstr "" #: ../../source/ref-api/flwr.client.rst:33::1 -#: flwr.client.clientapp.ClientApp:1 of +#: flwr.client.client_app.ClientApp:1 of msgid "Flower ClientApp." msgstr "" @@ -6779,7 +6779,7 @@ msgstr "" msgid "ClientApp" msgstr "" -#: flwr.client.clientapp.ClientApp:1 flwr.common.typing.ClientMessage:1 +#: flwr.client.client_app.ClientApp:1 flwr.common.typing.ClientMessage:1 #: flwr.common.typing.DisconnectRes:1 flwr.common.typing.EvaluateIns:1 #: flwr.common.typing.EvaluateRes:1 flwr.common.typing.FitIns:1 #: flwr.common.typing.FitRes:1 flwr.common.typing.GetParametersIns:1 @@ -6794,24 +6794,24 @@ msgid "Bases: :py:class:`object`" msgstr "" #: flwr.client.app.start_client:33 flwr.client.app.start_numpy_client:36 -#: flwr.client.clientapp.ClientApp:4 flwr.server.app.start_server:41 +#: flwr.client.client_app.ClientApp:4 flwr.server.app.start_server:41 #: flwr.server.driver.app.start_driver:30 of msgid "Examples" msgstr "" -#: flwr.client.clientapp.ClientApp:5 of +#: flwr.client.client_app.ClientApp:5 of msgid "" "Assuming a typical `Client` implementation named `FlowerClient`, you can " "wrap it in a `ClientApp` as follows:" msgstr "" -#: flwr.client.clientapp.ClientApp:16 of +#: flwr.client.client_app.ClientApp:16 of msgid "" "If the above code is in a Python module called `client`, it can be " "started as follows:" msgstr "" -#: flwr.client.clientapp.ClientApp:21 of +#: flwr.client.client_app.ClientApp:21 of msgid "" "In this `client:app` example, `client` refers to the Python module " "`client.py` in which the previous code lives in and `app` refers to the " diff --git a/doc/locales/zh_Hans/LC_MESSAGES/framework-docs.po b/doc/locales/zh_Hans/LC_MESSAGES/framework-docs.po index b6c32f994597..a4255e51f4b6 100644 --- a/doc/locales/zh_Hans/LC_MESSAGES/framework-docs.po +++ b/doc/locales/zh_Hans/LC_MESSAGES/framework-docs.po @@ -7209,7 +7209,7 @@ msgid "" msgstr "" #: ../../source/ref-api/flwr.client.rst:33::1 -#: flwr.client.clientapp.ClientApp:1 of +#: flwr.client.client_app.ClientApp:1 of #, fuzzy msgid "Flower ClientApp." msgstr "Flower 客户端。" @@ -7511,7 +7511,7 @@ msgstr "当前客户端属性。" msgid "ClientApp" msgstr "客户端" -#: flwr.client.clientapp.ClientApp:1 flwr.common.typing.ClientMessage:1 +#: flwr.client.client_app.ClientApp:1 flwr.common.typing.ClientMessage:1 #: flwr.common.typing.DisconnectRes:1 flwr.common.typing.EvaluateIns:1 #: flwr.common.typing.EvaluateRes:1 flwr.common.typing.FitIns:1 #: flwr.common.typing.FitRes:1 flwr.common.typing.GetParametersIns:1 @@ -7526,24 +7526,24 @@ msgid "Bases: :py:class:`object`" msgstr "" #: flwr.client.app.start_client:33 flwr.client.app.start_numpy_client:36 -#: flwr.client.clientapp.ClientApp:4 flwr.server.app.start_server:41 +#: flwr.client.client_app.ClientApp:4 flwr.server.app.start_server:41 #: flwr.server.driver.app.start_driver:30 of msgid "Examples" msgstr "实例" -#: flwr.client.clientapp.ClientApp:5 of +#: flwr.client.client_app.ClientApp:5 of msgid "" "Assuming a typical `Client` implementation named `FlowerClient`, you can " "wrap it in a `ClientApp` as follows:" msgstr "" -#: flwr.client.clientapp.ClientApp:16 of +#: flwr.client.client_app.ClientApp:16 of msgid "" "If the above code is in a Python module called `client`, it can be " "started as follows:" msgstr "" -#: flwr.client.clientapp.ClientApp:21 of +#: flwr.client.client_app.ClientApp:21 of msgid "" "In this `client:app` example, `client` refers to the Python module " "`client.py` in which the previous code lives in and `app` refers to the " diff --git a/src/py/flwr/client/__init__.py b/src/py/flwr/client/__init__.py index f359fb472cbe..a721fb584164 100644 --- a/src/py/flwr/client/__init__.py +++ b/src/py/flwr/client/__init__.py @@ -19,7 +19,7 @@ from .app import start_client as start_client from .app import start_numpy_client as start_numpy_client from .client import Client as Client -from .clientapp import ClientApp as ClientApp +from .client_app import ClientApp as ClientApp from .numpy_client import NumPyClient as NumPyClient from .typing import ClientFn as ClientFn diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 2d49b3a44f08..93d654379cfc 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -23,7 +23,7 @@ from typing import Callable, ContextManager, Optional, Tuple, Union from flwr.client.client import Client -from flwr.client.clientapp import ClientApp +from flwr.client.client_app import ClientApp from flwr.client.typing import ClientFn from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event from flwr.common.address import parse_address @@ -37,7 +37,7 @@ from flwr.common.exit_handlers import register_exit_handlers from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature -from .clientapp import load_client_app +from .client_app import load_client_app from .grpc_client.connection import grpc_connection from .grpc_rere_client.connection import grpc_request_response from .message_handler.message_handler import handle_control_message diff --git a/src/py/flwr/client/clientapp.py b/src/py/flwr/client/client_app.py similarity index 100% rename from src/py/flwr/client/clientapp.py rename to src/py/flwr/client/client_app.py diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/backend.py b/src/py/flwr/server/superlink/fleet/vce/backend/backend.py index f2796a5758a0..1d5e3a6a51ad 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/backend.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/backend.py @@ -18,7 +18,7 @@ from abc import ABC, abstractmethod from typing import Callable, Dict, Tuple -from flwr.client.clientapp import ClientApp +from flwr.client.client_app import ClientApp from flwr.common.context import Context from flwr.common.message import Message from flwr.common.typing import ConfigsRecordValues diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py index b29d76b239e5..602fd94bea51 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py @@ -18,7 +18,7 @@ from logging import INFO from typing import Callable, Dict, List, Tuple, Union -from flwr.client.clientapp import ClientApp +from flwr.client.client_app import ClientApp from flwr.common.context import Context from flwr.common.logger import log from flwr.common.message import Message diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py index f0cca527ab96..e75e01ba3af0 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py @@ -22,7 +22,7 @@ import ray from flwr.client import Client, NumPyClient -from flwr.client.clientapp import ClientApp +from flwr.client.client_app import ClientApp from flwr.common import ( Config, ConfigsRecord, diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index 666e7e7d9ec3..5d194632541e 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -19,7 +19,7 @@ from logging import ERROR, INFO from typing import Dict, Optional -from flwr.client.clientapp import ClientApp, load_client_app +from flwr.client.client_app import ClientApp, load_client_app from flwr.client.node_state import NodeState from flwr.common.logger import log from flwr.server.superlink.state import StateFactory diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 5ac0b2c27484..2b8a683cd893 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -25,7 +25,7 @@ from ray import ObjectRef from ray.util.actor_pool import ActorPool -from flwr.client.clientapp import ClientApp +from flwr.client.client_app import ClientApp from flwr.common import Context, Message from flwr.common.logger import log diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index a45321ed2368..6d04fa875201 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -21,7 +21,7 @@ from flwr import common from flwr.client import ClientFn -from flwr.client.clientapp import ClientApp +from flwr.client.client_app import ClientApp from flwr.client.node_state import NodeState from flwr.common import Message, Metadata, RecordSet from flwr.common.constant import ( diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 24fe3546e7d9..684fad60bf74 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -22,7 +22,7 @@ import ray from flwr.client import Client, NumPyClient -from flwr.client.clientapp import ClientApp +from flwr.client.client_app import ClientApp from flwr.common import ( Config, ConfigsRecord, From 3b686d5748ae84c5cfd46e3935638cd21c6773d4 Mon Sep 17 00:00:00 2001 From: Javier Date: Tue, 27 Feb 2024 20:50:15 +0100 Subject: [PATCH 14/17] Handle loading of non-existing `ClientApp` in `RayBackend` (#3018) --- .../superlink/fleet/vce/backend/raybackend.py | 45 +++++-- .../fleet/vce/backend/raybackend_test.py | 125 +++++++++++++----- .../simulation/ray_transport/ray_actor.py | 5 +- 3 files changed, 127 insertions(+), 48 deletions(-) diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py index 602fd94bea51..409deb077f1d 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py @@ -15,10 +15,12 @@ """Ray backend for the Fleet API using the Simulation Engine.""" import pathlib -from logging import INFO +from logging import ERROR, INFO from typing import Callable, Dict, List, Tuple, Union -from flwr.client.client_app import ClientApp +import ray + +from flwr.client.client_app import ClientApp, LoadClientAppError from flwr.common.context import Context from flwr.common.logger import log from flwr.common.message import Message @@ -46,6 +48,9 @@ def __init__( log(INFO, "Initialising: %s", self.__class__.__name__) log(INFO, "Backend config: %s", backend_config) + if not pathlib.Path(work_dir).exists(): + raise ValueError(f"Specified work_dir {work_dir} does not exist.") + # Init ray and append working dir if needed runtime_env = ( self._configure_runtime_env(work_dir=work_dir) if work_dir else None @@ -138,22 +143,34 @@ async def process_message( """ node_id = message.metadata.dst_node_id - # Submite a task to the pool - future = await self.pool.submit( - lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), - (app, message, str(node_id), context), - ) + try: + # Submite a task to the pool + future = await self.pool.submit( + lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), + (app, message, str(node_id), context), + ) - await future + await future - # Fetch result - ( - out_mssg, - updated_context, - ) = await self.pool.fetch_result_and_return_actor_to_pool(future) + # Fetch result + ( + out_mssg, + updated_context, + ) = await self.pool.fetch_result_and_return_actor_to_pool(future) - return out_mssg, updated_context + return out_mssg, updated_context + + except LoadClientAppError as load_ex: + log( + ERROR, + "An exception was raised when processing a message. Terminating %s", + self.__class__.__name__, + ) + await self.terminate() + raise load_ex async def terminate(self) -> None: """Terminate all actors in actor pool.""" await self.pool.terminate_all_actors() + ray.shutdown() + log(INFO, "Terminated %s", self.__class__.__name__) diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py index e75e01ba3af0..fd246b5fc2af 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py @@ -16,13 +16,12 @@ import asyncio from math import pi +from pathlib import Path from typing import Callable, Dict, Optional, Tuple, Union from unittest import IsolatedAsyncioTestCase -import ray - from flwr.client import Client, NumPyClient -from flwr.client.client_app import ClientApp +from flwr.client.client_app import ClientApp, LoadClientAppError, load_client_app from flwr.common import ( Config, ConfigsRecord, @@ -35,8 +34,7 @@ ) from flwr.common.constant import MESSAGE_TYPE_GET_PROPERTIES from flwr.common.recordset_compat import getpropertiesins_to_recordset - -from .raybackend import RayBackend +from flwr.server.superlink.fleet.vce.backend.raybackend import RayBackend class DummyClient(NumPyClient): @@ -60,6 +58,19 @@ def _load_app() -> ClientApp: return ClientApp(client_fn=get_dummy_client) +client_app = ClientApp( + client_fn=get_dummy_client, +) + + +def _load_from_module(client_app_module_name: str) -> Callable[[], ClientApp]: + def _load_app() -> ClientApp: + app: ClientApp = load_client_app(client_app_module_name) + return app + + return _load_app + + async def backend_build_process_and_termination( backend: RayBackend, process_args: Optional[Tuple[Callable[[], ClientApp], Message, Context]] = None, @@ -73,11 +84,38 @@ async def backend_build_process_and_termination( await backend.terminate() - ray.shutdown() - return to_return +def _create_message_and_context() -> Tuple[Message, Context, float]: + + # Construct a Message + mult_factor = 2024 + getproperties_ins = GetPropertiesIns(config={"factor": mult_factor}) + recordset = getpropertiesins_to_recordset(getproperties_ins) + message = Message( + content=recordset, + metadata=Metadata( + run_id=0, + message_id="", + group_id="", + src_node_id=0, + dst_node_id=0, + reply_to_message="", + ttl="", + message_type=MESSAGE_TYPE_GET_PROPERTIES, + ), + ) + + # Construct emtpy Context + context = Context(state=RecordSet()) + + # Expected output + expected_output = pi * mult_factor + + return message, context, expected_output + + class AsyncTestRayBackend(IsolatedAsyncioTestCase): """A basic class that allows runnig multliple asyncio tests.""" @@ -88,33 +126,18 @@ def test_backend_creation_and_termination(self) -> None: backend_build_process_and_termination(backend=backend, process_args=None) ) - def test_backend_creation_submit_and_termination(self) -> None: - """Test submit.""" - backend = RayBackend(backend_config={}, work_dir="") + def test_backend_creation_submit_and_termination( + self, + client_app_loader: Callable[[], ClientApp] = _load_app, + workdir: str = "", + ) -> None: + """Test submitting a message to a given ClientApp.""" + backend = RayBackend(backend_config={}, work_dir=workdir) # Define ClientApp - client_app_callable = _load_app - - # Construct a Message - mult_factor = 2024 - getproperties_ins = GetPropertiesIns(config={"factor": mult_factor}) - recordset = getpropertiesins_to_recordset(getproperties_ins) - message = Message( - content=recordset, - metadata=Metadata( - run_id=0, - message_id="", - group_id="", - src_node_id=0, - dst_node_id=0, - reply_to_message="", - ttl="", - message_type=MESSAGE_TYPE_GET_PROPERTIES, - ), - ) + client_app_callable = client_app_loader - # Construct emtpy Context - context = Context(state=RecordSet()) + message, context, expected_output = _create_message_and_context() res = asyncio.run( backend_build_process_and_termination( @@ -131,11 +154,47 @@ def test_backend_creation_submit_and_termination(self) -> None: content = out_mssg.content assert ( content.configs_records["getpropertiesres.properties"]["result"] - == pi * mult_factor + == expected_output ) # Verify context is correct obtained_result_in_context = updated_context.state.configs_records["result"][ "result" ] - assert obtained_result_in_context == pi * mult_factor + assert obtained_result_in_context == expected_output + + def test_backend_creation_submit_and_termination_non_existing_client_app( + self, + ) -> None: + """Testing with ClientApp module that does not exist.""" + with self.assertRaises(LoadClientAppError): + self.test_backend_creation_submit_and_termination( + client_app_loader=_load_from_module("a_non_existing_module:app") + ) + + def test_backend_creation_submit_and_termination_existing_client_app( + self, + ) -> None: + """Testing with ClientApp module that exist.""" + # Resolve what should be the workdir to pass upon Backend initialisation + file_path = Path(__file__) + working_dir = Path.cwd() + rel_workdir = file_path.relative_to(working_dir) + + # Susbtract last element + rel_workdir_str = str(rel_workdir.parent) + + self.test_backend_creation_submit_and_termination( + client_app_loader=_load_from_module("raybackend_test:client_app"), + workdir=rel_workdir_str, + ) + + def test_backend_creation_submit_and_termination_existing_client_app_unsetworkdir( + self, + ) -> None: + """Testing with ClientApp module that exist but the passed workdir does not.""" + with self.assertRaises(ValueError): + self.test_backend_creation_submit_and_termination( + client_app_loader=_load_from_module("raybackend_test:client_app"), + workdir="/?&%$^#%@$!", + ) diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 2b8a683cd893..08d0576e39f0 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -25,7 +25,7 @@ from ray import ObjectRef from ray.util.actor_pool import ActorPool -from flwr.client.client_app import ClientApp +from flwr.client.client_app import ClientApp, LoadClientAppError from flwr.common import Context, Message from flwr.common.logger import log @@ -67,6 +67,9 @@ def run( # Handle task message out_message = app(message=message, context=context) + except LoadClientAppError as load_ex: + raise load_ex + except Exception as ex: client_trace = traceback.format_exc() mssg = ( From 4910520eaa39090476cb38ec503379a3d1559932 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 28 Feb 2024 10:51:18 +0000 Subject: [PATCH 15/17] rename node_id to partition_id (#3023) --- examples/advanced-pytorch/utils.py | 4 ++-- examples/embedded-devices/client_pytorch.py | 4 ++-- examples/embedded-devices/client_tf.py | 4 ++-- .../pytorch-from-centralized-to-federated/cifar.py | 4 ++-- .../client.py | 4 ++-- .../pytorch-from-centralized-to-federated/run.sh | 2 +- examples/quickstart-huggingface/README.md | 4 ++-- examples/quickstart-huggingface/client.py | 14 +++++++------- examples/quickstart-huggingface/run.sh | 2 +- examples/quickstart-mlx/README.md | 8 ++++---- examples/quickstart-mlx/client.py | 4 ++-- examples/quickstart-mlx/run.sh | 2 +- examples/quickstart-pandas/README.md | 4 ++-- examples/quickstart-pandas/client.py | 6 +++--- examples/quickstart-pandas/run.sh | 2 +- examples/quickstart-pytorch-lightning/README.md | 6 +++--- examples/quickstart-pytorch-lightning/client.py | 6 +++--- examples/quickstart-pytorch-lightning/run.sh | 2 +- examples/quickstart-pytorch/README.md | 6 +++--- examples/quickstart-pytorch/client.py | 12 ++++++------ examples/quickstart-pytorch/run.sh | 2 +- examples/quickstart-sklearn-tabular/README.md | 6 +++--- examples/quickstart-sklearn-tabular/client.py | 4 ++-- examples/quickstart-sklearn-tabular/run.sh | 2 +- examples/quickstart-tensorflow/client.py | 4 ++-- examples/quickstart-tensorflow/run.sh | 2 +- examples/sklearn-logreg-mnist/README.md | 4 ++-- examples/sklearn-logreg-mnist/client.py | 4 ++-- examples/sklearn-logreg-mnist/run.sh | 2 +- examples/xgboost-comprehensive/README.md | 4 ++-- examples/xgboost-comprehensive/client.py | 4 ++-- examples/xgboost-comprehensive/run_bagging.sh | 2 +- examples/xgboost-comprehensive/run_cyclic.sh | 2 +- examples/xgboost-comprehensive/sim.py | 6 +++--- examples/xgboost-comprehensive/utils.py | 4 ++-- examples/xgboost-quickstart/README.md | 4 ++-- examples/xgboost-quickstart/client.py | 10 +++++----- examples/xgboost-quickstart/run.sh | 2 +- 38 files changed, 84 insertions(+), 84 deletions(-) diff --git a/examples/advanced-pytorch/utils.py b/examples/advanced-pytorch/utils.py index 186f079010dc..4a0f6918cdd6 100644 --- a/examples/advanced-pytorch/utils.py +++ b/examples/advanced-pytorch/utils.py @@ -9,10 +9,10 @@ warnings.filterwarnings("ignore") -def load_partition(node_id, toy: bool = False): +def load_partition(partition_id, toy: bool = False): """Load partition CIFAR10 data.""" fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) - partition = fds.load_partition(node_id) + partition = fds.load_partition(partition_id) # Divide data on each node: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2) partition_train_test = partition_train_test.with_transform(apply_transforms) diff --git a/examples/embedded-devices/client_pytorch.py b/examples/embedded-devices/client_pytorch.py index f326db7c678c..3f1e6c7d51b7 100644 --- a/examples/embedded-devices/client_pytorch.py +++ b/examples/embedded-devices/client_pytorch.py @@ -105,8 +105,8 @@ def apply_transforms(batch): trainsets = [] validsets = [] - for node_id in range(NUM_CLIENTS): - partition = fds.load_partition(node_id, "train") + for partition_id in range(NUM_CLIENTS): + partition = fds.load_partition(partition_id, "train") # Divide data on each node: 90% train, 10% test partition = partition.train_test_split(test_size=0.1) partition = partition.with_transform(apply_transforms) diff --git a/examples/embedded-devices/client_tf.py b/examples/embedded-devices/client_tf.py index ae793ecd81e0..d59b31ab1569 100644 --- a/examples/embedded-devices/client_tf.py +++ b/examples/embedded-devices/client_tf.py @@ -40,8 +40,8 @@ def prepare_dataset(use_mnist: bool): fds = FederatedDataset(dataset="cifar10", partitioners={"train": NUM_CLIENTS}) img_key = "img" partitions = [] - for node_id in range(NUM_CLIENTS): - partition = fds.load_partition(node_id, "train") + for partition_id in range(NUM_CLIENTS): + partition = fds.load_partition(partition_id, "train") partition.set_format("numpy") # Divide data on each node: 90% train, 10% test partition = partition.train_test_split(test_size=0.1) diff --git a/examples/pytorch-from-centralized-to-federated/cifar.py b/examples/pytorch-from-centralized-to-federated/cifar.py index e8f3ec3fd724..277a21da2e70 100644 --- a/examples/pytorch-from-centralized-to-federated/cifar.py +++ b/examples/pytorch-from-centralized-to-federated/cifar.py @@ -51,10 +51,10 @@ def forward(self, x: Tensor) -> Tensor: return x -def load_data(node_id: int): +def load_data(partition_id: int): """Load partition CIFAR10 data.""" fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) - partition = fds.load_partition(node_id) + partition = fds.load_partition(partition_id) # Divide data on each node: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2) pytorch_transforms = Compose( diff --git a/examples/pytorch-from-centralized-to-federated/client.py b/examples/pytorch-from-centralized-to-federated/client.py index 2a3bbccae2bb..9df4739e0aab 100644 --- a/examples/pytorch-from-centralized-to-federated/client.py +++ b/examples/pytorch-from-centralized-to-federated/client.py @@ -81,11 +81,11 @@ def evaluate( def main() -> None: """Load data, start CifarClient.""" parser = argparse.ArgumentParser(description="Flower") - parser.add_argument("--node-id", type=int, required=True, choices=range(0, 10)) + parser.add_argument("--partition-id", type=int, required=True, choices=range(0, 10)) args = parser.parse_args() # Load data - trainloader, testloader = cifar.load_data(args.node_id) + trainloader, testloader = cifar.load_data(args.partition_id) # Load model model = cifar.Net().to(DEVICE).train() diff --git a/examples/pytorch-from-centralized-to-federated/run.sh b/examples/pytorch-from-centralized-to-federated/run.sh index 1ed51dd787ac..6ddf6ad476b4 100755 --- a/examples/pytorch-from-centralized-to-federated/run.sh +++ b/examples/pytorch-from-centralized-to-federated/run.sh @@ -6,7 +6,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in $(seq 0 1); do echo "Starting client $i" - python client.py --node-id $i & + python client.py --partition-id $i & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/quickstart-huggingface/README.md b/examples/quickstart-huggingface/README.md index 5fdba887f181..ce7790cd4af5 100644 --- a/examples/quickstart-huggingface/README.md +++ b/examples/quickstart-huggingface/README.md @@ -62,13 +62,13 @@ Now you are ready to start the Flower clients which will participate in the lear Start client 1 in the first terminal: ```shell -python3 client.py --node-id 0 +python3 client.py --partition-id 0 ``` Start client 2 in the second terminal: ```shell -python3 client.py --node-id 1 +python3 client.py --partition-id 1 ``` You will see that PyTorch is starting a federated training. diff --git a/examples/quickstart-huggingface/client.py b/examples/quickstart-huggingface/client.py index 5dc461d30536..9be08d0cbcf4 100644 --- a/examples/quickstart-huggingface/client.py +++ b/examples/quickstart-huggingface/client.py @@ -17,10 +17,10 @@ CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint -def load_data(node_id): +def load_data(partition_id): """Load IMDB data (training and eval)""" fds = FederatedDataset(dataset="imdb", partitioners={"train": 1_000}) - partition = fds.load_partition(node_id) + partition = fds.load_partition(partition_id) # Divide data: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2) @@ -78,12 +78,12 @@ def test(net, testloader): return loss, accuracy -def main(node_id): +def main(partition_id): net = AutoModelForSequenceClassification.from_pretrained( CHECKPOINT, num_labels=2 ).to(DEVICE) - trainloader, testloader = load_data(node_id) + trainloader, testloader = load_data(partition_id) # Flower client class IMDBClient(fl.client.NumPyClient): @@ -116,12 +116,12 @@ def evaluate(self, parameters, config): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--node-id", + "--partition-id", choices=list(range(1_000)), required=True, type=int, help="Partition of the dataset divided into 1,000 iid partitions created " "artificially.", ) - node_id = parser.parse_args().node_id - main(node_id) + partition_id = parser.parse_args().partition_id + main(partition_id) diff --git a/examples/quickstart-huggingface/run.sh b/examples/quickstart-huggingface/run.sh index e722a24a21a9..fa989eab1471 100755 --- a/examples/quickstart-huggingface/run.sh +++ b/examples/quickstart-huggingface/run.sh @@ -6,7 +6,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in `seq 0 1`; do echo "Starting client $i" - python client.py --node-id ${i}& + python client.py --partition-id ${i}& done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/quickstart-mlx/README.md b/examples/quickstart-mlx/README.md index d94a87a014f7..cca55bcb946a 100644 --- a/examples/quickstart-mlx/README.md +++ b/examples/quickstart-mlx/README.md @@ -66,19 +66,19 @@ following commands. Start a first client in the first terminal: ```shell -python3 client.py --node-id 0 +python3 client.py --partition-id 0 ``` And another one in the second terminal: ```shell -python3 client.py --node-id 1 +python3 client.py --partition-id 1 ``` If you want to utilize your GPU, you can use the `--gpu` argument: ```shell -python3 client.py --gpu --node-id 2 +python3 client.py --gpu --partition-id 2 ``` Note that you can start many more clients if you want, but each will have to be in its own terminal. @@ -96,7 +96,7 @@ We will use `flwr_datasets` to easily download and partition the `MNIST` dataset ```python fds = FederatedDataset(dataset="mnist", partitioners={"train": 3}) -partition = fds.load_partition(node_id = args.node_id) +partition = fds.load_partition(partition_id = args.partition_id) partition_splits = partition.train_test_split(test_size=0.2) partition_splits['train'].set_format("numpy") diff --git a/examples/quickstart-mlx/client.py b/examples/quickstart-mlx/client.py index 3b506399a5f1..faba2b94d6bd 100644 --- a/examples/quickstart-mlx/client.py +++ b/examples/quickstart-mlx/client.py @@ -89,7 +89,7 @@ def evaluate(self, parameters, config): parser = argparse.ArgumentParser("Train a simple MLP on MNIST with MLX.") parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") parser.add_argument( - "--node-id", + "--partition-id", choices=[0, 1, 2], type=int, help="Partition of the dataset divided into 3 iid partitions created artificially.", @@ -106,7 +106,7 @@ def evaluate(self, parameters, config): learning_rate = 1e-1 fds = FederatedDataset(dataset="mnist", partitioners={"train": 3}) - partition = fds.load_partition(node_id=args.node_id) + partition = fds.load_partition(partition_id=args.partition_id) partition_splits = partition.train_test_split(test_size=0.2) partition_splits["train"].set_format("numpy") diff --git a/examples/quickstart-mlx/run.sh b/examples/quickstart-mlx/run.sh index 70281049517d..40d211848c07 100755 --- a/examples/quickstart-mlx/run.sh +++ b/examples/quickstart-mlx/run.sh @@ -8,7 +8,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in $(seq 0 1); do echo "Starting client $i" - python client.py --node-id $i & + python client.py --partition-id $i & done # Enable CTRL+C to stop all background processes diff --git a/examples/quickstart-pandas/README.md b/examples/quickstart-pandas/README.md index efcda43cf34d..dd69f3ead3cb 100644 --- a/examples/quickstart-pandas/README.md +++ b/examples/quickstart-pandas/README.md @@ -70,13 +70,13 @@ Now you are ready to start the Flower clients which will participate in the lear Start client 1 in the first terminal: ```shell -$ python3 client.py --node-id 0 +$ python3 client.py --partition-id 0 ``` Start client 2 in the second terminal: ```shell -$ python3 client.py --node-id 1 +$ python3 client.py --partition-id 1 ``` You will see that the server is printing aggregated statistics about the dataset distributed amongst clients. Have a look to the [Flower Quickstarter documentation](https://flower.ai/docs/quickstart-pandas.html) for a detailed explanation. diff --git a/examples/quickstart-pandas/client.py b/examples/quickstart-pandas/client.py index 8585922e4572..c52b7c65b04c 100644 --- a/examples/quickstart-pandas/client.py +++ b/examples/quickstart-pandas/client.py @@ -42,14 +42,14 @@ def fit( parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--node-id", + "--partition-id", type=int, choices=range(0, N_CLIENTS), required=True, - help="Specifies the node id of artificially partitioned datasets.", + help="Specifies the partition id of artificially partitioned datasets.", ) args = parser.parse_args() - partition_id = args.node_id + partition_id = args.partition_id # Load the partition data fds = FederatedDataset(dataset="hitorilabs/iris", partitioners={"train": N_CLIENTS}) diff --git a/examples/quickstart-pandas/run.sh b/examples/quickstart-pandas/run.sh index 571fa8bfb3e4..2ae1e582b8cf 100755 --- a/examples/quickstart-pandas/run.sh +++ b/examples/quickstart-pandas/run.sh @@ -4,7 +4,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in `seq 0 1`; do echo "Starting client $i" - python client.py --node-id ${i} & + python client.py --partition-id ${i} & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/quickstart-pytorch-lightning/README.md b/examples/quickstart-pytorch-lightning/README.md index 1d404a5d714f..fb29c7e9e9ea 100644 --- a/examples/quickstart-pytorch-lightning/README.md +++ b/examples/quickstart-pytorch-lightning/README.md @@ -57,20 +57,20 @@ Afterwards you are ready to start the Flower server as well as the clients. You python server.py ``` -Now you are ready to start the Flower clients which will participate in the learning. We need to specify the node id to +Now you are ready to start the Flower clients which will participate in the learning. We need to specify the partition id to use different partitions of the data on different nodes. To do so simply open two more terminal windows and run the following commands. Start client 1 in the first terminal: ```shell -python client.py --node-id 0 +python client.py --partition-id 0 ``` Start client 2 in the second terminal: ```shell -python client.py --node-id 1 +python client.py --partition-id 1 ``` You will see that PyTorch is starting a federated training. Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch) for a detailed explanation. diff --git a/examples/quickstart-pytorch-lightning/client.py b/examples/quickstart-pytorch-lightning/client.py index fc5f1ee03cfe..6e21259cc492 100644 --- a/examples/quickstart-pytorch-lightning/client.py +++ b/examples/quickstart-pytorch-lightning/client.py @@ -58,18 +58,18 @@ def _set_parameters(model, parameters): def main() -> None: parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--node-id", + "--partition-id", type=int, choices=range(0, 10), required=True, help="Specifies the artificial data partition", ) args = parser.parse_args() - node_id = args.node_id + partition_id = args.partition_id # Model and data model = mnist.LitAutoEncoder() - train_loader, val_loader, test_loader = mnist.load_data(node_id) + train_loader, val_loader, test_loader = mnist.load_data(partition_id) # Flower client client = FlowerClient(model, train_loader, val_loader, test_loader).to_client() diff --git a/examples/quickstart-pytorch-lightning/run.sh b/examples/quickstart-pytorch-lightning/run.sh index 60893a9a055b..62a1dac199bd 100755 --- a/examples/quickstart-pytorch-lightning/run.sh +++ b/examples/quickstart-pytorch-lightning/run.sh @@ -6,7 +6,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in $(seq 0 1); do echo "Starting client $i" - python client.py --node-id "${i}" & + python client.py --partition-id "${i}" & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/quickstart-pytorch/README.md b/examples/quickstart-pytorch/README.md index 3b9b9b310608..02c9b4b38498 100644 --- a/examples/quickstart-pytorch/README.md +++ b/examples/quickstart-pytorch/README.md @@ -55,20 +55,20 @@ Afterwards you are ready to start the Flower server as well as the clients. You python3 server.py ``` -Now you are ready to start the Flower clients which will participate in the learning. We need to specify the node id to +Now you are ready to start the Flower clients which will participate in the learning. We need to specify the partition id to use different partitions of the data on different nodes. To do so simply open two more terminal windows and run the following commands. Start client 1 in the first terminal: ```shell -python3 client.py --node-id 0 +python3 client.py --partition-id 0 ``` Start client 2 in the second terminal: ```shell -python3 client.py --node-id 1 +python3 client.py --partition-id 1 ``` You will see that PyTorch is starting a federated training. Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch) for a detailed explanation. diff --git a/examples/quickstart-pytorch/client.py b/examples/quickstart-pytorch/client.py index b5ea4c94dd21..e640ce111dff 100644 --- a/examples/quickstart-pytorch/client.py +++ b/examples/quickstart-pytorch/client.py @@ -69,10 +69,10 @@ def test(net, testloader): return loss, accuracy -def load_data(node_id): +def load_data(partition_id): """Load partition CIFAR10 data.""" fds = FederatedDataset(dataset="cifar10", partitioners={"train": 3}) - partition = fds.load_partition(node_id) + partition = fds.load_partition(partition_id) # Divide data on each node: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2) pytorch_transforms = Compose( @@ -94,20 +94,20 @@ def apply_transforms(batch): # 2. Federation of the pipeline with Flower # ############################################################################# -# Get node id +# Get partition id parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--node-id", + "--partition-id", choices=[0, 1, 2], required=True, type=int, help="Partition of the dataset divided into 3 iid partitions created artificially.", ) -node_id = parser.parse_args().node_id +partition_id = parser.parse_args().partition_id # Load model and data (simple CNN, CIFAR-10) net = Net().to(DEVICE) -trainloader, testloader = load_data(node_id=node_id) +trainloader, testloader = load_data(partition_id=partition_id) # Define Flower client diff --git a/examples/quickstart-pytorch/run.sh b/examples/quickstart-pytorch/run.sh index cdace99bb8df..6ca9c8cafec9 100755 --- a/examples/quickstart-pytorch/run.sh +++ b/examples/quickstart-pytorch/run.sh @@ -8,7 +8,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in $(seq 0 1); do echo "Starting client $i" - python client.py --node-id "$i" & + python client.py --partition-id "$i" & done # Enable CTRL+C to stop all background processes diff --git a/examples/quickstart-sklearn-tabular/README.md b/examples/quickstart-sklearn-tabular/README.md index 373aaea5999c..a975a9392800 100644 --- a/examples/quickstart-sklearn-tabular/README.md +++ b/examples/quickstart-sklearn-tabular/README.md @@ -63,15 +63,15 @@ poetry run python3 server.py Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminals and run the following command in each: ```shell -poetry run python3 client.py --node-id 0 # node-id should be any of {0,1,2} +poetry run python3 client.py --partition-id 0 # partition-id should be any of {0,1,2} ``` Alternatively you can run all of it in one shell as follows: ```shell poetry run python3 server.py & -poetry run python3 client.py --node-id 0 & -poetry run python3 client.py --node-id 1 +poetry run python3 client.py --partition-id 0 & +poetry run python3 client.py --partition-id 1 ``` You will see that Flower is starting a federated training. diff --git a/examples/quickstart-sklearn-tabular/client.py b/examples/quickstart-sklearn-tabular/client.py index 5dc0e88b3c75..fcab8f5d5612 100644 --- a/examples/quickstart-sklearn-tabular/client.py +++ b/examples/quickstart-sklearn-tabular/client.py @@ -13,14 +13,14 @@ parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--node-id", + "--partition-id", type=int, choices=range(0, N_CLIENTS), required=True, help="Specifies the artificial data partition", ) args = parser.parse_args() - partition_id = args.node_id + partition_id = args.partition_id # Load the partition data fds = FederatedDataset(dataset="hitorilabs/iris", partitioners={"train": N_CLIENTS}) diff --git a/examples/quickstart-sklearn-tabular/run.sh b/examples/quickstart-sklearn-tabular/run.sh index 48cee1b41b74..f770ca05f8f4 100755 --- a/examples/quickstart-sklearn-tabular/run.sh +++ b/examples/quickstart-sklearn-tabular/run.sh @@ -8,7 +8,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in $(seq 0 1); do echo "Starting client $i" - python client.py --node-id "${i}" & + python client.py --partition-id "${i}" & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/quickstart-tensorflow/client.py b/examples/quickstart-tensorflow/client.py index 37abbbcc46ec..3e2035c09311 100644 --- a/examples/quickstart-tensorflow/client.py +++ b/examples/quickstart-tensorflow/client.py @@ -11,7 +11,7 @@ # Parse arguments parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--node-id", + "--partition-id", type=int, choices=[0, 1, 2], required=True, @@ -26,7 +26,7 @@ # Download and partition dataset fds = FederatedDataset(dataset="cifar10", partitioners={"train": 3}) -partition = fds.load_partition(args.node_id, "train") +partition = fds.load_partition(args.partition_id, "train") partition.set_format("numpy") # Divide data on each node: 80% train, 20% test diff --git a/examples/quickstart-tensorflow/run.sh b/examples/quickstart-tensorflow/run.sh index 439abea8df4b..76188f197e3e 100755 --- a/examples/quickstart-tensorflow/run.sh +++ b/examples/quickstart-tensorflow/run.sh @@ -6,7 +6,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in `seq 0 1`; do echo "Starting client $i" - python client.py --node-id $i & + python client.py --partition-id $i & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/sklearn-logreg-mnist/README.md b/examples/sklearn-logreg-mnist/README.md index 50576d98ba3d..12b1a5e3bc1a 100644 --- a/examples/sklearn-logreg-mnist/README.md +++ b/examples/sklearn-logreg-mnist/README.md @@ -62,13 +62,13 @@ Now you are ready to start the Flower clients which will participate in the lear Start client 1 in the first terminal: ```shell -python3 client.py --node-id 0 # or any integer in {0-9} +python3 client.py --partition-id 0 # or any integer in {0-9} ``` Start client 2 in the second terminal: ```shell -python3 client.py --node-id 1 # or any integer in {0-9} +python3 client.py --partition-id 1 # or any integer in {0-9} ``` Alternatively, you can run all of it in one shell as follows: diff --git a/examples/sklearn-logreg-mnist/client.py b/examples/sklearn-logreg-mnist/client.py index 3d41cb6fbb21..1e9349df1acc 100644 --- a/examples/sklearn-logreg-mnist/client.py +++ b/examples/sklearn-logreg-mnist/client.py @@ -13,14 +13,14 @@ parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--node-id", + "--partition-id", type=int, choices=range(0, N_CLIENTS), required=True, help="Specifies the artificial data partition", ) args = parser.parse_args() - partition_id = args.node_id + partition_id = args.partition_id # Load the partition data fds = FederatedDataset(dataset="mnist", partitioners={"train": N_CLIENTS}) diff --git a/examples/sklearn-logreg-mnist/run.sh b/examples/sklearn-logreg-mnist/run.sh index 48cee1b41b74..f770ca05f8f4 100755 --- a/examples/sklearn-logreg-mnist/run.sh +++ b/examples/sklearn-logreg-mnist/run.sh @@ -8,7 +8,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in $(seq 0 1); do echo "Starting client $i" - python client.py --node-id "${i}" & + python client.py --partition-id "${i}" & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/xgboost-comprehensive/README.md b/examples/xgboost-comprehensive/README.md index 01fed646d056..dc6d7e3872d6 100644 --- a/examples/xgboost-comprehensive/README.md +++ b/examples/xgboost-comprehensive/README.md @@ -120,10 +120,10 @@ You can also run the example without the scripts. First, launch the server: python server.py --train-method=bagging/cyclic --pool-size=N --num-clients-per-round=N ``` -Then run at least two clients (each on a new terminal or computer in your network) passing different `NODE_ID` and all using the same `N` (denoting the total number of clients or data partitions): +Then run at least two clients (each on a new terminal or computer in your network) passing different `PARTITION_ID` and all using the same `N` (denoting the total number of clients or data partitions): ```bash -python client.py --train-method=bagging/cyclic --node-id=NODE_ID --num-partitions=N +python client.py --train-method=bagging/cyclic --partition-id=PARTITION_ID --num-partitions=N ``` ### Flower Simulation Setup diff --git a/examples/xgboost-comprehensive/client.py b/examples/xgboost-comprehensive/client.py index 74fbc4f5366a..66daed449fd5 100644 --- a/examples/xgboost-comprehensive/client.py +++ b/examples/xgboost-comprehensive/client.py @@ -35,9 +35,9 @@ resplitter=resplit, ) -# Load the partition for this `node_id` +# Load the partition for this `partition_id` log(INFO, "Loading partition...") -partition = fds.load_partition(node_id=args.node_id, split="train") +partition = fds.load_partition(partition_id=args.partition_id, split="train") partition.set_format("numpy") if args.centralised_eval: diff --git a/examples/xgboost-comprehensive/run_bagging.sh b/examples/xgboost-comprehensive/run_bagging.sh index e853a4ef19cb..a6300b781a06 100755 --- a/examples/xgboost-comprehensive/run_bagging.sh +++ b/examples/xgboost-comprehensive/run_bagging.sh @@ -8,7 +8,7 @@ sleep 30 # Sleep for 30s to give the server enough time to start for i in `seq 0 4`; do echo "Starting client $i" - python3 client.py --node-id=$i --num-partitions=5 --partitioner-type=exponential & + python3 client.py --partition-id=$i --num-partitions=5 --partitioner-type=exponential & done # Enable CTRL+C to stop all background processes diff --git a/examples/xgboost-comprehensive/run_cyclic.sh b/examples/xgboost-comprehensive/run_cyclic.sh index 47e09fd8faef..258bdf2fe0d8 100755 --- a/examples/xgboost-comprehensive/run_cyclic.sh +++ b/examples/xgboost-comprehensive/run_cyclic.sh @@ -8,7 +8,7 @@ sleep 15 # Sleep for 15s to give the server enough time to start for i in `seq 0 4`; do echo "Starting client $i" - python3 client.py --node-id=$i --train-method=cyclic --num-partitions=5 --partitioner-type=exponential --centralised-eval & + python3 client.py --partition-id=$i --train-method=cyclic --num-partitions=5 --partitioner-type=exponential --centralised-eval & done # Enable CTRL+C to stop all background processes diff --git a/examples/xgboost-comprehensive/sim.py b/examples/xgboost-comprehensive/sim.py index ae2a9239a493..b72b23931929 100644 --- a/examples/xgboost-comprehensive/sim.py +++ b/examples/xgboost-comprehensive/sim.py @@ -98,9 +98,9 @@ def main(): # Load and process all client partitions. This upfront cost is amortized soon # after the simulation begins since clients wont need to preprocess their partition. - for node_id in tqdm(range(args.pool_size), desc="Extracting client partition"): - # Extract partition for client with node_id - partition = fds.load_partition(node_id=node_id, split="train") + for partition_id in tqdm(range(args.pool_size), desc="Extracting client partition"): + # Extract partition for client with partition_id + partition = fds.load_partition(partition_id=partition_id, split="train") partition.set_format("numpy") if args.centralised_eval_client: diff --git a/examples/xgboost-comprehensive/utils.py b/examples/xgboost-comprehensive/utils.py index 102587f4266d..abc100da1ade 100644 --- a/examples/xgboost-comprehensive/utils.py +++ b/examples/xgboost-comprehensive/utils.py @@ -37,10 +37,10 @@ def client_args_parser(): help="Partitioner types.", ) parser.add_argument( - "--node-id", + "--partition-id", default=0, type=int, - help="Node ID used for the current client.", + help="Partition ID used for the current client.", ) parser.add_argument( "--seed", default=42, type=int, help="Seed used for train/test splitting." diff --git a/examples/xgboost-quickstart/README.md b/examples/xgboost-quickstart/README.md index cd99cd4c2895..72dde5706e8d 100644 --- a/examples/xgboost-quickstart/README.md +++ b/examples/xgboost-quickstart/README.md @@ -67,13 +67,13 @@ To do so simply open two more terminal windows and run the following commands. Start client 1 in the first terminal: ```shell -python3 client.py --node-id=0 +python3 client.py --partition-id=0 ``` Start client 2 in the second terminal: ```shell -python3 client.py --node-id=1 +python3 client.py --partition-id=1 ``` You will see that XGBoost is starting a federated training. diff --git a/examples/xgboost-quickstart/client.py b/examples/xgboost-quickstart/client.py index 62e8a441bae1..6ac23ae15148 100644 --- a/examples/xgboost-quickstart/client.py +++ b/examples/xgboost-quickstart/client.py @@ -24,13 +24,13 @@ warnings.filterwarnings("ignore", category=UserWarning) -# Define arguments parser for the client/node ID. +# Define arguments parser for the client/partition ID. parser = argparse.ArgumentParser() parser.add_argument( - "--node-id", + "--partition-id", default=0, type=int, - help="Node ID used for the current client.", + help="Partition ID used for the current client.", ) args = parser.parse_args() @@ -61,9 +61,9 @@ def transform_dataset_to_dmatrix(data: Union[Dataset, DatasetDict]) -> xgb.core. partitioner = IidPartitioner(num_partitions=30) fds = FederatedDataset(dataset="jxie/higgs", partitioners={"train": partitioner}) -# Load the partition for this `node_id` +# Load the partition for this `partition_id` log(INFO, "Loading partition...") -partition = fds.load_partition(node_id=args.node_id, split="train") +partition = fds.load_partition(partition_id=args.partition_id, split="train") partition.set_format("numpy") # Train/test splitting diff --git a/examples/xgboost-quickstart/run.sh b/examples/xgboost-quickstart/run.sh index 6287145bfb5f..b35af58222ab 100755 --- a/examples/xgboost-quickstart/run.sh +++ b/examples/xgboost-quickstart/run.sh @@ -8,7 +8,7 @@ sleep 5 # Sleep for 5s to give the server enough time to start for i in `seq 0 1`; do echo "Starting client $i" - python3 client.py --node-id=$i & + python3 client.py --partition-id=$i & done # Enable CTRL+C to stop all background processes From 256548ea17fc443436ed30cd34f92554ec009999 Mon Sep 17 00:00:00 2001 From: mohammadnaseri Date: Wed, 28 Feb 2024 11:20:11 +0000 Subject: [PATCH 16/17] Update cryptography version (#3028) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6bd5c74f29a9..b52be36b7ff5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,7 @@ python = "^3.8" numpy = "^1.21.0" grpcio = "^1.60.0" protobuf = "^4.25.2" -cryptography = "^41.0.2" +cryptography = "^42.0.4" pycryptodome = "^3.18.0" iterators = "^0.0.2" typer = { version = "^0.9.0", extras=["all"] } From d663cd46cdca6cc67664ac3bb9fbe6465e93a062 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 28 Feb 2024 12:12:12 +0000 Subject: [PATCH 17/17] Add `Workflow` (#3017) Co-authored-by: Daniel J. Beutel --- examples/app-pytorch/README.md | 16 +- examples/app-pytorch/server_workflow.py | 55 +++ src/py/flwr/server/__init__.py | 4 + src/py/flwr/server/compat/__init__.py | 2 + src/py/flwr/server/compat/app.py | 95 +---- src/py/flwr/server/compat/app_test.py | 84 ----- src/py/flwr/server/compat/app_utils.py | 102 +++++ src/py/flwr/server/compat/app_utils_test.py | 62 +++ src/py/flwr/server/compat/legacy_context.py | 55 +++ src/py/flwr/server/run_serverapp.py | 2 +- src/py/flwr/server/typing.py | 1 + src/py/flwr/server/workflow/__init__.py | 22 ++ .../flwr/server/workflow/default_workflows.py | 357 ++++++++++++++++++ 13 files changed, 680 insertions(+), 177 deletions(-) create mode 100644 examples/app-pytorch/server_workflow.py delete mode 100644 src/py/flwr/server/compat/app_test.py create mode 100644 src/py/flwr/server/compat/app_utils.py create mode 100644 src/py/flwr/server/compat/app_utils_test.py create mode 100644 src/py/flwr/server/compat/legacy_context.py create mode 100644 src/py/flwr/server/workflow/__init__.py create mode 100644 src/py/flwr/server/workflow/default_workflows.py diff --git a/examples/app-pytorch/README.md b/examples/app-pytorch/README.md index bcc7aa045e26..de1b6fdbb819 100644 --- a/examples/app-pytorch/README.md +++ b/examples/app-pytorch/README.md @@ -12,10 +12,12 @@ Let's assume the following project structure: ```bash $ tree . . -├── client.py # <-- contains `ClientApp` -├── server.py # <-- contains `ServerApp` -├── task.py # <-- task-specific code (model, data) -└── requirements.txt # <-- dependencies +├── client.py # <-- contains `ClientApp` +├── server.py # <-- contains `ServerApp` +├── server_workflow.py # <-- contains `ServerApp` with workflow +├── server_custom.py # <-- contains `ServerApp` with custom main function +├── task.py # <-- task-specific code (model, data) +└── requirements.txt # <-- dependencies ``` ## Install dependencies @@ -52,6 +54,12 @@ With both the long-running server (SuperLink) and two clients (SuperNode) up and flower-server-app server:app --insecure ``` +Or, to try the workflow example, run: + +```bash +flower-server-app server_workflow:app --insecure +``` + Or, to try the custom server function example, run: ```bash diff --git a/examples/app-pytorch/server_workflow.py b/examples/app-pytorch/server_workflow.py new file mode 100644 index 000000000000..920e266c99e9 --- /dev/null +++ b/examples/app-pytorch/server_workflow.py @@ -0,0 +1,55 @@ +from typing import List, Tuple + +import flwr as fl +from flwr.common import Context, Metrics +from flwr.server import Driver, LegacyContext + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + examples = [num_examples for num_examples, _ in metrics] + + # Multiply accuracy of each client by number of examples used + train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] + train_accuracies = [ + num_examples * m["train_accuracy"] for num_examples, m in metrics + ] + val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] + val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] + + # Aggregate and return custom metric (weighted average) + return { + "train_loss": sum(train_losses) / sum(examples), + "train_accuracy": sum(train_accuracies) / sum(examples), + "val_loss": sum(val_losses) / sum(examples), + "val_accuracy": sum(val_accuracies) / sum(examples), + } + + +# Define strategy +strategy = fl.server.strategy.FedAvg( + fraction_fit=1.0, # Select all available clients + fraction_evaluate=0.0, # Disable evaluation + min_available_clients=2, + fit_metrics_aggregation_fn=weighted_average, +) + + +# Run via `flower-server-app server_workflow:app` +app = fl.server.ServerApp() + + +@app.main() +def main(driver: Driver, context: Context) -> None: + # Construct the LegacyContext + context = LegacyContext( + state=context.state, + config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, + ) + + # Create the workflow + workflow = fl.server.workflow.DefaultWorkflow() + + # Execute + workflow(driver, context) diff --git a/src/py/flwr/server/__init__.py b/src/py/flwr/server/__init__.py index 969bea96d1fe..633bd668b520 100644 --- a/src/py/flwr/server/__init__.py +++ b/src/py/flwr/server/__init__.py @@ -16,12 +16,14 @@ from . import strategy +from . import workflow as workflow from .app import run_driver_api as run_driver_api from .app import run_fleet_api as run_fleet_api from .app import run_superlink as run_superlink from .app import start_server as start_server from .client_manager import ClientManager as ClientManager from .client_manager import SimpleClientManager as SimpleClientManager +from .compat import LegacyContext as LegacyContext from .compat import start_driver as start_driver from .driver import Driver as Driver from .history import History as History @@ -34,6 +36,7 @@ "ClientManager", "Driver", "History", + "LegacyContext", "run_driver_api", "run_fleet_api", "run_server_app", @@ -45,4 +48,5 @@ "start_driver", "start_server", "strategy", + "workflow", ] diff --git a/src/py/flwr/server/compat/__init__.py b/src/py/flwr/server/compat/__init__.py index 3a0c2b4e83a0..7bae196ddb65 100644 --- a/src/py/flwr/server/compat/__init__.py +++ b/src/py/flwr/server/compat/__init__.py @@ -16,7 +16,9 @@ from .app import start_driver as start_driver +from .legacy_context import LegacyContext as LegacyContext __all__ = [ + "LegacyContext", "start_driver", ] diff --git a/src/py/flwr/server/compat/app.py b/src/py/flwr/server/compat/app.py index 2e441db08668..203317a3e348 100644 --- a/src/py/flwr/server/compat/app.py +++ b/src/py/flwr/server/compat/app.py @@ -16,16 +16,13 @@ import sys -import threading -import time from logging import INFO from pathlib import Path -from typing import Dict, Optional, Union +from typing import Optional, Union from flwr.common import EventType, event from flwr.common.address import parse_address from flwr.common.logger import log, warn_deprecated_feature -from flwr.proto import driver_pb2 # pylint: disable=E0611 from flwr.server.client_manager import ClientManager from flwr.server.history import History from flwr.server.server import Server, init_defaults, run_fl @@ -33,8 +30,7 @@ from flwr.server.strategy import Strategy from ..driver import Driver -from ..driver.grpc_driver import GrpcDriver -from .driver_client_proxy import DriverClientProxy +from .app_utils import start_update_client_manager_thread DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091" @@ -104,11 +100,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals """ event(EventType.START_DRIVER_ENTER) - if driver: - # pylint: disable=protected-access - grpc_driver, _ = driver._get_grpc_driver_and_run_id() - # pylint: enable=protected-access - else: + if driver is None: # Not passing a `Driver` object is deprecated warn_deprecated_feature("start_driver") @@ -122,12 +114,9 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals # Create the Driver if isinstance(root_certificates, str): root_certificates = Path(root_certificates).read_bytes() - grpc_driver = GrpcDriver( + driver = Driver( driver_service_address=address, root_certificates=root_certificates ) - grpc_driver.connect() - - lock = threading.Lock() # Initialize the Driver API server and config initialized_server, initialized_config = init_defaults( @@ -142,18 +131,10 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals initialized_config, ) - f_stop = threading.Event() # Start the thread updating nodes - thread = threading.Thread( - target=update_client_manager, - args=( - grpc_driver, - initialized_server.client_manager(), - lock, - f_stop, - ), + thread, f_stop = start_update_client_manager_thread( + driver, initialized_server.client_manager() ) - thread.start() # Start training hist = run_fl( @@ -164,72 +145,10 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals f_stop.set() # Stop the Driver API server and the thread - with lock: - if driver: - del driver - else: - grpc_driver.disconnect() + del driver thread.join() event(EventType.START_SERVER_LEAVE) return hist - - -def update_client_manager( - driver: GrpcDriver, - client_manager: ClientManager, - lock: threading.Lock, - f_stop: threading.Event, -) -> None: - """Update the nodes list in the client manager. - - This function periodically communicates with the associated driver to get all - node_ids. Each node_id is then converted into a `DriverClientProxy` instance - and stored in the `registered_nodes` dictionary with node_id as key. - - New nodes will be added to the ClientManager via `client_manager.register()`, - and dead nodes will be removed from the ClientManager via - `client_manager.unregister()`. - """ - # Request for run_id - run_id = driver.create_run( - driver_pb2.CreateRunRequest() # pylint: disable=E1101 - ).run_id - - # Loop until the driver is disconnected - registered_nodes: Dict[int, DriverClientProxy] = {} - while not f_stop.is_set(): - with lock: - # End the while loop if the driver is disconnected - if driver.stub is None: - break - get_nodes_res = driver.get_nodes( - req=driver_pb2.GetNodesRequest(run_id=run_id) # pylint: disable=E1101 - ) - all_node_ids = {node.node_id for node in get_nodes_res.nodes} - dead_nodes = set(registered_nodes).difference(all_node_ids) - new_nodes = all_node_ids.difference(registered_nodes) - - # Unregister dead nodes - for node_id in dead_nodes: - client_proxy = registered_nodes[node_id] - client_manager.unregister(client_proxy) - del registered_nodes[node_id] - - # Register new nodes - for node_id in new_nodes: - client_proxy = DriverClientProxy( - node_id=node_id, - driver=driver, - anonymous=False, - run_id=run_id, - ) - if client_manager.register(client_proxy): - registered_nodes[node_id] = client_proxy - else: - raise RuntimeError("Could not register node.") - - # Sleep for 3 seconds - time.sleep(3) diff --git a/src/py/flwr/server/compat/app_test.py b/src/py/flwr/server/compat/app_test.py deleted file mode 100644 index 77eeda2848c1..000000000000 --- a/src/py/flwr/server/compat/app_test.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2022 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Flower Driver app tests.""" - - -import threading -import time -import unittest -from unittest.mock import MagicMock - -from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 - CreateRunResponse, - GetNodesResponse, -) -from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.server.client_manager import SimpleClientManager - -from .app import update_client_manager - - -class TestClientManagerWithDriver(unittest.TestCase): - """Tests for ClientManager. - - Considering multi-threading, all tests assume that the `update_client_manager()` - updates the ClientManager every 3 seconds. - """ - - def test_simple_client_manager_update(self) -> None: - """Tests if the node update works correctly.""" - # Prepare - expected_nodes = [Node(node_id=i, anonymous=False) for i in range(100)] - expected_updated_nodes = [ - Node(node_id=i, anonymous=False) for i in range(80, 120) - ] - driver = MagicMock() - driver.stub = "driver stub" - driver.create_run.return_value = CreateRunResponse(run_id=1) - driver.get_nodes.return_value = GetNodesResponse(nodes=expected_nodes) - client_manager = SimpleClientManager() - lock = threading.Lock() - f_stop = threading.Event() - # Execute - thread = threading.Thread( - target=update_client_manager, - args=(driver, client_manager, lock, f_stop), - daemon=True, - ) - thread.start() - # Wait until all nodes are registered via `client_manager.sample()` - client_manager.sample(len(expected_nodes)) - # Retrieve all nodes in `client_manager` - node_ids = {proxy.node_id for proxy in client_manager.all().values()} - # Update the GetNodesResponse and wait until the `client_manager` is updated - driver.get_nodes.return_value = GetNodesResponse(nodes=expected_updated_nodes) - while True: - with lock: - if len(client_manager.all()) == len(expected_updated_nodes): - break - time.sleep(1.3) - # Retrieve all nodes in `client_manager` - updated_node_ids = {proxy.node_id for proxy in client_manager.all().values()} - # Simulate `driver.disconnect()` - driver.stub = None - - # Assert - driver.create_run.assert_called_once() - assert node_ids == {node.node_id for node in expected_nodes} - assert updated_node_ids == {node.node_id for node in expected_updated_nodes} - - f_stop.set() - # Exit - thread.join() diff --git a/src/py/flwr/server/compat/app_utils.py b/src/py/flwr/server/compat/app_utils.py new file mode 100644 index 000000000000..696ec1132c4a --- /dev/null +++ b/src/py/flwr/server/compat/app_utils.py @@ -0,0 +1,102 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions for the `start_driver`.""" + + +import threading +import time +from typing import Dict, Tuple + +from ..client_manager import ClientManager +from ..compat.driver_client_proxy import DriverClientProxy +from ..driver import Driver + + +def start_update_client_manager_thread( + driver: Driver, + client_manager: ClientManager, +) -> Tuple[threading.Thread, threading.Event]: + """Periodically update the nodes list in the client manager in a thread. + + This function starts a thread that periodically uses the associated driver to + get all node_ids. Each node_id is then converted into a `DriverClientProxy` + instance and stored in the `registered_nodes` dictionary with node_id as key. + + New nodes will be added to the ClientManager via `client_manager.register()`, + and dead nodes will be removed from the ClientManager via + `client_manager.unregister()`. + + Parameters + ---------- + driver : Driver + The Driver object to use. + client_manager : ClientManager + The ClientManager object to be updated. + + Returns + ------- + threading.Thread + A thread that updates the ClientManager and handles the stop event. + threading.Event + An event that, when set, signals the thread to stop. + """ + f_stop = threading.Event() + thread = threading.Thread( + target=_update_client_manager, + args=( + driver, + client_manager, + f_stop, + ), + ) + thread.start() + + return thread, f_stop + + +def _update_client_manager( + driver: Driver, + client_manager: ClientManager, + f_stop: threading.Event, +) -> None: + """Update the nodes list in the client manager.""" + # Loop until the driver is disconnected + registered_nodes: Dict[int, DriverClientProxy] = {} + while not f_stop.is_set(): + all_node_ids = set(driver.get_node_ids()) + dead_nodes = set(registered_nodes).difference(all_node_ids) + new_nodes = all_node_ids.difference(registered_nodes) + + # Unregister dead nodes + for node_id in dead_nodes: + client_proxy = registered_nodes[node_id] + client_manager.unregister(client_proxy) + del registered_nodes[node_id] + + # Register new nodes + for node_id in new_nodes: + client_proxy = DriverClientProxy( + node_id=node_id, + driver=driver.grpc_driver, # type: ignore + anonymous=False, + run_id=driver.run_id, # type: ignore + ) + if client_manager.register(client_proxy): + registered_nodes[node_id] = client_proxy + else: + raise RuntimeError("Could not register node.") + + # Sleep for 3 seconds + time.sleep(3) diff --git a/src/py/flwr/server/compat/app_utils_test.py b/src/py/flwr/server/compat/app_utils_test.py new file mode 100644 index 000000000000..7e47e6eaaf32 --- /dev/null +++ b/src/py/flwr/server/compat/app_utils_test.py @@ -0,0 +1,62 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for utility functions for the `start_driver`.""" + + +import time +import unittest +from unittest.mock import Mock, patch + +from ..client_manager import SimpleClientManager +from .app_utils import start_update_client_manager_thread + + +class TestUtils(unittest.TestCase): + """Tests for utility functions.""" + + def test_start_update_client_manager_thread(self) -> None: + """Test start_update_client_manager_thread function.""" + # Prepare + sleep = time.sleep + sleep_patch = patch("time.sleep", lambda x: sleep(x / 100)) + sleep_patch.start() + expected_node_ids = list(range(100)) + updated_expected_node_ids = list(range(80, 120)) + driver = Mock() + driver.grpc_driver = Mock() + driver.run_id = 123 + driver.get_node_ids.return_value = expected_node_ids + client_manager = SimpleClientManager() + + # Execute + thread, f_stop = start_update_client_manager_thread(driver, client_manager) + # Wait until all nodes are registered via `client_manager.sample()` + client_manager.sample(len(expected_node_ids)) + # Retrieve all nodes in `client_manager` + node_ids = {proxy.node_id for proxy in client_manager.all().values()} + # Update the GetNodesResponse and wait until the `client_manager` is updated + driver.get_node_ids.return_value = updated_expected_node_ids + sleep(0.1) + # Retrieve all nodes in `client_manager` + updated_node_ids = {proxy.node_id for proxy in client_manager.all().values()} + # Stop the thread + f_stop.set() + + # Assert + assert node_ids == set(expected_node_ids) + assert updated_node_ids == set(updated_expected_node_ids) + + # Exit + thread.join() diff --git a/src/py/flwr/server/compat/legacy_context.py b/src/py/flwr/server/compat/legacy_context.py new file mode 100644 index 000000000000..0b00c98bb16d --- /dev/null +++ b/src/py/flwr/server/compat/legacy_context.py @@ -0,0 +1,55 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Legacy Context.""" + + +from dataclasses import dataclass +from typing import Optional + +from flwr.common import Context, RecordSet + +from ..client_manager import ClientManager, SimpleClientManager +from ..history import History +from ..server_config import ServerConfig +from ..strategy import FedAvg, Strategy + + +@dataclass +class LegacyContext(Context): + """Legacy Context.""" + + config: ServerConfig + strategy: Strategy + client_manager: ClientManager + history: History + + def __init__( + self, + state: RecordSet, + config: Optional[ServerConfig] = None, + strategy: Optional[Strategy] = None, + client_manager: Optional[ClientManager] = None, + ) -> None: + if config is None: + config = ServerConfig() + if strategy is None: + strategy = FedAvg() + if client_manager is None: + client_manager = SimpleClientManager() + self.config = config + self.strategy = strategy + self.client_manager = client_manager + self.history = History() + super().__init__(state) diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index e7205ebd1444..19fd16fb0c1a 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -107,7 +107,7 @@ def run_server_app() -> None: run(server_app_attr, driver, server_app_dir) # Clean up - del driver + driver.__del__() # pylint: disable=unnecessary-dunder-call event(EventType.RUN_SERVER_APP_LEAVE) diff --git a/src/py/flwr/server/typing.py b/src/py/flwr/server/typing.py index fa84322bc785..01143af74392 100644 --- a/src/py/flwr/server/typing.py +++ b/src/py/flwr/server/typing.py @@ -22,3 +22,4 @@ from .driver import Driver ServerAppCallable = Callable[[Driver, Context], None] +Workflow = Callable[[Driver, Context], None] diff --git a/src/py/flwr/server/workflow/__init__.py b/src/py/flwr/server/workflow/__init__.py new file mode 100644 index 000000000000..098b0dbfb92f --- /dev/null +++ b/src/py/flwr/server/workflow/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Workflows.""" + + +from .default_workflows import DefaultWorkflow + +__all__ = [ + "DefaultWorkflow", +] diff --git a/src/py/flwr/server/workflow/default_workflows.py b/src/py/flwr/server/workflow/default_workflows.py new file mode 100644 index 000000000000..5c6c1e2d114e --- /dev/null +++ b/src/py/flwr/server/workflow/default_workflows.py @@ -0,0 +1,357 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Legacy default workflows.""" + + +import timeit +from logging import DEBUG, INFO +from typing import Optional, cast + +import flwr.common.recordset_compat as compat +from flwr.common import ConfigsRecord, Context, GetParametersIns, log +from flwr.common.constant import ( + MESSAGE_TYPE_EVALUATE, + MESSAGE_TYPE_FIT, + MESSAGE_TYPE_GET_PARAMETERS, +) + +from ..compat.app_utils import start_update_client_manager_thread +from ..compat.legacy_context import LegacyContext +from ..driver import Driver +from ..typing import Workflow + +KEY_CURRENT_ROUND = "current_round" +KEY_START_TIME = "start_time" +CONFIGS_RECORD_KEY = "config" +PARAMS_RECORD_KEY = "parameters" + + +class DefaultWorkflow: + """Default workflow in Flower.""" + + def __init__( + self, + fit_workflow: Optional[Workflow] = None, + evaluate_workflow: Optional[Workflow] = None, + ) -> None: + if fit_workflow is None: + fit_workflow = default_fit_workflow + if evaluate_workflow is None: + evaluate_workflow = default_evaluate_workflow + self.fit_workflow: Workflow = fit_workflow + self.evaluate_workflow: Workflow = evaluate_workflow + + def __call__(self, driver: Driver, context: Context) -> None: + """Execute the workflow.""" + if not isinstance(context, LegacyContext): + raise TypeError( + f"Expect a LegacyContext, but get {type(context).__name__}." + ) + + # Start the thread updating nodes + thread, f_stop = start_update_client_manager_thread( + driver, context.client_manager + ) + + # Initialize parameters + default_init_params_workflow(driver, context) + + # Run federated learning for num_rounds + log(INFO, "FL starting") + start_time = timeit.default_timer() + cfg = ConfigsRecord() + cfg[KEY_START_TIME] = start_time + context.state.configs_records[CONFIGS_RECORD_KEY] = cfg + + for current_round in range(1, context.config.num_rounds + 1): + cfg[KEY_CURRENT_ROUND] = current_round + + # Fit round + self.fit_workflow(driver, context) + + # Centralized evaluation + default_centralized_evaluation_workflow(driver, context) + + # Evaluate round + self.evaluate_workflow(driver, context) + + # Bookkeeping + end_time = timeit.default_timer() + elapsed = end_time - start_time + log(INFO, "FL finished in %s", elapsed) + + # Log results + hist = context.history + log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed)) + log( + INFO, + "app_fit: metrics_distributed_fit %s", + str(hist.metrics_distributed_fit), + ) + log(INFO, "app_fit: metrics_distributed %s", str(hist.metrics_distributed)) + log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized)) + log(INFO, "app_fit: metrics_centralized %s", str(hist.metrics_centralized)) + + # Terminate the thread + f_stop.set() + del driver + thread.join() + + +def default_init_params_workflow(driver: Driver, context: Context) -> None: + """Execute the default workflow for parameters initialization.""" + if not isinstance(context, LegacyContext): + raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.") + + log(INFO, "Initializing global parameters") + parameters = context.strategy.initialize_parameters( + client_manager=context.client_manager + ) + if parameters is not None: + log(INFO, "Using initial parameters provided by strategy") + paramsrecord = compat.parameters_to_parametersrecord( + parameters, keep_input=True + ) + else: + # Get initial parameters from one of the clients + log(INFO, "Requesting initial parameters from one random client") + random_client = context.client_manager.sample(1)[0] + # Send GetParametersIns and get the response + content = compat.getparametersins_to_recordset(GetParametersIns({})) + messages = driver.send_and_receive( + [ + driver.create_message( + content=content, + message_type=MESSAGE_TYPE_GET_PARAMETERS, + dst_node_id=random_client.node_id, + group_id="", + ttl="", + ) + ] + ) + log(INFO, "Received initial parameters from one random client") + msg = list(messages)[0] + paramsrecord = next(iter(msg.content.parameters_records.values())) + + context.state.parameters_records[PARAMS_RECORD_KEY] = paramsrecord + + # Evaluate initial parameters + log(INFO, "Evaluating initial parameters") + parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True) + res = context.strategy.evaluate(0, parameters=parameters) + if res is not None: + log( + INFO, + "initial parameters (loss, other metrics): %s, %s", + res[0], + res[1], + ) + context.history.add_loss_centralized(server_round=0, loss=res[0]) + context.history.add_metrics_centralized(server_round=0, metrics=res[1]) + + +def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None: + """Execute the default workflow for centralized evaluation.""" + if not isinstance(context, LegacyContext): + raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.") + + # Retrieve current_round and start_time from the context + cfg = context.state.configs_records[CONFIGS_RECORD_KEY] + current_round = cast(int, cfg[KEY_CURRENT_ROUND]) + start_time = cast(float, cfg[KEY_START_TIME]) + + # Centralized evaluation + parameters = compat.parametersrecord_to_parameters( + record=context.state.parameters_records[PARAMS_RECORD_KEY], + keep_input=True, + ) + res_cen = context.strategy.evaluate(current_round, parameters=parameters) + if res_cen is not None: + loss_cen, metrics_cen = res_cen + log( + INFO, + "fit progress: (%s, %s, %s, %s)", + current_round, + loss_cen, + metrics_cen, + timeit.default_timer() - start_time, + ) + context.history.add_loss_centralized(server_round=current_round, loss=loss_cen) + context.history.add_metrics_centralized( + server_round=current_round, metrics=metrics_cen + ) + + +def default_fit_workflow(driver: Driver, context: Context) -> None: + """Execute the default workflow for a single fit round.""" + if not isinstance(context, LegacyContext): + raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.") + + # Get current_round and parameters + cfg = context.state.configs_records[CONFIGS_RECORD_KEY] + current_round = cast(int, cfg[KEY_CURRENT_ROUND]) + parametersrecord = context.state.parameters_records[PARAMS_RECORD_KEY] + parameters = compat.parametersrecord_to_parameters( + parametersrecord, keep_input=True + ) + + # Get clients and their respective instructions from strategy + client_instructions = context.strategy.configure_fit( + server_round=current_round, + parameters=parameters, + client_manager=context.client_manager, + ) + + if not client_instructions: + log(INFO, "fit_round %s: no clients selected, cancel", current_round) + return + log( + DEBUG, + "fit_round %s: strategy sampled %s clients (out of %s)", + current_round, + len(client_instructions), + context.client_manager.num_available(), + ) + + # Build dictionary mapping node_id to ClientProxy + node_id_to_proxy = {proxy.node_id: proxy for proxy, _ in client_instructions} + + # Build out messages + out_messages = [ + driver.create_message( + content=compat.fitins_to_recordset(fitins, True), + message_type=MESSAGE_TYPE_FIT, + dst_node_id=proxy.node_id, + group_id="", + ttl="", + ) + for proxy, fitins in client_instructions + ] + + # Send instructions to clients and + # collect `fit` results from all clients participating in this round + messages = list(driver.send_and_receive(out_messages)) + del out_messages + + # No exception/failure handling currently + log( + DEBUG, + "fit_round %s received %s results and %s failures", + current_round, + len(messages), + 0, + ) + + # Aggregate training results + results = [ + ( + node_id_to_proxy[msg.metadata.src_node_id], + compat.recordset_to_fitres(msg.content, False), + ) + for msg in messages + ] + aggregated_result = context.strategy.aggregate_fit(current_round, results, []) + parameters_aggregated, metrics_aggregated = aggregated_result + + # Update the parameters and write history + if parameters_aggregated: + paramsrecord = compat.parameters_to_parametersrecord( + parameters_aggregated, True + ) + context.state.parameters_records[PARAMS_RECORD_KEY] = paramsrecord + context.history.add_metrics_distributed_fit( + server_round=current_round, metrics=metrics_aggregated + ) + + +def default_evaluate_workflow(driver: Driver, context: Context) -> None: + """Execute the default workflow for a single evaluate round.""" + if not isinstance(context, LegacyContext): + raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.") + + # Get current_round and parameters + cfg = context.state.configs_records[CONFIGS_RECORD_KEY] + current_round = cast(int, cfg[KEY_CURRENT_ROUND]) + parametersrecord = context.state.parameters_records[PARAMS_RECORD_KEY] + parameters = compat.parametersrecord_to_parameters( + parametersrecord, keep_input=True + ) + + # Get clients and their respective instructions from strategy + client_instructions = context.strategy.configure_evaluate( + server_round=current_round, + parameters=parameters, + client_manager=context.client_manager, + ) + if not client_instructions: + log(INFO, "evaluate_round %s: no clients selected, cancel", current_round) + return + log( + DEBUG, + "evaluate_round %s: strategy sampled %s clients (out of %s)", + current_round, + len(client_instructions), + context.client_manager.num_available(), + ) + + # Build dictionary mapping node_id to ClientProxy + node_id_to_proxy = {proxy.node_id: proxy for proxy, _ in client_instructions} + + # Build out messages + out_messages = [ + driver.create_message( + content=compat.evaluateins_to_recordset(evalins, True), + message_type=MESSAGE_TYPE_EVALUATE, + dst_node_id=proxy.node_id, + group_id="", + ttl="", + ) + for proxy, evalins in client_instructions + ] + + # Send instructions to clients and + # collect `evaluate` results from all clients participating in this round + messages = list(driver.send_and_receive(out_messages)) + del out_messages + + # No exception/failure handling currently + log( + DEBUG, + "evaluate_round %s received %s results and %s failures", + current_round, + len(messages), + 0, + ) + + # Aggregate the evaluation results + results = [ + ( + node_id_to_proxy[msg.metadata.src_node_id], + compat.recordset_to_evaluateres(msg.content), + ) + for msg in messages + ] + aggregated_result = context.strategy.aggregate_evaluate(current_round, results, []) + + loss_aggregated, metrics_aggregated = aggregated_result + + # Write history + if loss_aggregated is not None: + context.history.add_loss_distributed( + server_round=current_round, loss=loss_aggregated + ) + context.history.add_metrics_distributed( + server_round=current_round, metrics=metrics_aggregated + )