diff --git a/.github/workflows/framework.yml b/.github/workflows/framework.yml index 784f04750c5e..feb08229be06 100644 --- a/.github/workflows/framework.yml +++ b/.github/workflows/framework.yml @@ -31,6 +31,8 @@ jobs: steps: - uses: actions/checkout@v4 + with: + fetch-depth: 0 - name: Bootstrap uses: ./.github/actions/bootstrap with: diff --git a/README.md b/README.md index a010abfcb2f5..16cb7f1cfaf6 100644 --- a/README.md +++ b/README.md @@ -153,6 +153,7 @@ Other [examples](https://github.com/adap/flower/tree/main/examples): - [Flower with KaplanMeierFitter from the lifelines library](https://github.com/adap/flower/tree/main/examples/federated-kaplan-meier-fitter) - [Sample Level Privacy with Opacus](https://github.com/adap/flower/tree/main/examples/opacus) - [Sample Level Privacy with TensorFlow-Privacy](https://github.com/adap/flower/tree/main/examples/tensorflow-privacy) +- [Flower with a Tabular Dataset] (https://github.com/adap/flower/tree/main/examples/fl-tabular) ## Community diff --git a/benchmarks/flowertune-llm/README.md b/benchmarks/flowertune-llm/README.md new file mode 100644 index 000000000000..0cb69e7ff9c7 --- /dev/null +++ b/benchmarks/flowertune-llm/README.md @@ -0,0 +1,61 @@ +![](_static/flower_llm.jpg) + +# FlowerTune LLM Leaderboard + +This repository guides you through the process of federated LLM instruction tuning with a +pre-trained [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.3) model across 4 domains --- general NLP, finance, medical and code. + +Please follow the instructions to run and evaluate the federated LLMs. + +## Create a new project + +As the first step, please register a Flower account on [Flower website](https://flower.ai/login). +Assuming `flwr` package is already installed on your system (check [here](https://flower.ai/docs/framework/how-to-install-flower.html) for `flwr` installation). +We provide a single-line command to create a new project directory based on your selected challenge: + +```shell +flwr new --framework=flwrtune --username=your_flower_account +``` + +Then you will see a prompt to ask your project name and the choice of LLM challenges from the set of general NLP, finance, medical and code. +Type your project name and select your preferred challenge, +and then a new project directory will be generated automatically. + +### Structure + +After running `flwr new`, you will see a new directory generated with the following structure: + +```bash + + ├── README.md # <- Instructions + ├── pyproject.toml # <- Environment dependencies + └── + ├── app.py # <- Flower ClientApp/ServerApp build + ├── client.py # <- Flower client constructor + ├── server.py # <- Sever-related functions + ├── models.py # <- Model build + ├── dataset.py # <- Dataset and tokenizer build + ├── conf/config.yaml # <- User configuration + └── conf/static_config.yaml # <- Static configuration +``` + +This can serve as the starting point for you to build up your own federated LLM fine-tuning methods. +Please note that any modification to the content of `conf/static_config.yaml` is strictly prohibited for those who wish to participate in the [LLM Leaderboard](https://flower.ai/benchmarks/llm-leaderboard). +Otherwise, the submission will not be considered. + +## Run FlowerTune LLM challenges + +With a new project directory created, running a baseline challenge can be done by: + +1. Navigate inside the directory that you just created. + + +2. Follow the `Environments setup` section of `README.md` in the project directory to install project dependencies. + + +3. Run the challenge as indicated in the `Running the challenge` section in the `README.md`. + +## Evaluate pre-trained LLMs + +After the LLM fine-tuning finished, evaluate the performance of your pre-trained LLMs +following the `README.md` in `evaluation` directory. diff --git a/benchmarks/flowertune-llm/_static/flower_llm.jpg b/benchmarks/flowertune-llm/_static/flower_llm.jpg new file mode 100644 index 000000000000..96081d9c2ad1 Binary files /dev/null and b/benchmarks/flowertune-llm/_static/flower_llm.jpg differ diff --git a/datasets/flwr_datasets/__init__.py b/datasets/flwr_datasets/__init__.py index d084780102ce..bd68fa43c606 100644 --- a/datasets/flwr_datasets/__init__.py +++ b/datasets/flwr_datasets/__init__.py @@ -23,11 +23,12 @@ __all__ = [ "FederatedDataset", - "partitioner", "metrics", - "visualization", + "partitioner", "preprocessor", "utils", + "visualization", ] + __version__ = _package_version diff --git a/datasets/flwr_datasets/partitioner/__init__.py b/datasets/flwr_datasets/partitioner/__init__.py index 241f320714e2..3a85c195707a 100644 --- a/datasets/flwr_datasets/partitioner/__init__.py +++ b/datasets/flwr_datasets/partitioner/__init__.py @@ -28,15 +28,15 @@ from .semantic_partitioner import SemanticPartitioner __all__ = [ + "DirichletPartitioner", + "ExponentialPartitioner", "IidPartitioner", - "Partitioner", + "InnerDirichletPartitioner", + "LinearPartitioner", "NaturalIdPartitioner", - "DirichletPartitioner", + "Partitioner", + "ShardPartitioner", "SizePartitioner", - "LinearPartitioner", - "InnerDirichletPartitioner", "SquarePartitioner", - "ShardPartitioner", - "ExponentialPartitioner", "SemanticPartitioner" ] diff --git a/datasets/flwr_datasets/preprocessor/__init__.py b/datasets/flwr_datasets/preprocessor/__init__.py index bab5d82a2035..67b2aaadc3d2 100644 --- a/datasets/flwr_datasets/preprocessor/__init__.py +++ b/datasets/flwr_datasets/preprocessor/__init__.py @@ -20,7 +20,7 @@ from .preprocessor import Preprocessor __all__ = [ + "Divider", "Merger", "Preprocessor", - "Divider", ] diff --git a/datasets/flwr_datasets/visualization/__init__.py b/datasets/flwr_datasets/visualization/__init__.py index 801a38dcafc6..b55e406c71db 100644 --- a/datasets/flwr_datasets/visualization/__init__.py +++ b/datasets/flwr_datasets/visualization/__init__.py @@ -19,6 +19,6 @@ from .label_distribution import plot_label_distributions __all__ = [ - "plot_label_distributions", "plot_comparison_label_distribution", + "plot_label_distributions", ] diff --git a/datasets/flwr_datasets/visualization/bar_plot.py b/datasets/flwr_datasets/visualization/bar_plot.py index 6326b24a9695..339ff0967906 100644 --- a/datasets/flwr_datasets/visualization/bar_plot.py +++ b/datasets/flwr_datasets/visualization/bar_plot.py @@ -38,7 +38,6 @@ def _plot_bar( plot_kwargs: Optional[Dict[str, Any]], legend_kwargs: Optional[Dict[str, Any]], ) -> Axes: - if axis is None: if figsize is None: figsize = _initialize_figsize( diff --git a/datasets/flwr_datasets/visualization/heatmap_plot.py b/datasets/flwr_datasets/visualization/heatmap_plot.py index 2e593a79368e..3c87de7693ae 100644 --- a/datasets/flwr_datasets/visualization/heatmap_plot.py +++ b/datasets/flwr_datasets/visualization/heatmap_plot.py @@ -39,7 +39,6 @@ def _plot_heatmap( plot_kwargs: Optional[Dict[str, Any]], legend_kwargs: Optional[Dict[str, Any]], ) -> Axes: - if axis is None: if figsize is None: figsize = _initialize_figsize( @@ -92,7 +91,6 @@ def _initialize_figsize( num_partitions: int, num_labels: int, ) -> Tuple[float, float]: - figsize = (0.0, 0.0) if partition_id_axis == "x": figsize = (3 * np.sqrt(num_partitions), np.sqrt(num_labels)) diff --git a/dev/format.sh b/dev/format.sh index b9e3b00dffe1..71edf9c6065a 100755 --- a/dev/format.sh +++ b/dev/format.sh @@ -3,6 +3,7 @@ set -e cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/../ # Python +python -m flwr_tool.check_copyright src/py/flwr python -m flwr_tool.init_py_fix src/py/flwr python -m isort --skip src/py/flwr/proto src/py python -m black -q --exclude src/py/flwr/proto src/py diff --git a/dev/test.sh b/dev/test.sh index 7cabf35abf41..5b827380bc50 100755 --- a/dev/test.sh +++ b/dev/test.sh @@ -58,6 +58,10 @@ echo "- All Markdown checks passed" echo "- Start license checks" +echo "- copyright: start" +python -m flwr_tool.check_copyright src/py/flwr +echo "- copyright: done" + echo "- licensecheck: start" python -m licensecheck -u poetry --fail-licenses gpl --zero echo "- licensecheck: done" diff --git a/doc/source/contributor-how-to-build-docker-images.rst b/doc/source/contributor-how-to-build-docker-images.rst index 2efc739f54f0..4c178f439a07 100644 --- a/doc/source/contributor-how-to-build-docker-images.rst +++ b/doc/source/contributor-how-to-build-docker-images.rst @@ -65,7 +65,7 @@ Building the base image * - ``FLWR_VERSION`` - Version of Flower to be installed. - Yes - - ``1.8.0`` + - ``1.9.0`` * - ``FLWR_PACKAGE`` - The Flower package to be installed. - No @@ -73,14 +73,14 @@ Building the base image The following example creates a base Ubuntu/Alpine image with Python 3.11.0, pip 23.0.1, -setuptools 69.0.2 and Flower 1.8.0: +setuptools 69.0.2 and Flower 1.9.0: .. code-block:: bash $ cd src/docker/base/ $ docker build \ --build-arg PYTHON_VERSION=3.11.0 \ - --build-arg FLWR_VERSION=1.8.0 \ + --build-arg FLWR_VERSION=1.9.0 \ --build-arg PIP_VERSION=23.0.1 \ --build-arg SETUPTOOLS_VERSION=69.0.2 \ -t flwr_base:0.1.0 . @@ -106,7 +106,7 @@ Building the SuperLink/SuperNode or ServerApp image * - ``BASE_IMAGE`` - The Tag of the Flower base image. - Yes - - ``1.8.0-py3.10-ubuntu22.04`` + - ``1.9.0-py3.10-ubuntu22.04`` The following example creates a SuperLink/SuperNode or ServerApp image with the official Flower base image: diff --git a/doc/source/contributor-tutorial-get-started-as-a-contributor.rst b/doc/source/contributor-tutorial-get-started-as-a-contributor.rst index 43f9739987ac..d7d647996a3d 100644 --- a/doc/source/contributor-tutorial-get-started-as-a-contributor.rst +++ b/doc/source/contributor-tutorial-get-started-as-a-contributor.rst @@ -17,8 +17,8 @@ supports `PEP 517 `_. Developer Machine Setup ----------------------- -Preliminarities -~~~~~~~~~~~~~~~ +Preliminaries +~~~~~~~~~~~~~ Some system-wide dependencies are needed. For macOS diff --git a/doc/source/how-to-install-flower.rst b/doc/source/how-to-install-flower.rst index 725a7468090c..b00e2ae803ab 100644 --- a/doc/source/how-to-install-flower.rst +++ b/doc/source/how-to-install-flower.rst @@ -20,7 +20,7 @@ Stable releases are available on `PyPI `_:: For simulations that use the Virtual Client Engine, ``flwr`` should be installed with the ``simulation`` extra:: - python -m pip install flwr[simulation] + python -m pip install "flwr[simulation]" Using conda (or mamba) diff --git a/doc/source/how-to-run-flower-using-docker.rst b/doc/source/how-to-run-flower-using-docker.rst index 54079968d417..7d9ec883960a 100644 --- a/doc/source/how-to-run-flower-using-docker.rst +++ b/doc/source/how-to-run-flower-using-docker.rst @@ -38,10 +38,10 @@ If you're looking to try out Flower, you can use the following command: .. code-block:: bash - $ docker run --rm -p 9091:9091 -p 9092:9092 flwr/superlink:1.8.0 --insecure + $ docker run --rm -p 9091:9091 -p 9092:9092 flwr/superlink:1.9.0 --insecure -The command pulls the Docker image with the tag ``1.8.0`` from Docker Hub. The tag specifies -the Flower version. In this case, Flower 1.8.0. The ``--rm`` flag tells Docker to remove the +The command pulls the Docker image with the tag ``1.9.0`` from Docker Hub. The tag specifies +the Flower version. In this case, Flower 1.9.0. The ``--rm`` flag tells Docker to remove the container after it exits. .. note:: @@ -66,7 +66,7 @@ You can use ``--help`` to view all available flags that the SuperLink supports: .. code-block:: bash - $ docker run --rm flwr/superlink:1.8.0 --help + $ docker run --rm flwr/superlink:1.9.0 --help Mounting a volume to store the state on the host system ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -88,7 +88,7 @@ container. Furthermore, we use the flag ``--database`` to specify the name of th $ mkdir state $ sudo chown -R 49999:49999 state $ docker run --rm \ - -p 9091:9091 -p 9092:9092 --volume ./state/:/app/state flwr/superlink:1.8.0 \ + -p 9091:9091 -p 9092:9092 --volume ./state/:/app/state flwr/superlink:1.9.0 \ --insecure \ --database state.db @@ -118,7 +118,7 @@ with the ``--ssl-ca-certfile``, ``--ssl-certfile`` and ``--ssl-keyfile`` flag. $ docker run --rm \ -p 9091:9091 -p 9092:9092 \ - --volume ./certificates/:/app/certificates/:ro flwr/superlink:nightly \ + --volume ./certificates/:/app/certificates/:ro flwr/superlink:1.9.0 \ --ssl-ca-certfile certificates/ca.crt \ --ssl-certfile certificates/server.pem \ --ssl-keyfile certificates/server.key @@ -136,14 +136,6 @@ Flower SuperNode The SuperNode Docker image comes with a pre-installed version of Flower and serves as a base for building your own SuperNode image. -.. important:: - - The SuperNode Docker image currently works only with the 1.9.0-nightly release. A stable version - will be available when Flower 1.9.0 (stable) gets released (ETA: May). A SuperNode nightly image - must be paired with the corresponding SuperLink and ServerApp nightly images released on the same - day. To ensure the versions are in sync, using the concrete tag, e.g., ``1.9.0.dev20240501`` - instead of ``nightly`` is recommended. - We will use the ``quickstart-pytorch`` example, which you can find in the Flower repository, to illustrate how you can dockerize your ClientApp. @@ -204,7 +196,7 @@ The ``Dockerfile.supernode`` contains the instructions that assemble the SuperNo .. code-block:: dockerfile - FROM flwr/supernode:nightly + FROM flwr/supernode:1.9.0 WORKDIR /app @@ -214,7 +206,7 @@ The ``Dockerfile.supernode`` contains the instructions that assemble the SuperNo COPY client.py ./ ENTRYPOINT ["flower-client-app", "client:app"] -In the first two lines, we instruct Docker to use the SuperNode image tagged ``nightly`` as a base +In the first two lines, we instruct Docker to use the SuperNode image tagged ``1.9.0`` as a base image and set our working directory to ``/app``. The following instructions will now be executed in the ``/app`` directory. Next, we install the ClientApp dependencies by copying the ``requirements.txt`` file into the image and run ``pip install``. In the last two lines, @@ -275,7 +267,7 @@ To see all available flags that the SuperNode supports, run: .. code-block:: bash - $ docker run --rm flwr/supernode:nightly --help + $ docker run --rm flwr/supernode:1.9.0 --help Enabling SSL for secure connections ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -325,14 +317,14 @@ The ``Dockerfile.serverapp`` contains the instructions that assemble the ServerA .. code-block:: dockerfile - FROM flwr/serverapp:1.8.0 + FROM flwr/serverapp:1.9.0 WORKDIR /app COPY server.py ./ ENTRYPOINT ["flower-server-app", "server:app"] -In the first two lines, we instruct Docker to use the ServerApp image tagged ``1.8.0`` as a base +In the first two lines, we instruct Docker to use the ServerApp image tagged ``1.9.0`` as a base image and set our working directory to ``/app``. The following instructions will now be executed in the ``/app`` directory. In the last two lines, we copy the ``server.py`` module into the image and set the entry point to ``flower-server-app`` with the argument ``server:app``. @@ -391,7 +383,7 @@ To see all available flags that the ServerApp supports, run: .. code-block:: bash - $ docker run --rm flwr/serverapp:1.8.0 --help + $ docker run --rm flwr/serverapp:1.9.0 --help Enabling SSL for secure connections ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -425,7 +417,7 @@ Run the Docker image with the ``-u`` flag and specify ``root`` as the username: .. code-block:: bash - $ docker run --rm -u root flwr/superlink:1.8.0 + $ docker run --rm -u root flwr/superlink:1.9.0 This command will run the Docker container with root user privileges. @@ -436,7 +428,7 @@ missing system dependencies, you can use the ``USER root`` directive within your .. code-block:: dockerfile - FROM flwr/supernode:1.8.0 + FROM flwr/supernode:1.9.0 # Switch to root user USER root @@ -456,6 +448,13 @@ Using a different Flower version If you want to use a different version of Flower, for example Flower nightly, you can do so by changing the tag. All available versions are on `Docker Hub `__. +.. important:: + + When using Flower nightly, the SuperLink nightly image must be paired with the corresponding + SuperNode and ServerApp nightly images released on the same day. To ensure the versions are + in sync, using the concrete tag, e.g., ``1.10.0.dev20240610`` instead of ``nightly`` is + recommended. + Pinning a Docker image to a specific version ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -464,19 +463,19 @@ updates of system dependencies that should not change the functionality of Flowe want to ensure that you always use the same image, you can specify the hash of the image instead of the tag. -The following command returns the current image hash referenced by the ``superlink:1.8.0`` tag: +The following command returns the current image hash referenced by the ``superlink:1.9.0`` tag: .. code-block:: bash - $ docker inspect --format='{{index .RepoDigests 0}}' flwr/superlink:1.8.0 - flwr/superlink@sha256:1b855d1fa4e344e4d95db99793f2bb35d8c63f6a1decdd736863bfe4bb0fe46c + $ docker inspect --format='{{index .RepoDigests 0}}' flwr/superlink:1.9.0 + flwr/superlink@sha256:985c24b2b337ab7f15a554fde9d860cede95079bcaa244fda8f12c0805e34c7d Next, we can pin the hash when running a new SuperLink container: .. code-block:: bash $ docker run \ - --rm flwr/superlink@sha256:1b855d1fa4e344e4d95db99793f2bb35d8c63f6a1decdd736863bfe4bb0fe46c \ + --rm flwr/superlink@sha256:985c24b2b337ab7f15a554fde9d860cede95079bcaa244fda8f12c0805e34c7d \ --insecure Setting environment variables @@ -487,4 +486,4 @@ To set a variable inside a Docker container, you can use the ``-e = .. code-block:: bash $ docker run -e FLWR_TELEMETRY_ENABLED=0 \ - --rm flwr/superlink:1.8.0 --insecure + --rm flwr/superlink:1.9.0 --insecure diff --git a/doc/source/how-to-use-differential-privacy.rst b/doc/source/how-to-use-differential-privacy.rst index c8901bd906cc..5d4fa3dca1a4 100644 --- a/doc/source/how-to-use-differential-privacy.rst +++ b/doc/source/how-to-use-differential-privacy.rst @@ -9,7 +9,7 @@ This guide explains how you can utilize differential privacy in the Flower frame Central Differential Privacy ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -This approach consists of two seprate phases: clipping of the updates and adding noise to the aggregated model. +This approach consists of two separate phases: clipping of the updates and adding noise to the aggregated model. For the clipping phase, Flower framework has made it possible to decide whether to perform clipping on the server side or the client side. - **Server-side Clipping**: This approach has the advantage of the server enforcing uniform clipping across all clients' updates and reducing the communication overhead for clipping values. However, it also has the disadvantage of increasing the computational load on the server due to the need to perform the clipping operation for all clients. diff --git a/doc/source/how-to-use-strategies.rst b/doc/source/how-to-use-strategies.rst index d0e2cd63a091..8ac120124951 100644 --- a/doc/source/how-to-use-strategies.rst +++ b/doc/source/how-to-use-strategies.rst @@ -72,7 +72,7 @@ It must return a dictionary of arbitrary configuration values :code:`client.fit ) fl.server.start_server(config=fl.server.ServerConfig(num_rounds=3), strategy=strategy) -The :code:`on_fit_config_fn` can be used to pass arbitrary configuration values from server to client, and poetentially change these values each round, for example, to adjust the learning rate. +The :code:`on_fit_config_fn` can be used to pass arbitrary configuration values from server to client, and potentially change these values each round, for example, to adjust the learning rate. The client will receive the dictionary returned by the :code:`on_fit_config_fn` in its own :code:`client.fit()` function. Similar to :code:`on_fit_config_fn`, there is also :code:`on_evaluate_config_fn` to customize the configuration sent to :code:`client.evaluate()` diff --git a/examples/fl-tabular/README.md b/examples/fl-tabular/README.md new file mode 100644 index 000000000000..58afd1080b70 --- /dev/null +++ b/examples/fl-tabular/README.md @@ -0,0 +1,39 @@ +# Flower Example on Adult Census Income Tabular Dataset + +This code exemplifies a federated learning setup using the Flower framework on the ["Adult Census Income"](https://huggingface.co/datasets/scikit-learn/adult-census-income) tabular dataset. The "Adult Census Income" dataset contains demographic information such as age, education, occupation, etc., with the target attribute being income level (\<=50K or >50K). The dataset is partitioned into subsets, simulating a federated environment with 5 clients, each holding a distinct portion of the data. Categorical variables are one-hot encoded, and the data is split into training and testing sets. Federated learning is conducted using the FedAvg strategy for 5 rounds. + +This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the dataset. + +## Environments Setup + +Start by cloning the example. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/fl-tabular . && rm -rf flower && cd fl-tabular +``` + +This will create a new directory called `fl-tabular` containing the following files: + +```shell +-- pyproject.toml +-- client.py +-- server.py +-- task.py +-- README.md +``` + +### Installing dependencies + +Project dependencies are defined in `pyproject.toml`. Install them with: + +```shell +pip install . +``` + +## Running Code + +### Federated Using Flower Simulation + +```bash +flower-simulation --server-app server:app --client-app client:app --num-supernodes 5 +``` diff --git a/examples/fl-tabular/client.py b/examples/fl-tabular/client.py new file mode 100644 index 000000000000..228183f4edc4 --- /dev/null +++ b/examples/fl-tabular/client.py @@ -0,0 +1,38 @@ +from flwr.client import Client, ClientApp, NumPyClient +from flwr_datasets import FederatedDataset +from task import set_weights, get_weights, train, evaluate, IncomeClassifier, load_data + +NUMBER_OF_CLIENTS = 5 + + +class FlowerClient(NumPyClient): + def __init__(self, net, trainloader, testloader): + self.net = net + self.trainloader = trainloader + self.testloader = testloader + + def fit(self, parameters, config): + set_weights(self.net, parameters) + train(self.net, self.trainloader) + return get_weights(self.net), len(self.trainloader), {} + + def evaluate(self, parameters, config): + set_weights(self.net, parameters) + loss, accuracy = evaluate(self.net, self.testloader) + return loss, len(self.testloader), {"accuracy": accuracy} + + +def get_client_fn(dataset: FederatedDataset): + def client_fn(cid: str) -> Client: + train_loader, test_loader = load_data(partition_id=int(cid), fds=dataset) + net = IncomeClassifier(14) + return FlowerClient(net, train_loader, test_loader).to_client() + + return client_fn + + +fds = FederatedDataset( + dataset="scikit-learn/adult-census-income", + partitioners={"train": NUMBER_OF_CLIENTS}, +) +app = ClientApp(client_fn=get_client_fn(fds)) diff --git a/examples/fl-tabular/pyproject.toml b/examples/fl-tabular/pyproject.toml new file mode 100644 index 000000000000..21498f73a4f3 --- /dev/null +++ b/examples/fl-tabular/pyproject.toml @@ -0,0 +1,20 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "fl-tabular" +version = "0.1.0" +description = "Adult Census Income Tabular Dataset and Federated Learning in Flower" +authors = [ + { name = "The Flower Authors", email = "hello@flower.ai" }, +] +dependencies = [ + "flwr[simulation]>=1.9.0,<2.0", + "flwr-datasets>=0.1.0,<1.0.0", + "torch==2.1.1", + "scikit-learn==1.5.0", +] + +[tool.hatch.build.targets.wheel] +packages = ["."] diff --git a/examples/fl-tabular/server.py b/examples/fl-tabular/server.py new file mode 100644 index 000000000000..376726f832f7 --- /dev/null +++ b/examples/fl-tabular/server.py @@ -0,0 +1,24 @@ +from flwr.common import ndarrays_to_parameters +from flwr.server import ServerApp, ServerConfig +from flwr.server.strategy import FedAvg +from task import IncomeClassifier, get_weights + +net = IncomeClassifier(input_dim=14) +params = ndarrays_to_parameters(get_weights(net)) + + +def weighted_average(metrics): + accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] + examples = [num_examples for num_examples, _ in metrics] + + return {"accuracy": sum(accuracies) / sum(examples)} + + +strategy = FedAvg( + initial_parameters=params, + evaluate_metrics_aggregation_fn=weighted_average, +) +app = ServerApp( + strategy=strategy, + config=ServerConfig(num_rounds=5), +) diff --git a/examples/fl-tabular/task.py b/examples/fl-tabular/task.py new file mode 100644 index 000000000000..b07365c733d6 --- /dev/null +++ b/examples/fl-tabular/task.py @@ -0,0 +1,108 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import TensorDataset, DataLoader +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler, OrdinalEncoder +from sklearn.compose import ColumnTransformer +from sklearn.pipeline import Pipeline +from collections import OrderedDict +from flwr_datasets import FederatedDataset + + +def load_data(partition_id: int, fds: FederatedDataset): + dataset = fds.load_partition(partition_id, "train").with_format("pandas")[:] + + dataset.dropna(inplace=True) + + categorical_cols = dataset.select_dtypes(include=["object"]).columns + ordinal_encoder = OrdinalEncoder() + dataset[categorical_cols] = ordinal_encoder.fit_transform(dataset[categorical_cols]) + + X = dataset.drop("income", axis=1) + y = dataset["income"] + + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 + ) + + numeric_features = X.select_dtypes(include=["float64", "int64"]).columns + numeric_transformer = Pipeline(steps=[("scaler", StandardScaler())]) + + preprocessor = ColumnTransformer( + transformers=[("num", numeric_transformer, numeric_features)] + ) + + X_train = preprocessor.fit_transform(X_train) + X_test = preprocessor.transform(X_test) + + X_train_tensor = torch.tensor(X_train, dtype=torch.float32) + X_test_tensor = torch.tensor(X_test, dtype=torch.float32) + y_train_tensor = torch.tensor(y_train.values, dtype=torch.float32).view(-1, 1) + y_test_tensor = torch.tensor(y_test.values, dtype=torch.float32).view(-1, 1) + + train_dataset = TensorDataset(X_train_tensor, y_train_tensor) + test_dataset = TensorDataset(X_test_tensor, y_test_tensor) + train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) + test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) + + return train_loader, test_loader + + +class IncomeClassifier(nn.Module): + def __init__(self, input_dim: int): + super(IncomeClassifier, self).__init__() + self.layer1 = nn.Linear(input_dim, 128) + self.layer2 = nn.Linear(128, 64) + self.output = nn.Linear(64, 1) + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + x = self.relu(self.layer1(x)) + x = self.relu(self.layer2(x)) + x = self.sigmoid(self.output(x)) + return x + + +def train(model, train_loader, num_epochs=1): + criterion = nn.BCELoss() + optimizer = optim.Adam(model.parameters(), lr=0.001) + model.train() + for epoch in range(num_epochs): + for X_batch, y_batch in train_loader: + optimizer.zero_grad() + outputs = model(X_batch) + loss = criterion(outputs, y_batch) + loss.backward() + optimizer.step() + + +def evaluate(model, test_loader): + model.eval() + criterion = nn.BCELoss() + loss = 0.0 + correct = 0 + total = 0 + with torch.no_grad(): + for X_batch, y_batch in test_loader: + outputs = model(X_batch) + batch_loss = criterion(outputs, y_batch) + loss += batch_loss.item() + predicted = (outputs > 0.5).float() + total += y_batch.size(0) + correct += (predicted == y_batch).sum().item() + accuracy = correct / total + loss = loss / len(test_loader) + return loss, accuracy + + +def set_weights(net, parameters): + params_dict = zip(net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=True) + + +def get_weights(net): + ndarrays = [val.cpu().numpy() for _, val in net.state_dict().items()] + return ndarrays diff --git a/src/py/flwr/cli/new/new.py b/src/py/flwr/cli/new/new.py index 9bbc016de1a8..94da99dce36e 100644 --- a/src/py/flwr/cli/new/new.py +++ b/src/py/flwr/cli/new/new.py @@ -190,7 +190,7 @@ def new( ) print( typer.style( - f" cd {project_name}\n" + " pip install -e .\n flwr run\n", + f" cd {package_name}\n" + " pip install -e .\n flwr run\n", fg=typer.colors.BRIGHT_CYAN, bold=True, ) diff --git a/src/py/flwr/client/app_test.py b/src/py/flwr/client/app_test.py index 56d6308a0fe2..74ade03f973a 100644 --- a/src/py/flwr/client/app_test.py +++ b/src/py/flwr/client/app_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py index 82539834eaad..2e810f6560f2 100644 --- a/src/py/flwr/client/client_app.py +++ b/src/py/flwr/client/client_app.py @@ -1,4 +1,4 @@ -# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/client/client_test.py b/src/py/flwr/client/client_test.py index 373c676e5edc..343b6cf093b2 100644 --- a/src/py/flwr/client/client_test.py +++ b/src/py/flwr/client/client_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/client/dpfedavg_numpy_client.py b/src/py/flwr/client/dpfedavg_numpy_client.py index ab31a289d29b..c592d10936d5 100644 --- a/src/py/flwr/client/dpfedavg_numpy_client.py +++ b/src/py/flwr/client/dpfedavg_numpy_client.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/client/grpc_rere_client/__init__.py b/src/py/flwr/client/grpc_rere_client/__init__.py index 93903e725776..e7c9408c0047 100644 --- a/src/py/flwr/client/grpc_rere_client/__init__.py +++ b/src/py/flwr/client/grpc_rere_client/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py index cc35ffef46db..9607dac5679e 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index b1c268d51d55..34dc0e417383 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/client/message_handler/__init__.py b/src/py/flwr/client/message_handler/__init__.py index 653563963de5..a345b4af3ef2 100644 --- a/src/py/flwr/client/message_handler/__init__.py +++ b/src/py/flwr/client/message_handler/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index e5acbe0cc9d0..68326852970f 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 8a2db1804e4a..40907942513d 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/client/mod/__init__.py b/src/py/flwr/client/mod/__init__.py index 0b4cf6488421..35d1fa81805c 100644 --- a/src/py/flwr/client/mod/__init__.py +++ b/src/py/flwr/client/mod/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/client/mod/secure_aggregation/__init__.py b/src/py/flwr/client/mod/secure_aggregation/__init__.py index 8892d8c03935..a64bc89e62c9 100644 --- a/src/py/flwr/client/mod/secure_aggregation/__init__.py +++ b/src/py/flwr/client/mod/secure_aggregation/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/client/mod/utils.py b/src/py/flwr/client/mod/utils.py index 4c3c32944f01..c8fb21379783 100644 --- a/src/py/flwr/client/mod/utils.py +++ b/src/py/flwr/client/mod/utils.py @@ -1,4 +1,4 @@ -# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/client/mod/utils_test.py b/src/py/flwr/client/mod/utils_test.py index 4676a2c02c4b..035e41639b10 100644 --- a/src/py/flwr/client/mod/utils_test.py +++ b/src/py/flwr/client/mod/utils_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/client/numpy_client_test.py b/src/py/flwr/client/numpy_client_test.py index 526098798e45..06a0deafe2c9 100644 --- a/src/py/flwr/client/numpy_client_test.py +++ b/src/py/flwr/client/numpy_client_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/client/rest_client/__init__.py b/src/py/flwr/client/rest_client/__init__.py index c3485483ad35..a24d822a6d75 100644 --- a/src/py/flwr/client/rest_client/__init__.py +++ b/src/py/flwr/client/rest_client/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index 5f5e153f9d8d..db5bd7eb6770 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/client/supernode/app.py b/src/py/flwr/client/supernode/app.py index 742281f5c011..c9a16edeaf15 100644 --- a/src/py/flwr/client/supernode/app.py +++ b/src/py/flwr/client/supernode/app.py @@ -267,7 +267,7 @@ def _parse_args_run_supernode() -> argparse.ArgumentParser: "--flwr-dir", default=None, help="""The path containing installed Flower Apps. - By default, this value isequal to: + By default, this value is equal to: - `$FLWR_HOME/` if `$FLWR_HOME` is defined - `$XDG_DATA_HOME/.flwr/` if `$XDG_DATA_HOME` is defined diff --git a/src/py/flwr/common/address.py b/src/py/flwr/common/address.py index 71b6d684597f..1c6481b80a74 100644 --- a/src/py/flwr/common/address.py +++ b/src/py/flwr/common/address.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/common/address_test.py b/src/py/flwr/common/address_test.py index 420b89871d69..d5901ed640b1 100644 --- a/src/py/flwr/common/address_test.py +++ b/src/py/flwr/common/address_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index 193f000ca42e..ce29b3edb30e 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/common/date.py b/src/py/flwr/common/date.py index f47ad5470106..7f30f5e0591a 100644 --- a/src/py/flwr/common/date.py +++ b/src/py/flwr/common/date.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/common/dp.py b/src/py/flwr/common/dp.py index 83a72b8ce749..527805c8ef42 100644 --- a/src/py/flwr/common/dp.py +++ b/src/py/flwr/common/dp.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/common/grpc.py b/src/py/flwr/common/grpc.py index ead0329ca79c..ec8fe823a7eb 100644 --- a/src/py/flwr/common/grpc.py +++ b/src/py/flwr/common/grpc.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/common/message_test.py b/src/py/flwr/common/message_test.py index 19f8aeb1eb63..daee57896903 100644 --- a/src/py/flwr/common/message_test.py +++ b/src/py/flwr/common/message_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/common/secure_aggregation/__init__.py b/src/py/flwr/common/secure_aggregation/__init__.py index b4e0acc0c148..77e1ea3842d7 100644 --- a/src/py/flwr/common/secure_aggregation/__init__.py +++ b/src/py/flwr/common/secure_aggregation/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/common/secure_aggregation/crypto/__init__.py b/src/py/flwr/common/secure_aggregation/crypto/__init__.py index 2cb34493f7d0..3788dbc0ca15 100644 --- a/src/py/flwr/common/secure_aggregation/crypto/__init__.py +++ b/src/py/flwr/common/secure_aggregation/crypto/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/common/secure_aggregation/crypto/shamir.py b/src/py/flwr/common/secure_aggregation/crypto/shamir.py index e56e21b89371..688bfa2153ea 100644 --- a/src/py/flwr/common/secure_aggregation/crypto/shamir.py +++ b/src/py/flwr/common/secure_aggregation/crypto/shamir.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py index 1d004a398ea8..59ca84d604b8 100644 --- a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py +++ b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/common/secure_aggregation/ndarrays_arithmetic.py b/src/py/flwr/common/secure_aggregation/ndarrays_arithmetic.py index e926a9531bea..207c15b61518 100644 --- a/src/py/flwr/common/secure_aggregation/ndarrays_arithmetic.py +++ b/src/py/flwr/common/secure_aggregation/ndarrays_arithmetic.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/common/secure_aggregation/quantization.py b/src/py/flwr/common/secure_aggregation/quantization.py index 56c25e2bd59c..7946276b6a4f 100644 --- a/src/py/flwr/common/secure_aggregation/quantization.py +++ b/src/py/flwr/common/secure_aggregation/quantization.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/common/secure_aggregation/secaggplus_constants.py b/src/py/flwr/common/secure_aggregation/secaggplus_constants.py index 8a15908c13c5..545507eb44ed 100644 --- a/src/py/flwr/common/secure_aggregation/secaggplus_constants.py +++ b/src/py/flwr/common/secure_aggregation/secaggplus_constants.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/common/secure_aggregation/secaggplus_utils.py b/src/py/flwr/common/secure_aggregation/secaggplus_utils.py index c373573477b9..cf6ac3bfb003 100644 --- a/src/py/flwr/common/secure_aggregation/secaggplus_utils.py +++ b/src/py/flwr/common/secure_aggregation/secaggplus_utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index f9969426fc36..afb11b6956f2 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2021 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/common/version.py b/src/py/flwr/common/version.py index 6808c66606b1..ac13f70d8a88 100644 --- a/src/py/flwr/common/version.py +++ b/src/py/flwr/common/version.py @@ -1,3 +1,17 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== """Flower package version helper.""" import importlib.metadata as importlib_metadata diff --git a/src/py/flwr/server/client_proxy_test.py b/src/py/flwr/server/client_proxy_test.py index 685698558e3a..6ca37052a87d 100644 --- a/src/py/flwr/server/client_proxy_test.py +++ b/src/py/flwr/server/client_proxy_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/compat/app.py b/src/py/flwr/server/compat/app.py index 4bb23b846ab7..e978359fa828 100644 --- a/src/py/flwr/server/compat/app.py +++ b/src/py/flwr/server/compat/app.py @@ -1,4 +1,4 @@ -# Copyright 2022 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/compat/app_utils.py b/src/py/flwr/server/compat/app_utils.py index 1cdf1efbffb9..baff27307b88 100644 --- a/src/py/flwr/server/compat/app_utils.py +++ b/src/py/flwr/server/compat/app_utils.py @@ -91,7 +91,7 @@ def _update_client_manager( node_id=node_id, driver=driver, anonymous=False, - run_id=driver.run_id, # type: ignore + run_id=driver.run.run_id, ) if client_manager.register(client_proxy): registered_nodes[node_id] = client_proxy diff --git a/src/py/flwr/server/compat/driver_client_proxy.py b/src/py/flwr/server/compat/driver_client_proxy.py index 150803786f98..7190786784ec 100644 --- a/src/py/flwr/server/compat/driver_client_proxy.py +++ b/src/py/flwr/server/compat/driver_client_proxy.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/compat/driver_client_proxy_test.py b/src/py/flwr/server/compat/driver_client_proxy_test.py index d9e3d3bc0824..31b917fa869b 100644 --- a/src/py/flwr/server/compat/driver_client_proxy_test.py +++ b/src/py/flwr/server/compat/driver_client_proxy_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py index b95cec95ab47..4f888323e586 100644 --- a/src/py/flwr/server/driver/driver.py +++ b/src/py/flwr/server/driver/driver.py @@ -19,11 +19,17 @@ from typing import Iterable, List, Optional from flwr.common import Message, RecordSet +from flwr.common.typing import Run class Driver(ABC): """Abstract base Driver class for the Driver API.""" + @property + @abstractmethod + def run(self) -> Run: + """Run information.""" + @abstractmethod def create_message( # pylint: disable=too-many-arguments self, diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index d339f1b232f9..e614df659e3f 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -1,4 +1,4 @@ -# Copyright 2022 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ import time import warnings from logging import DEBUG, ERROR, WARNING -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, cast import grpc @@ -25,6 +25,7 @@ from flwr.common.grpc import create_channel from flwr.common.logger import log from flwr.common.serde import message_from_taskres, message_to_taskins +from flwr.common.typing import Run from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 CreateRunRequest, CreateRunResponse, @@ -37,6 +38,7 @@ ) from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611 from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 from .driver import Driver @@ -46,13 +48,24 @@ ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """ [Driver] Error: Not connected. -Call `connect()` on the `GrpcDriverHelper` instance before calling any of the other -`GrpcDriverHelper` methods. +Call `connect()` on the `GrpcDriverStub` instance before calling any of the other +`GrpcDriverStub` methods. """ -class GrpcDriverHelper: - """`GrpcDriverHelper` provides access to the gRPC Driver API/service.""" +class GrpcDriverStub: + """`GrpcDriverStub` provides access to the gRPC Driver API/service. + + Parameters + ---------- + driver_service_address : Optional[str] + The IPv4 or IPv6 address of the Driver API server. + Defaults to `"[::]:9091"`. + root_certificates : Optional[bytes] (default: None) + The PEM-encoded root certificates as a byte string. + If provided, a secure connection using the certificates will be + established to an SSL-enabled Flower server. + """ def __init__( self, @@ -64,6 +77,10 @@ def __init__( self.channel: Optional[grpc.Channel] = None self.stub: Optional[DriverStub] = None + def is_connected(self) -> bool: + """Return True if connected to the Driver API server, otherwise False.""" + return self.channel is not None + def connect(self) -> None: """Connect to the Driver API.""" event(EventType.DRIVER_CONNECT) @@ -95,18 +112,29 @@ def create_run(self, req: CreateRunRequest) -> CreateRunResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverHelper` instance not connected") + raise ConnectionError("`GrpcDriverStub` instance not connected") # Call Driver API res: CreateRunResponse = self.stub.CreateRun(request=req) return res + def get_run(self, req: GetRunRequest) -> GetRunResponse: + """Get run information.""" + # Check if channel is open + if self.stub is None: + log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) + raise ConnectionError("`GrpcDriverStub` instance not connected") + + # Call gRPC Driver API + res: GetRunResponse = self.stub.GetRun(request=req) + return res + def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse: """Get client IDs.""" # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverHelper` instance not connected") + raise ConnectionError("`GrpcDriverStub` instance not connected") # Call gRPC Driver API res: GetNodesResponse = self.stub.GetNodes(request=req) @@ -117,7 +145,7 @@ def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverHelper` instance not connected") + raise ConnectionError("`GrpcDriverStub` instance not connected") # Call gRPC Driver API res: PushTaskInsResponse = self.stub.PushTaskIns(request=req) @@ -128,7 +156,7 @@ def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverHelper` instance not connected") + raise ConnectionError("`GrpcDriverStub` instance not connected") # Call Driver API res: PullTaskResResponse = self.stub.PullTaskRes(request=req) @@ -140,56 +168,52 @@ class GrpcDriver(Driver): Parameters ---------- - driver_service_address : Optional[str] - The IPv4 or IPv6 address of the Driver API server. - Defaults to `"[::]:9091"`. - certificates : bytes (default: None) - Tuple containing root certificate, server certificate, and private key - to start a secure SSL-enabled server. The tuple is expected to have - three bytes elements in the following order: - - * CA certificate. - * server certificate. - * server private key. - fab_id : str (default: None) - The identifier of the FAB used in the run. - fab_version : str (default: None) - The version of the FAB used in the run. + run_id : int + The identifier of the run. + stub : Optional[GrpcDriverStub] (default: None) + The ``GrpcDriverStub`` instance used to communicate with the SuperLink. + If None, an instance connected to "[::]:9091" will be created. """ - def __init__( + def __init__( # pylint: disable=too-many-arguments self, - driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, - root_certificates: Optional[bytes] = None, - fab_id: Optional[str] = None, - fab_version: Optional[str] = None, + run_id: int, + stub: Optional[GrpcDriverStub] = None, ) -> None: - self.addr = driver_service_address - self.root_certificates = root_certificates - self.driver_helper: Optional[GrpcDriverHelper] = None - self.run_id: Optional[int] = None - self.fab_id = fab_id if fab_id is not None else "" - self.fab_version = fab_version if fab_version is not None else "" + self._run_id = run_id + self._run: Optional[Run] = None + self.stub = stub if stub is not None else GrpcDriverStub() self.node = Node(node_id=0, anonymous=True) - def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]: - # Check if the GrpcDriverHelper is initialized - if self.driver_helper is None or self.run_id is None: - # Connect and create run - self.driver_helper = GrpcDriverHelper( - driver_service_address=self.addr, - root_certificates=self.root_certificates, + @property + def run(self) -> Run: + """Run information.""" + self._get_stub_and_run_id() + return Run(**vars(cast(Run, self._run))) + + def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]: + # Check if is initialized + if self._run is None: + # Connect + if not self.stub.is_connected(): + self.stub.connect() + # Get the run info + req = GetRunRequest(run_id=self._run_id) + res = self.stub.get_run(req) + if not res.HasField("run"): + raise RuntimeError(f"Cannot find the run with ID: {self._run_id}") + self._run = Run( + run_id=res.run.run_id, + fab_id=res.run.fab_id, + fab_version=res.run.fab_version, ) - self.driver_helper.connect() - req = CreateRunRequest(fab_id=self.fab_id, fab_version=self.fab_version) - res = self.driver_helper.create_run(req) - self.run_id = res.run_id - return self.driver_helper, self.run_id + + return self.stub, self._run.run_id def _check_message(self, message: Message) -> None: # Check if the message is valid if not ( - message.metadata.run_id == self.run_id + message.metadata.run_id == cast(Run, self._run).run_id and message.metadata.src_node_id == self.node.node_id and message.metadata.message_id == "" and message.metadata.reply_to_message == "" @@ -210,7 +234,7 @@ def create_message( # pylint: disable=too-many-arguments This method constructs a new `Message` with given content and metadata. The `run_id` and `src_node_id` will be set automatically. """ - _, run_id = self._get_grpc_driver_helper_and_run_id() + _, run_id = self._get_stub_and_run_id() if ttl: warnings.warn( "A custom TTL was set, but note that the SuperLink does not enforce " @@ -234,9 +258,9 @@ def create_message( # pylint: disable=too-many-arguments def get_node_ids(self) -> List[int]: """Get node IDs.""" - grpc_driver_helper, run_id = self._get_grpc_driver_helper_and_run_id() - # Call GrpcDriverHelper method - res = grpc_driver_helper.get_nodes(GetNodesRequest(run_id=run_id)) + stub, run_id = self._get_stub_and_run_id() + # Call GrpcDriverStub method + res = stub.get_nodes(GetNodesRequest(run_id=run_id)) return [node.node_id for node in res.nodes] def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: @@ -245,7 +269,7 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: This method takes an iterable of messages and sends each message to the node specified in `dst_node_id`. """ - grpc_driver_helper, _ = self._get_grpc_driver_helper_and_run_id() + stub, _ = self._get_stub_and_run_id() # Construct TaskIns task_ins_list: List[TaskIns] = [] for msg in messages: @@ -255,10 +279,8 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: taskins = message_to_taskins(msg) # Add to list task_ins_list.append(taskins) - # Call GrpcDriverHelper method - res = grpc_driver_helper.push_task_ins( - PushTaskInsRequest(task_ins_list=task_ins_list) - ) + # Call GrpcDriverStub method + res = stub.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list)) return list(res.task_ids) def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: @@ -267,9 +289,9 @@ def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: This method is used to collect messages from the SuperLink that correspond to a set of given message IDs. """ - grpc_driver, _ = self._get_grpc_driver_helper_and_run_id() + stub, _ = self._get_stub_and_run_id() # Pull TaskRes - res = grpc_driver.pull_task_res( + res = stub.pull_task_res( PullTaskResRequest(node=self.node, task_ids=message_ids) ) # Convert TaskRes to Message @@ -308,8 +330,8 @@ def send_and_receive( def close(self) -> None: """Disconnect from the SuperLink if connected.""" - # Check if GrpcDriverHelper is initialized - if self.driver_helper is None: + # Check if `connect` was called before + if not self.stub.is_connected(): return # Disconnect - self.driver_helper.disconnect() + self.stub.disconnect() diff --git a/src/py/flwr/server/driver/grpc_driver_test.py b/src/py/flwr/server/driver/grpc_driver_test.py index fbead0e3043d..72efc5f8b2c6 100644 --- a/src/py/flwr/server/driver/grpc_driver_test.py +++ b/src/py/flwr/server/driver/grpc_driver_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ PullTaskResRequest, PushTaskInsRequest, ) +from flwr.proto.run_pb2 import Run # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 from .grpc_driver import GrpcDriver @@ -36,58 +37,36 @@ class TestGrpcDriver(unittest.TestCase): """Tests for `GrpcDriver` class.""" def setUp(self) -> None: - """Initialize mock GrpcDriverHelper and Driver instance before each test.""" - mock_response = Mock() - mock_response.run_id = 61016 - self.mock_grpc_driver_helper = Mock() - self.mock_grpc_driver_helper.create_run.return_value = mock_response - self.patcher = patch( - "flwr.server.driver.grpc_driver.GrpcDriverHelper", - return_value=self.mock_grpc_driver_helper, + """Initialize mock GrpcDriverStub and Driver instance before each test.""" + mock_response = Mock( + run=Run(run_id=61016, fab_id="mock/mock", fab_version="v1.0.0") ) - self.patcher.start() - self.driver = GrpcDriver() - - def tearDown(self) -> None: - """Cleanup after each test.""" - self.patcher.stop() - - def test_check_and_init_grpc_driver_already_initialized(self) -> None: - """Test that GrpcDriverHelper doesn't initialize if run is created.""" - # Prepare - self.driver.driver_helper = self.mock_grpc_driver_helper - self.driver.run_id = 61016 - - # Execute - # pylint: disable-next=protected-access - self.driver._get_grpc_driver_helper_and_run_id() + self.mock_grpc_driver_stub = Mock() + self.mock_grpc_driver_stub.get_run.return_value = mock_response + self.mock_grpc_driver_stub.HasField.return_value = True + self.driver = GrpcDriver(run_id=61016, stub=self.mock_grpc_driver_stub) + def test_init_grpc_driver(self) -> None: + """Test GrpcDriverStub initialization.""" # Assert - self.mock_grpc_driver_helper.connect.assert_not_called() - - def test_check_and_init_grpc_driver_needs_initialization(self) -> None: - """Test GrpcDriverHelper initialization when run is not created.""" - # Execute - # pylint: disable-next=protected-access - self.driver._get_grpc_driver_helper_and_run_id() - - # Assert - self.mock_grpc_driver_helper.connect.assert_called_once() - self.assertEqual(self.driver.run_id, 61016) + self.assertEqual(self.driver.run.run_id, 61016) + self.assertEqual(self.driver.run.fab_id, "mock/mock") + self.assertEqual(self.driver.run.fab_version, "v1.0.0") + self.mock_grpc_driver_stub.get_run.assert_called_once() def test_get_nodes(self) -> None: """Test retrieval of nodes.""" # Prepare mock_response = Mock() mock_response.nodes = [Mock(node_id=404), Mock(node_id=200)] - self.mock_grpc_driver_helper.get_nodes.return_value = mock_response + self.mock_grpc_driver_stub.get_nodes.return_value = mock_response # Execute node_ids = self.driver.get_node_ids() - args, kwargs = self.mock_grpc_driver_helper.get_nodes.call_args + args, kwargs = self.mock_grpc_driver_stub.get_nodes.call_args # Assert - self.mock_grpc_driver_helper.connect.assert_called_once() + self.mock_grpc_driver_stub.get_run.assert_called_once() self.assertEqual(len(args), 1) self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], GetNodesRequest) @@ -98,7 +77,7 @@ def test_push_messages_valid(self) -> None: """Test pushing valid messages.""" # Prepare mock_response = Mock(task_ids=["id1", "id2"]) - self.mock_grpc_driver_helper.push_task_ins.return_value = mock_response + self.mock_grpc_driver_stub.push_task_ins.return_value = mock_response msgs = [ self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL) for _ in range(2) @@ -106,10 +85,10 @@ def test_push_messages_valid(self) -> None: # Execute msg_ids = self.driver.push_messages(msgs) - args, kwargs = self.mock_grpc_driver_helper.push_task_ins.call_args + args, kwargs = self.mock_grpc_driver_stub.push_task_ins.call_args # Assert - self.mock_grpc_driver_helper.connect.assert_called_once() + self.mock_grpc_driver_stub.get_run.assert_called_once() self.assertEqual(len(args), 1) self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], PushTaskInsRequest) @@ -121,7 +100,7 @@ def test_push_messages_invalid(self) -> None: """Test pushing invalid messages.""" # Prepare mock_response = Mock(task_ids=["id1", "id2"]) - self.mock_grpc_driver_helper.push_task_ins.return_value = mock_response + self.mock_grpc_driver_stub.push_task_ins.return_value = mock_response msgs = [ self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL) for _ in range(2) @@ -145,16 +124,16 @@ def test_pull_messages_with_given_message_ids(self) -> None: ), TaskRes(task=Task(ancestry=["id3"], error=error_to_proto(Error(code=0)))), ] - self.mock_grpc_driver_helper.pull_task_res.return_value = mock_response + self.mock_grpc_driver_stub.pull_task_res.return_value = mock_response msg_ids = ["id1", "id2", "id3"] # Execute msgs = self.driver.pull_messages(msg_ids) reply_tos = {msg.metadata.reply_to_message for msg in msgs} - args, kwargs = self.mock_grpc_driver_helper.pull_task_res.call_args + args, kwargs = self.mock_grpc_driver_stub.pull_task_res.call_args # Assert - self.mock_grpc_driver_helper.connect.assert_called_once() + self.mock_grpc_driver_stub.get_run.assert_called_once() self.assertEqual(len(args), 1) self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], PullTaskResRequest) @@ -165,14 +144,14 @@ def test_send_and_receive_messages_complete(self) -> None: """Test send and receive all messages successfully.""" # Prepare mock_response = Mock(task_ids=["id1"]) - self.mock_grpc_driver_helper.push_task_ins.return_value = mock_response + self.mock_grpc_driver_stub.push_task_ins.return_value = mock_response # The response message must include either `content` (i.e. a recordset) or # an `Error`. We choose the latter in this case error_proto = error_to_proto(Error(code=0)) mock_response = Mock( task_res_list=[TaskRes(task=Task(ancestry=["id1"], error=error_proto))] ) - self.mock_grpc_driver_helper.pull_task_res.return_value = mock_response + self.mock_grpc_driver_stub.pull_task_res.return_value = mock_response msgs = [self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL)] # Execute @@ -187,9 +166,9 @@ def test_send_and_receive_messages_timeout(self) -> None: # Prepare sleep_fn = time.sleep mock_response = Mock(task_ids=["id1"]) - self.mock_grpc_driver_helper.push_task_ins.return_value = mock_response + self.mock_grpc_driver_stub.push_task_ins.return_value = mock_response mock_response = Mock(task_res_list=[]) - self.mock_grpc_driver_helper.pull_task_res.return_value = mock_response + self.mock_grpc_driver_stub.pull_task_res.return_value = mock_response msgs = [self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL)] # Execute @@ -204,19 +183,21 @@ def test_send_and_receive_messages_timeout(self) -> None: def test_del_with_initialized_driver(self) -> None: """Test cleanup behavior when Driver is initialized.""" # Prepare - # pylint: disable-next=protected-access - self.driver._get_grpc_driver_helper_and_run_id() + self.mock_grpc_driver_stub.is_connected.return_value = True # Execute self.driver.close() # Assert - self.mock_grpc_driver_helper.disconnect.assert_called_once() + self.mock_grpc_driver_stub.disconnect.assert_called_once() def test_del_with_uninitialized_driver(self) -> None: """Test cleanup behavior when Driver is not initialized.""" + # Prepare + self.mock_grpc_driver_stub.is_connected.return_value = False + # Execute self.driver.close() # Assert - self.mock_grpc_driver_helper.disconnect.assert_not_called() + self.mock_grpc_driver_stub.disconnect.assert_not_called() diff --git a/src/py/flwr/server/driver/inmemory_driver.py b/src/py/flwr/server/driver/inmemory_driver.py index 8c71b1067293..53406796750f 100644 --- a/src/py/flwr/server/driver/inmemory_driver.py +++ b/src/py/flwr/server/driver/inmemory_driver.py @@ -17,11 +17,12 @@ import time import warnings -from typing import Iterable, List, Optional +from typing import Iterable, List, Optional, cast from uuid import UUID from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet from flwr.common.serde import message_from_taskres, message_to_taskins +from flwr.common.typing import Run from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.server.superlink.state import StateFactory @@ -33,30 +34,27 @@ class InMemoryDriver(Driver): Parameters ---------- + run_id : int + The identifier of the run. state_factory : StateFactory A StateFactory embedding a state that this driver can interface with. - fab_id : str (default: None) - The identifier of the FAB used in the run. - fab_version : str (default: None) - The version of the FAB used in the run. """ def __init__( self, + run_id: int, state_factory: StateFactory, - fab_id: Optional[str] = None, - fab_version: Optional[str] = None, ) -> None: - self.run_id: Optional[int] = None - self.fab_id = fab_id if fab_id is not None else "" - self.fab_version = fab_version if fab_version is not None else "" - self.node = Node(node_id=0, anonymous=True) + self._run_id = run_id + self._run: Optional[Run] = None self.state = state_factory.state() + self.node = Node(node_id=0, anonymous=True) def _check_message(self, message: Message) -> None: + self._init_run() # Check if the message is valid if not ( - message.metadata.run_id == self.run_id + message.metadata.run_id == cast(Run, self._run).run_id and message.metadata.src_node_id == self.node.node_id and message.metadata.message_id == "" and message.metadata.reply_to_message == "" @@ -64,16 +62,20 @@ def _check_message(self, message: Message) -> None: ): raise ValueError(f"Invalid message: {message}") - def _get_run_id(self) -> int: - """Return run_id. - - If unset, create a new run. - """ - if self.run_id is None: - self.run_id = self.state.create_run( - fab_id=self.fab_id, fab_version=self.fab_version - ) - return self.run_id + def _init_run(self) -> None: + """Initialize the run.""" + if self._run is not None: + return + run = self.state.get_run(self._run_id) + if run is None: + raise RuntimeError(f"Cannot find the run with ID: {self._run_id}") + self._run = run + + @property + def run(self) -> Run: + """Run ID.""" + self._init_run() + return Run(**vars(cast(Run, self._run))) def create_message( # pylint: disable=too-many-arguments self, @@ -88,7 +90,7 @@ def create_message( # pylint: disable=too-many-arguments This method constructs a new `Message` with given content and metadata. The `run_id` and `src_node_id` will be set automatically. """ - run_id = self._get_run_id() + self._init_run() if ttl: warnings.warn( "A custom TTL was set, but note that the SuperLink does not enforce " @@ -99,7 +101,7 @@ def create_message( # pylint: disable=too-many-arguments ttl_ = DEFAULT_TTL if ttl is None else ttl metadata = Metadata( - run_id=run_id, + run_id=cast(Run, self._run).run_id, message_id="", # Will be set by the server src_node_id=self.node.node_id, dst_node_id=dst_node_id, @@ -112,8 +114,8 @@ def create_message( # pylint: disable=too-many-arguments def get_node_ids(self) -> List[int]: """Get node IDs.""" - run_id = self._get_run_id() - return list(self.state.get_nodes(run_id)) + self._init_run() + return list(self.state.get_nodes(cast(Run, self._run).run_id)) def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: """Push messages to specified node IDs. diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index 95c2a0b277af..eff38f548826 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,8 +31,9 @@ message_to_taskres, recordset_to_proto, ) +from flwr.common.typing import Run from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 -from flwr.server.superlink.state import StateFactory +from flwr.server.superlink.state import InMemoryState, SqliteState, StateFactory from .inmemory_driver import InMemoryDriver @@ -79,12 +80,24 @@ def setUp(self) -> None: """ # Create driver self.num_nodes = 42 - self.driver = InMemoryDriver(StateFactory("")) - self.driver.state = MagicMock() - self.driver.state.get_nodes.return_value = [ + self.state = MagicMock() + self.state.get_nodes.return_value = [ int.from_bytes(os.urandom(8), "little", signed=True) for _ in range(self.num_nodes) ] + self.state.get_run.return_value = Run( + run_id=61016, fab_id="mock/mock", fab_version="v1.0.0" + ) + state_factory = MagicMock(state=lambda: self.state) + self.driver = InMemoryDriver(run_id=61016, state_factory=state_factory) + self.driver.state = self.state + + def test_get_run(self) -> None: + """Test the InMemoryDriver starting with run_id.""" + # Assert + self.assertEqual(self.driver.run.run_id, 61016) + self.assertEqual(self.driver.run.fab_id, "mock/mock") + self.assertEqual(self.driver.run.fab_version, "v1.0.0") def test_get_nodes(self) -> None: """Test retrieval of nodes.""" @@ -104,7 +117,7 @@ def test_push_messages_valid(self) -> None: ] taskins_ids = [uuid4() for _ in range(num_messages)] - self.driver.state.store_task_ins.side_effect = taskins_ids # type: ignore + self.state.store_task_ins.side_effect = taskins_ids # Execute msg_ids = list(self.driver.push_messages(msgs)) @@ -141,7 +154,7 @@ def test_pull_messages_with_given_message_ids(self) -> None: task=Task(ancestry=[msg_ids[1]], error=error_to_proto(Error(code=0))) ), ] - self.driver.state.get_task_res.return_value = task_res_list # type: ignore + self.state.get_task_res.return_value = task_res_list # Execute pulled_msgs = list(self.driver.pull_messages(msg_ids)) @@ -167,8 +180,8 @@ def test_send_and_receive_messages_complete(self) -> None: task=Task(ancestry=[msg_ids[1]], error=error_to_proto(Error(code=0))) ), ] - self.driver.state.store_task_ins.side_effect = msg_ids # type: ignore - self.driver.state.get_task_res.return_value = task_res_list # type: ignore + self.state.store_task_ins.side_effect = msg_ids + self.state.get_task_res.return_value = task_res_list # Execute ret_msgs = list(self.driver.send_and_receive(msgs)) @@ -193,8 +206,8 @@ def test_send_and_receive_messages_timeout(self) -> None: task=Task(ancestry=[msg_ids[1]], error=error_to_proto(Error(code=0))) ), ] - self.driver.state.store_task_ins.side_effect = msg_ids # type: ignore - self.driver.state.get_task_res.return_value = task_res_list # type: ignore + self.state.store_task_ins.side_effect = msg_ids + self.state.get_task_res.return_value = task_res_list # Execute with patch("time.sleep", side_effect=lambda t: time.sleep(t * 0.01)): @@ -208,19 +221,23 @@ def test_send_and_receive_messages_timeout(self) -> None: def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None: """Test tasks are deleted in sqlite state once messages are pulled.""" # Prepare - self.driver = InMemoryDriver(StateFactory("")) + state = StateFactory("").state() + self.driver = InMemoryDriver( + state.create_run("", ""), MagicMock(state=lambda: state) + ) msg_ids, node_id = push_messages(self.driver, self.num_nodes) + assert isinstance(state, SqliteState) # Check recorded - task_ins = self.driver.state.query("SELECT * FROM task_ins;") # type: ignore + task_ins = state.query("SELECT * FROM task_ins;") self.assertEqual(len(task_ins), len(list(msg_ids))) # Prepare: create replies reply_tos = get_replies(self.driver, msg_ids, node_id) # Query number of task_ins and task_res in State - task_res = self.driver.state.query("SELECT * FROM task_res;") # type: ignore - task_ins = self.driver.state.query("SELECT * FROM task_ins;") # type: ignore + task_res = state.query("SELECT * FROM task_res;") + task_ins = state.query("SELECT * FROM task_ins;") # Assert self.assertEqual(reply_tos, msg_ids) @@ -230,18 +247,19 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None: def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None: """Test tasks are deleted in in-memory state once messages are pulled.""" # Prepare - self.driver = InMemoryDriver(StateFactory(":flwr-in-memory-state:")) + state_factory = StateFactory(":flwr-in-memory-state:") + state = state_factory.state() + self.driver = InMemoryDriver(state.create_run("", ""), state_factory) msg_ids, node_id = push_messages(self.driver, self.num_nodes) + assert isinstance(state, InMemoryState) # Check recorded - self.assertEqual( - len(self.driver.state.task_ins_store), len(list(msg_ids)) # type: ignore - ) + self.assertEqual(len(state.task_ins_store), len(list(msg_ids))) # Prepare: create replies reply_tos = get_replies(self.driver, msg_ids, node_id) # Assert self.assertEqual(reply_tos, msg_ids) - self.assertEqual(len(self.driver.state.task_res_store), 0) # type: ignore - self.assertEqual(len(self.driver.state.task_ins_store), 0) # type: ignore + self.assertEqual(len(state.task_res_store), 0) + self.assertEqual(len(state.task_ins_store), 0) diff --git a/src/py/flwr/server/history_test.py b/src/py/flwr/server/history_test.py index adb9d697e409..b53357149623 100644 --- a/src/py/flwr/server/history_test.py +++ b/src/py/flwr/server/history_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2021 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index efd3f6846264..63ffc4a1caae 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -24,8 +24,10 @@ from flwr.common import Context, EventType, RecordSet, event from flwr.common.logger import log, update_console_handler, warn_deprecated_feature from flwr.common.object_ref import load_app +from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611 -from .driver import Driver, GrpcDriver +from .driver import Driver +from .driver.grpc_driver import GrpcDriver, GrpcDriverStub from .server_app import LoadServerAppError, ServerApp ADDRESS_DRIVER_API = "0.0.0.0:9091" @@ -149,13 +151,16 @@ def run_server_app() -> None: server_app_dir = args.dir server_app_attr = getattr(args, "server-app") - # Initialize GrpcDriver - driver = GrpcDriver( - driver_service_address=args.superlink, - root_certificates=root_certificates, - fab_id=args.fab_id, - fab_version=args.fab_version, + # Create run + stub = GrpcDriverStub( + driver_service_address=args.superlink, root_certificates=root_certificates ) + stub.connect() + req = CreateRunRequest(fab_id=args.fab_id, fab_version=args.fab_version) + res = stub.create_run(req) + + # Initialize GrpcDriver + driver = GrpcDriver(run_id=res.run_id, stub=stub) # Run the ServerApp with the Driver run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr) diff --git a/src/py/flwr/server/server_app_test.py b/src/py/flwr/server/server_app_test.py index 38c0d6240d90..0751a0cb2bc5 100644 --- a/src/py/flwr/server/server_app_test.py +++ b/src/py/flwr/server/server_app_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/strategy/bulyan.py b/src/py/flwr/server/strategy/bulyan.py index 1e4f97530ab7..a81406c255ad 100644 --- a/src/py/flwr/server/strategy/bulyan.py +++ b/src/py/flwr/server/strategy/bulyan.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/strategy/bulyan_test.py b/src/py/flwr/server/strategy/bulyan_test.py index 299ed49066fb..93a9ebda3783 100644 --- a/src/py/flwr/server/strategy/bulyan_test.py +++ b/src/py/flwr/server/strategy/bulyan_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/strategy/dpfedavg_adaptive.py b/src/py/flwr/server/strategy/dpfedavg_adaptive.py index a908679ed668..423ddddeb379 100644 --- a/src/py/flwr/server/strategy/dpfedavg_adaptive.py +++ b/src/py/flwr/server/strategy/dpfedavg_adaptive.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/strategy/dpfedavg_fixed.py b/src/py/flwr/server/strategy/dpfedavg_fixed.py index c54379fc7087..d122f0688922 100644 --- a/src/py/flwr/server/strategy/dpfedavg_fixed.py +++ b/src/py/flwr/server/strategy/dpfedavg_fixed.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/strategy/fedadagrad.py b/src/py/flwr/server/strategy/fedadagrad.py index 4a8f52d98e18..f13c5358da25 100644 --- a/src/py/flwr/server/strategy/fedadagrad.py +++ b/src/py/flwr/server/strategy/fedadagrad.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2021 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/strategy/fedadagrad_test.py b/src/py/flwr/server/strategy/fedadagrad_test.py index 0c966442ecaf..b43a4c75d123 100644 --- a/src/py/flwr/server/strategy/fedadagrad_test.py +++ b/src/py/flwr/server/strategy/fedadagrad_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2021 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/strategy/fedadam.py b/src/py/flwr/server/strategy/fedadam.py index 8a47cf0dd8ac..dc90e90c7568 100644 --- a/src/py/flwr/server/strategy/fedadam.py +++ b/src/py/flwr/server/strategy/fedadam.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2021 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/strategy/fedavg_android.py b/src/py/flwr/server/strategy/fedavg_android.py index 6678b7ced114..2f49cf8784c9 100644 --- a/src/py/flwr/server/strategy/fedavg_android.py +++ b/src/py/flwr/server/strategy/fedavg_android.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2021 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/strategy/fedavgm.py b/src/py/flwr/server/strategy/fedavgm.py index fb9261abe89d..ab3d37249db6 100644 --- a/src/py/flwr/server/strategy/fedavgm.py +++ b/src/py/flwr/server/strategy/fedavgm.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/strategy/fedavgm_test.py b/src/py/flwr/server/strategy/fedavgm_test.py index a0e942171627..39da5f4b82c4 100644 --- a/src/py/flwr/server/strategy/fedavgm_test.py +++ b/src/py/flwr/server/strategy/fedavgm_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/strategy/fedmedian.py b/src/py/flwr/server/strategy/fedmedian.py index 17e979d92beb..e7cba5324fa8 100644 --- a/src/py/flwr/server/strategy/fedmedian.py +++ b/src/py/flwr/server/strategy/fedmedian.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/strategy/fedmedian_test.py b/src/py/flwr/server/strategy/fedmedian_test.py index 57cf08d8c01d..3960ad70b145 100644 --- a/src/py/flwr/server/strategy/fedmedian_test.py +++ b/src/py/flwr/server/strategy/fedmedian_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/strategy/fedopt.py b/src/py/flwr/server/strategy/fedopt.py index be5f260d96fa..c581d4797123 100644 --- a/src/py/flwr/server/strategy/fedopt.py +++ b/src/py/flwr/server/strategy/fedopt.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2021 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/strategy/fedprox.py b/src/py/flwr/server/strategy/fedprox.py index d20f578b193d..f15271e06060 100644 --- a/src/py/flwr/server/strategy/fedprox.py +++ b/src/py/flwr/server/strategy/fedprox.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/strategy/fedxgb_bagging.py b/src/py/flwr/server/strategy/fedxgb_bagging.py index a8e8adddafbb..a74ee81976a6 100644 --- a/src/py/flwr/server/strategy/fedxgb_bagging.py +++ b/src/py/flwr/server/strategy/fedxgb_bagging.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/strategy/fedxgb_cyclic.py b/src/py/flwr/server/strategy/fedxgb_cyclic.py index 2605daab29f4..75025a89728b 100644 --- a/src/py/flwr/server/strategy/fedxgb_cyclic.py +++ b/src/py/flwr/server/strategy/fedxgb_cyclic.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/strategy/fedxgb_nn_avg.py b/src/py/flwr/server/strategy/fedxgb_nn_avg.py index 8dedc925f350..4562663287ae 100644 --- a/src/py/flwr/server/strategy/fedxgb_nn_avg.py +++ b/src/py/flwr/server/strategy/fedxgb_nn_avg.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/strategy/fedyogi.py b/src/py/flwr/server/strategy/fedyogi.py index 7c77aab7ae73..c7b2ebb51667 100644 --- a/src/py/flwr/server/strategy/fedyogi.py +++ b/src/py/flwr/server/strategy/fedyogi.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2021 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/strategy/krum.py b/src/py/flwr/server/strategy/krum.py index 16eb5212940e..074d018c35a3 100644 --- a/src/py/flwr/server/strategy/krum.py +++ b/src/py/flwr/server/strategy/krum.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/strategy/krum_test.py b/src/py/flwr/server/strategy/krum_test.py index 653dc9a8475d..b34982325b39 100644 --- a/src/py/flwr/server/strategy/krum_test.py +++ b/src/py/flwr/server/strategy/krum_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/strategy/multikrum_test.py b/src/py/flwr/server/strategy/multikrum_test.py index f874dc2f9800..7a1a4c3ecf38 100644 --- a/src/py/flwr/server/strategy/multikrum_test.py +++ b/src/py/flwr/server/strategy/multikrum_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/strategy/qfedavg.py b/src/py/flwr/server/strategy/qfedavg.py index 758e8e608e9f..26a397d4cf8c 100644 --- a/src/py/flwr/server/strategy/qfedavg.py +++ b/src/py/flwr/server/strategy/qfedavg.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2021 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/driver/__init__.py b/src/py/flwr/server/superlink/driver/__init__.py index 2bfe63e6065f..58fbc479478f 100644 --- a/src/py/flwr/server/superlink/driver/__init__.py +++ b/src/py/flwr/server/superlink/driver/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/driver/driver_grpc.py b/src/py/flwr/server/superlink/driver/driver_grpc.py index f74000bc59c4..782935481945 100644 --- a/src/py/flwr/server/superlink/driver/driver_grpc.py +++ b/src/py/flwr/server/superlink/driver/driver_grpc.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py index e808616af778..03128f02158e 100644 --- a/src/py/flwr/server/superlink/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -35,7 +35,11 @@ PushTaskInsResponse, ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 +from flwr.proto.run_pb2 import ( # pylint: disable=E0611 + GetRunRequest, + GetRunResponse, + Run, +) from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611 from flwr.server.superlink.state import State, StateFactory from flwr.server.utils.validator import validate_task_ins_or_res @@ -134,7 +138,15 @@ def GetRun( self, request: GetRunRequest, context: grpc.ServicerContext ) -> GetRunResponse: """Get run information.""" - raise NotImplementedError + log(DEBUG, "DriverServicer.GetRun") + + # Init state + state: State = self.state_factory.state() + + # Retrieve run information + run = state.get_run(request.run_id) + run_proto = None if run is None else Run(**vars(run)) + return GetRunResponse(run=run_proto) def _raise_if(validation_error: bool, detail: str) -> None: diff --git a/src/py/flwr/server/superlink/driver/driver_servicer_test.py b/src/py/flwr/server/superlink/driver/driver_servicer_test.py index 99f7cc007a89..394d6be7ee6a 100644 --- a/src/py/flwr/server/superlink/driver/driver_servicer_test.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/fleet/__init__.py b/src/py/flwr/server/superlink/fleet/__init__.py index d3c3ef90163d..c236ed06ae1c 100644 --- a/src/py/flwr/server/superlink/fleet/__init__.py +++ b/src/py/flwr/server/superlink/fleet/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/__init__.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/__init__.py index bae8bc431edd..6b2c2bf3ffec 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/__init__.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py index 6f94ea844e38..79f1a8f9902b 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer_test.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer_test.py index bd93554a6a32..03e8555f8ecf 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py index d5b4a915c609..5fe0396696ab 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge_test.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge_test.py index f7c236acd7a1..f9b6b97030f0 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py index ac62ad014950..03497743becd 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy_test.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy_test.py index e7077dfd39ae..6d3eb4f67e30 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py index ae685fda91a7..1b4286c87b92 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server_test.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server_test.py index 8afa37515950..7ff730b17afa 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/__init__.py b/src/py/flwr/server/superlink/fleet/grpc_rere/__init__.py index 61ab71d91400..03c8ded2423a 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/__init__.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py index 13e024eb31e4..89342a46eb48 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/fleet/message_handler/__init__.py b/src/py/flwr/server/superlink/fleet/message_handler/__init__.py index 18b0f11fa6c5..3db0ef5d1611 100644 --- a/src/py/flwr/server/superlink/fleet/message_handler/__init__.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py index dceb18cab453..b70cd54035fe 100644 --- a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/fleet/message_handler/message_handler_test.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler_test.py index c135f6fb7b61..ec521b328eb8 100644 --- a/src/py/flwr/server/superlink/fleet/message_handler/message_handler_test.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/fleet/rest_rere/__init__.py b/src/py/flwr/server/superlink/fleet/rest_rere/__init__.py index a926f9ca0bfc..f24db2a2e12f 100644 --- a/src/py/flwr/server/superlink/fleet/rest_rere/__init__.py +++ b/src/py/flwr/server/superlink/fleet/rest_rere/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py index c7ff496d39bf..1ed67e5eb0aa 100644 --- a/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py +++ b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/state/__init__.py b/src/py/flwr/server/superlink/state/__init__.py index 7f260d733bbe..9d3bd220403b 100644 --- a/src/py/flwr/server/superlink/state/__init__.py +++ b/src/py/flwr/server/superlink/state/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index e03260355db9..da9c754c3115 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -1,4 +1,4 @@ -# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index b9672757b0e6..4df9470ded62 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -1,4 +1,4 @@ -# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/state/sqlite_state_test.py b/src/py/flwr/server/superlink/state/sqlite_state_test.py index 20927df1cf12..10e12da96bd5 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state_test.py +++ b/src/py/flwr/server/superlink/state/sqlite_state_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index d1fc9465c9f2..65e2c63cab69 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -1,4 +1,4 @@ -# Copyright 2022 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/superlink/state/state_factory.py b/src/py/flwr/server/superlink/state/state_factory.py index 62a00d910828..96c8d445c16e 100644 --- a/src/py/flwr/server/superlink/state/state_factory.py +++ b/src/py/flwr/server/superlink/state/state_factory.py @@ -1,4 +1,4 @@ -# Copyright 2022 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,7 +26,16 @@ class StateFactory: - """Factory class that creates State instances.""" + """Factory class that creates State instances. + + Parameters + ---------- + database : str + A string representing the path to the database file that will be opened. + Note that passing ':memory:' will open a connection to a database that is + in RAM, instead of on disk. For more information on special in-memory + databases, please refer to https://sqlite.org/inmemorydb.html. + """ def __init__(self, database: str) -> None: self.database = database diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 81307d938400..373202d5cde6 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/utils/__init__.py b/src/py/flwr/server/utils/__init__.py index c370716adaac..8994374c4d08 100644 --- a/src/py/flwr/server/utils/__init__.py +++ b/src/py/flwr/server/utils/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2021 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/utils/tensorboard.py b/src/py/flwr/server/utils/tensorboard.py index 3e8d1e62411e..5d38fc159657 100644 --- a/src/py/flwr/server/utils/tensorboard.py +++ b/src/py/flwr/server/utils/tensorboard.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2021 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/server/utils/tensorboard_test.py b/src/py/flwr/server/utils/tensorboard_test.py index 1827a42cf6e6..689755c6da16 100644 --- a/src/py/flwr/server/utils/tensorboard_test.py +++ b/src/py/flwr/server/utils/tensorboard_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2021 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/simulation/__init__.py b/src/py/flwr/simulation/__init__.py index 3d648b14edba..5db90a352e3f 100644 --- a/src/py/flwr/simulation/__init__.py +++ b/src/py/flwr/simulation/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2021 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/simulation/app.py b/src/py/flwr/simulation/app.py index 4b4b7249ccd3..856d6fc45e22 100644 --- a/src/py/flwr/simulation/app.py +++ b/src/py/flwr/simulation/app.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2021 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/simulation/ray_transport/__init__.py b/src/py/flwr/simulation/ray_transport/__init__.py index 0e82b75bb4b3..ed4971935a15 100644 --- a/src/py/flwr/simulation/ray_transport/__init__.py +++ b/src/py/flwr/simulation/ray_transport/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2021 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 5e344eb087ee..d3d103bb377a 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2021 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 3532c5a4e877..a3de1401d252 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -27,7 +27,7 @@ from flwr.client import ClientApp from flwr.common import EventType, event, log from flwr.common.logger import set_logger_propagation, update_console_handler -from flwr.common.typing import ConfigsRecordValues +from flwr.common.typing import ConfigsRecordValues, Run from flwr.server.driver import Driver, InMemoryDriver from flwr.server.run_serverapp import run from flwr.server.server_app import ServerApp @@ -169,11 +169,14 @@ def server_th_with_start_checks( # type: ignore return serverapp_th -def _init_run_id(driver: InMemoryDriver, state: StateFactory, run_id: int) -> None: - """Create a run with a given `run_id`.""" +def _override_run_id(state: StateFactory, run_id_to_replace: int, run_id: int) -> None: + """Override the run_id of an existing Run.""" log(DEBUG, "Pre-registering run with id %s", run_id) - state.state().run_ids[run_id] = ("", "") # type: ignore - driver.run_id = run_id + # Remove run + run_info: Run = state.state().run_ids.pop(run_id_to_replace) # type: ignore + # Update with new run_id and insert back in state + run_info.run_id = run_id + state.state().run_ids[run_id] = run_info # type: ignore # pylint: disable=too-many-locals @@ -201,11 +204,15 @@ def _main_loop( f_stop = asyncio.Event() serverapp_th = None try: - # Initialize Driver - driver = InMemoryDriver(state_factory) + # Create run (with empty fab_id and fab_version) + run_id_ = state_factory.state().create_run("", "") if run_id: - _init_run_id(driver, state_factory, run_id) + _override_run_id(state_factory, run_id_to_replace=run_id_, run_id=run_id) + run_id_ = run_id + + # Initialize Driver + driver = InMemoryDriver(run_id=run_id_, state_factory=state_factory) # Get and run ServerApp thread serverapp_th = run_serverapp_th( diff --git a/src/py/flwr_tool/check_copyright.py b/src/py/flwr_tool/check_copyright.py new file mode 100755 index 000000000000..96870ba67bd0 --- /dev/null +++ b/src/py/flwr_tool/check_copyright.py @@ -0,0 +1,76 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +"""Check if copyright notices are present in all Python files. + +Example: + python -m flwr_tool.check_copyright src/py/flwr +""" + + +import os +import subprocess +import sys +from pathlib import Path +from typing import List + +from flwr_tool.init_py_check import get_init_dir_list_and_warnings + +COPYRIGHT_FORMAT = """# Copyright {} Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ==============================================================================""" + + +def _get_file_creation_year(filepath: str) -> str: + result = subprocess.run( + ["git", "log", "--diff-filter=A", "--format=%ai", "--", filepath], + stdout=subprocess.PIPE, + text=True, + check=True, + ) + date_str = result.stdout.splitlines()[-1] # Get the first commit date + creation_year = date_str.split("-")[0] # Extract the year + return creation_year + + +def _check_copyright(dir_list: List[str]) -> None: + warning_list = [] + for valid_dir in dir_list: + if "proto" in valid_dir: + continue + + dir_path = Path(valid_dir) + for py_file in dir_path.glob("*.py"): + creation_year = _get_file_creation_year(str(py_file.absolute())) + expected_copyright = COPYRIGHT_FORMAT.format(creation_year) + + if expected_copyright not in py_file.read_text(): + warning_message = "- " + str(py_file) + warning_list.append(warning_message) + + if len(warning_list) > 0: + print("Missing or incorrect copyright notice in the following files:") + for warning in warning_list: + print(warning) + sys.exit(1) + + +if __name__ == "__main__": + if len(sys.argv) == 0: + raise Exception( # pylint: disable=W0719 + "Please provide at least one directory path relative " + "to your current working directory." + ) + for i, _ in enumerate(sys.argv): + abs_path: str = os.path.abspath(os.path.join(os.getcwd(), sys.argv[i])) + __, init_dirs = get_init_dir_list_and_warnings(abs_path) + _check_copyright(init_dirs) diff --git a/src/py/flwr_tool/fix_copyright.py b/src/py/flwr_tool/fix_copyright.py new file mode 100755 index 000000000000..a5bbbdf616f7 --- /dev/null +++ b/src/py/flwr_tool/fix_copyright.py @@ -0,0 +1,59 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +"""Fix copyright notices in all Python files of a given directory. + +Example: + python -m flwr_tool.fix_copyright src/py/flwr +""" + + +import os +import sys +from pathlib import Path +from typing import List + +from flwr_tool.check_copyright import COPYRIGHT_FORMAT, _get_file_creation_year +from flwr_tool.init_py_check import get_init_dir_list_and_warnings + + +def _insert_or_edit_copyright(py_file: Path) -> None: + contents = py_file.read_text() + lines = contents.splitlines() + creation_year = _get_file_creation_year(str(py_file.absolute())) + expected_copyright = COPYRIGHT_FORMAT.format(creation_year) + + if expected_copyright not in contents: + if "Copyright" in lines[0]: + end_index = 0 + for idx, line in enumerate(lines): + if ( + line.strip() + == COPYRIGHT_FORMAT.rsplit("\n", maxsplit=1)[-1].strip() + ): + end_index = idx + 1 + break + lines = lines[end_index:] + + lines.insert(0, expected_copyright) + py_file.write_text("\n".join(lines) + "\n") + + +def _fix_copyright(dir_list: List[str]) -> None: + for valid_dir in dir_list: + if "proto" in valid_dir: + continue + + dir_path = Path(valid_dir) + for py_file in dir_path.glob("*.py"): + _insert_or_edit_copyright(py_file) + + +if __name__ == "__main__": + if len(sys.argv) == 0: + raise Exception( # pylint: disable=W0719 + "Please provide at least one directory path relative " + "to your current working directory." + ) + for i, _ in enumerate(sys.argv): + abs_path: str = os.path.abspath(os.path.join(os.getcwd(), sys.argv[i])) + __, init_dirs = get_init_dir_list_and_warnings(abs_path) + _fix_copyright(init_dirs)