diff --git a/.github/workflows/_docker-build.yml b/.github/workflows/_docker-build.yml index b8ddd355eb8e..7d78ea881034 100644 --- a/.github/workflows/_docker-build.yml +++ b/.github/workflows/_docker-build.yml @@ -81,7 +81,7 @@ jobs: - name: Set up QEMU if: matrix.platform.qemu != '' - uses: docker/setup-qemu-action@68827325e0b33c7199eb31dd4e31fbe9023e06e3 # v3.0.0 + uses: docker/setup-qemu-action@5927c834f5b4fdf503fca6f4c7eccda82949e1ee # v3.1.0 with: platforms: ${{ matrix.platform.qemu }} @@ -92,7 +92,7 @@ jobs: images: ${{ inputs.namespace-repository }} - name: Set up Docker Buildx - uses: docker/setup-buildx-action@d70bba72b1f3fd22344832f00baa16ece964efeb # v3.3.0 + uses: docker/setup-buildx-action@4fd812986e6c8c2a69e18311145f9371337f27d4 # v3.4.0 - name: Login to Docker Hub uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3.2.0 @@ -122,7 +122,7 @@ jobs: touch "/tmp/digests/${digest#sha256:}" - name: Upload digest - uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # v4.3.3 + uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4.3.4 with: name: digests-${{ steps.build-id.outputs.id }}-${{ matrix.platform.name }} path: /tmp/digests/* @@ -138,7 +138,7 @@ jobs: metadata: ${{ steps.meta.outputs.json }} steps: - name: Download digests - uses: actions/download-artifact@65a9edc5881444af0b9093a5e628f2fe47ea3b2e # v4.1.7 + uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: pattern: digests-${{ needs.build.outputs.build-id }}-* path: /tmp/digests @@ -152,7 +152,7 @@ jobs: tags: ${{ inputs.tags }} - name: Set up Docker Buildx - uses: docker/setup-buildx-action@d70bba72b1f3fd22344832f00baa16ece964efeb # v3.3.0 + uses: docker/setup-buildx-action@4fd812986e6c8c2a69e18311145f9371337f27d4 # v3.4.0 - name: Login to Docker Hub uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3.2.0 diff --git a/.github/workflows/docker-images.yml b/.github/workflows/docker-images.yml deleted file mode 100644 index e341ae62e3f7..000000000000 --- a/.github/workflows/docker-images.yml +++ /dev/null @@ -1,75 +0,0 @@ -name: Build docker images - -on: - workflow_dispatch: - inputs: - flwr-version: - description: "Version of Flower." - required: true - type: string - -permissions: - contents: read - -jobs: - parameters: - name: Collect build parameters - runs-on: ubuntu-22.04 - timeout-minutes: 10 - outputs: - pip-version: ${{ steps.versions.outputs.pip-version }} - setuptools-version: ${{ steps.versions.outputs.setuptools-version }} - matrix: ${{ steps.matrix.outputs.matrix }} - steps: - - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - - - uses: ./.github/actions/bootstrap - id: bootstrap - - - id: versions - run: | - echo "pip-version=${{ steps.bootstrap.outputs.pip-version }}" >> "$GITHUB_OUTPUT" - echo "setuptools-version=${{ steps.bootstrap.outputs.setuptools-version }}" >> "$GITHUB_OUTPUT" - - - id: matrix - run: | - python dev/build-docker-image-matrix.py --flwr-version ${{ github.event.inputs.flwr-version }} > matrix.json - echo "matrix=$(cat matrix.json)" >> $GITHUB_OUTPUT - - build-base-images: - name: Build base images - uses: ./.github/workflows/_docker-build.yml - needs: parameters - strategy: - fail-fast: false - matrix: ${{ fromJson(needs.parameters.outputs.matrix).base }} - with: - namespace-repository: ${{ matrix.images.namespace_repository }} - file-dir: ${{ matrix.images.file_dir }} - build-args: | - PYTHON_VERSION=${{ matrix.images.python_version }} - PIP_VERSION=${{ needs.parameters.outputs.pip-version }} - SETUPTOOLS_VERSION=${{ needs.parameters.outputs.setuptools-version }} - DISTRO=${{ matrix.images.distro.name }} - DISTRO_VERSION=${{ matrix.images.distro.version }} - FLWR_VERSION=${{ matrix.images.flwr_version }} - tags: ${{ matrix.images.tag }} - secrets: - dockerhub-user: ${{ secrets.DOCKERHUB_USERNAME }} - dockerhub-token: ${{ secrets.DOCKERHUB_TOKEN }} - - build-binary-images: - name: Build binary images - uses: ./.github/workflows/_docker-build.yml - needs: [parameters, build-base-images] - strategy: - fail-fast: false - matrix: ${{ fromJson(needs.parameters.outputs.matrix).binary }} - with: - namespace-repository: ${{ matrix.images.namespace_repository }} - file-dir: ${{ matrix.images.file_dir }} - build-args: BASE_IMAGE=${{ matrix.images.base_image }} - tags: ${{ matrix.images.tags }} - secrets: - dockerhub-user: ${{ secrets.DOCKERHUB_USERNAME }} - dockerhub-token: ${{ secrets.DOCKERHUB_TOKEN }} diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index 4cbee6f770d6..11c0073fcf1f 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -66,44 +66,39 @@ jobs: - directory: bare-client-auth - - directory: jax + - directory: framework-jax - - directory: pytorch + - directory: framework-pytorch dataset: | from torchvision.datasets import CIFAR10 CIFAR10('./data', download=True) - - directory: tensorflow + - directory: framework-tensorflow dataset: | import tensorflow as tf tf.keras.datasets.cifar10.load_data() - - directory: tabnet - dataset: | - import tensorflow_datasets as tfds - tfds.load(name='iris', split=tfds.Split.TRAIN) - - - directory: opacus + - directory: framework-opacus dataset: | from torchvision.datasets import CIFAR10 CIFAR10('./data', download=True) - - directory: pytorch-lightning + - directory: framework-pytorch-lightning dataset: | from torchvision.datasets import MNIST MNIST('./data', download=True) - - directory: scikit-learn + - directory: framework-scikit-learn dataset: | import openml openml.datasets.get_dataset(554) - - directory: fastai + - directory: framework-fastai dataset: | from fastai.vision.all import untar_data, URLs untar_data(URLs.MNIST) - - directory: pandas + - directory: framework-pandas dataset: | from pathlib import Path from sklearn.datasets import load_iris @@ -145,7 +140,7 @@ jobs: run: python -c "${{ matrix.dataset }}" - name: Run edge client test if: ${{ matrix.directory != 'bare-client-auth' }} - run: ./../test.sh "${{ matrix.directory }}" + run: ./../test_legacy.sh "${{ matrix.directory }}" - name: Run virtual client test if: ${{ matrix.directory != 'bare-client-auth' }} run: python simulation.py @@ -154,16 +149,16 @@ jobs: run: python simulation_next.py - name: Run driver test if: ${{ matrix.directory != 'bare-client-auth' }} - run: ./../test_driver.sh "${{ matrix.directory }}" + run: ./../test_superlink.sh "${{ matrix.directory }}" - name: Run driver test with REST if: ${{ matrix.directory == 'bare' }} - run: ./../test_driver.sh bare rest + run: ./../test_superlink.sh bare rest - name: Run driver test with SQLite database if: ${{ matrix.directory == 'bare' }} - run: ./../test_driver.sh bare sqlite + run: ./../test_superlink.sh bare sqlite - name: Run driver test with client authentication if: ${{ matrix.directory == 'bare-client-auth' }} - run: ./../test_driver.sh bare client-auth + run: ./../test_superlink.sh bare client-auth - name: Run reconnection test with SQLite database if: ${{ matrix.directory == 'bare' }} run: ./../test_reconnection.sh sqlite diff --git a/.github/workflows/framework-release.yml b/.github/workflows/framework-release.yml index a941b47d58fc..812d5b1e398e 100644 --- a/.github/workflows/framework-release.yml +++ b/.github/workflows/framework-release.yml @@ -43,3 +43,70 @@ jobs: curl $tar_url --output dist/$tar_name python -m poetry publish -u __token__ -p ${{ secrets.PYPI_TOKEN_RELEASE_FLWR }} + + parameters: + if: ${{ github.repository == 'adap/flower' }} + name: Collect docker build parameters + runs-on: ubuntu-22.04 + timeout-minutes: 10 + needs: publish + outputs: + pip-version: ${{ steps.versions.outputs.pip-version }} + setuptools-version: ${{ steps.versions.outputs.setuptools-version }} + matrix: ${{ steps.matrix.outputs.matrix }} + steps: + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 + + - uses: ./.github/actions/bootstrap + id: bootstrap + + - id: versions + run: | + echo "pip-version=${{ steps.bootstrap.outputs.pip-version }}" >> "$GITHUB_OUTPUT" + echo "setuptools-version=${{ steps.bootstrap.outputs.setuptools-version }}" >> "$GITHUB_OUTPUT" + + - id: matrix + run: | + FLWR_VERSION=$(poetry version -s) + python dev/build-docker-image-matrix.py --flwr-version "${FLWR_VERSION}" > matrix.json + echo "matrix=$(cat matrix.json)" >> $GITHUB_OUTPUT + + build-base-images: + if: ${{ github.repository == 'adap/flower' }} + name: Build base images + uses: ./.github/workflows/_docker-build.yml + needs: parameters + strategy: + fail-fast: false + matrix: ${{ fromJson(needs.parameters.outputs.matrix).base }} + with: + namespace-repository: ${{ matrix.images.namespace_repository }} + file-dir: ${{ matrix.images.file_dir }} + build-args: | + PYTHON_VERSION=${{ matrix.images.python_version }} + PIP_VERSION=${{ needs.parameters.outputs.pip-version }} + SETUPTOOLS_VERSION=${{ needs.parameters.outputs.setuptools-version }} + DISTRO=${{ matrix.images.distro.name }} + DISTRO_VERSION=${{ matrix.images.distro.version }} + FLWR_VERSION=${{ matrix.images.flwr_version }} + tags: ${{ matrix.images.tag }} + secrets: + dockerhub-user: ${{ secrets.DOCKERHUB_USERNAME }} + dockerhub-token: ${{ secrets.DOCKERHUB_TOKEN }} + + build-binary-images: + if: ${{ github.repository == 'adap/flower' }} + name: Build binary images + uses: ./.github/workflows/_docker-build.yml + needs: [parameters, build-base-images] + strategy: + fail-fast: false + matrix: ${{ fromJson(needs.parameters.outputs.matrix).binary }} + with: + namespace-repository: ${{ matrix.images.namespace_repository }} + file-dir: ${{ matrix.images.file_dir }} + build-args: BASE_IMAGE=${{ matrix.images.base_image }} + tags: ${{ matrix.images.tags }} + secrets: + dockerhub-user: ${{ secrets.DOCKERHUB_USERNAME }} + dockerhub-token: ${{ secrets.DOCKERHUB_TOKEN }} diff --git a/.github/workflows/release-nightly.yml b/.github/workflows/release-nightly.yml index 2b72190bede5..97751aafc031 100644 --- a/.github/workflows/release-nightly.yml +++ b/.github/workflows/release-nightly.yml @@ -69,7 +69,8 @@ jobs: images: [ { repository: "flwr/superlink", file_dir: "src/docker/superlink" }, { repository: "flwr/supernode", file_dir: "src/docker/supernode" }, - { repository: "flwr/serverapp", file_dir: "src/docker/serverapp" } + { repository: "flwr/serverapp", file_dir: "src/docker/serverapp" }, + { repository: "flwr/superexec", file_dir: "src/docker/superexec" } ] with: namespace-repository: ${{ matrix.images.repository }} diff --git a/datasets/README.md b/datasets/README.md index e0d8e688c779..776076308fed 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -42,12 +42,13 @@ Create **custom partitioning schemes** or choose from the **implemented [partiti * IID partitioning `IidPartitioner(num_partitions)` * Dirichlet partitioning `DirichletPartitioner(num_partitions, partition_by, alpha)` * InnerDirichlet partitioning `InnerDirichletPartitioner(partition_sizes, partition_by, alpha)` -* Natural ID partitioner `NaturalIdPartitioner(partition_by)` -* Size partitioner (the abstract base class for the partitioners dictating the division based the number of samples) `SizePartitioner` -* Linear partitioner `LinearPartitioner(num_partitions)` -* Square partitioner `SquarePartitioner(num_partitions)` -* Exponential partitioner `ExponentialPartitioner(num_partitions)` -* Semantic partitioner `SemanticPartitioner(partition_by)` (only for image datasets) +* Pathological partitioning `PathologicalPartitioner(num_partitions, partition_by, num_classes_per_partition, class_assignment_mode)` +* Natural ID partitioning `NaturalIdPartitioner(partition_by)` +* Size based partitioning (the abstract base class for the partitioners dictating the division based the number of samples) `SizePartitioner` +* Linear partitioning `LinearPartitioner(num_partitions)` +* Square partitioning `SquarePartitioner(num_partitions)` +* Exponential partitioning `ExponentialPartitioner(num_partitions)` +* Semantic partitioning (only for image datasets) `SemanticPartitioner(num_partitions)` * more to come in the future releases (contributions are welcome).

Comparison of partitioning schemes. diff --git a/datasets/doc/source/conf.py b/datasets/doc/source/conf.py index 11285c375f96..c287ec318b5d 100644 --- a/datasets/doc/source/conf.py +++ b/datasets/doc/source/conf.py @@ -162,7 +162,7 @@ def find_test_modules(package_path): .. raw:: html
- + Open in Colab """ diff --git a/datasets/doc/source/index.rst b/datasets/doc/source/index.rst index bdcea7650bbc..fcc7920711bf 100644 --- a/datasets/doc/source/index.rst +++ b/datasets/doc/source/index.rst @@ -94,6 +94,7 @@ Here are a few of the ``Partitioner`` s that are available: (for a full list see * IID partitioning ``IidPartitioner(num_partitions)`` * Dirichlet partitioning ``DirichletPartitioner(num_partitions, partition_by, alpha)`` * InnerDirichlet partitioning ``InnerDirichletPartitioner(partition_sizes, partition_by, alpha)`` +* PathologicalPartitioner ``PathologicalPartitioner(num_partitions, partition_by, num_classes_per_partition, class_assignment_mode)`` * Natural ID partitioner ``NaturalIdPartitioner(partition_by)`` * Size partitioner (the abstract base class for the partitioners dictating the division based the number of samples) ``SizePartitioner`` * Linear partitioner ``LinearPartitioner(num_partitions)`` diff --git a/datasets/flwr_datasets/partitioner/__init__.py b/datasets/flwr_datasets/partitioner/__init__.py index 59435d072c60..46861aa8a0b0 100644 --- a/datasets/flwr_datasets/partitioner/__init__.py +++ b/datasets/flwr_datasets/partitioner/__init__.py @@ -22,6 +22,7 @@ from .linear_partitioner import LinearPartitioner from .natural_id_partitioner import NaturalIdPartitioner from .partitioner import Partitioner +from .pathological_partitioner import PathologicalPartitioner from .semantic_partitioner import SemanticPartitioner from .shard_partitioner import ShardPartitioner from .size_partitioner import SizePartitioner @@ -35,6 +36,7 @@ "LinearPartitioner", "NaturalIdPartitioner", "Partitioner", + "PathologicalPartitioner", "SemanticPartitioner", "ShardPartitioner", "SizePartitioner", diff --git a/datasets/flwr_datasets/partitioner/pathological_partitioner.py b/datasets/flwr_datasets/partitioner/pathological_partitioner.py new file mode 100644 index 000000000000..1ee60d283044 --- /dev/null +++ b/datasets/flwr_datasets/partitioner/pathological_partitioner.py @@ -0,0 +1,305 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Pathological partitioner class that works with Hugging Face Datasets.""" + + +import warnings +from typing import Any, Dict, List, Literal, Optional + +import numpy as np + +import datasets +from flwr_datasets.common.typing import NDArray +from flwr_datasets.partitioner.partitioner import Partitioner + + +# pylint: disable=too-many-arguments, too-many-instance-attributes +class PathologicalPartitioner(Partitioner): + """Partition dataset such that each partition has a chosen number of classes. + + Implementation based on Federated Learning on Non-IID Data Silos: An Experimental + Study https://arxiv.org/pdf/2102.02079. + + The algorithm firstly determines which classe will be assigned to which partitions. + For each partition `num_classes_per_partition` are sampled in a way chosen in + `class_assignment_mode`. Given the information about the required classes for each + partition, it is determined into how many parts the samples corresponding to this + label should be divided. Such division is performed for each class. + + Parameters + ---------- + num_partitions : int + The total number of partitions that the data will be divided into. + partition_by : str + Column name of the labels (targets) based on which partitioning works. + num_classes_per_partition: int + The (exact) number of unique classes that each partition will have. + class_assignment_mode: Literal["random", "deterministic", "first-deterministic"] + The way how the classes are assigned to the partitions. The default is "random". + The possible values are: + + - "random": Randomly assign classes to the partitions. For each partition choose + the `num_classes_per_partition` classes without replacement. + - "first-deterministic": Assign the first class for each partition in a + deterministic way (class id is the partition_id % num_unique_classes). + The rest of the classes are assigned randomly. In case the number of + partitions is smaller than the number of unique classes, not all classes will + be used in the first iteration, otherwise all the classes will be used (such + it will be present in at least one partition). + - "deterministic": Assign all the classes to the partitions in a deterministic + way. Classes are assigned based on the formula: partion_id has classes + identified by the index: (partition_id + i) % num_unique_classes + where i in {0, ..., num_classes_per_partition}. So, partition 0 will have + classes 0, 1, 2, ..., `num_classes_per_partition`-1, partition 1 will have + classes 1, 2, 3, ...,`num_classes_per_partition`, .... + + The list representing the unique lables is sorted in ascending order. In case + of numbers starting from zero the class id corresponds to the number itself. + `class_assignment_mode="first-deterministic"` was used in the orginal paper, + here we provide the option to use the other modes as well. + shuffle: bool + Whether to randomize the order of samples. Shuffling applied after the + samples assignment to partitions. + seed: int + Seed used for dataset shuffling. It has no effect if `shuffle` is False. + + Examples + -------- + In order to mimic the original behavior of the paper follow the setup below + (the `class_assignment_mode="first-deterministic"`): + + >>> from flwr_datasets.partitioner import PathologicalPartitioner + >>> from flwr_datasets import FederatedDataset + >>> + >>> partitioner = PathologicalPartitioner( + >>> num_partitions=10, + >>> partition_by="label", + >>> num_classes_per_partition=2, + >>> class_assignment_mode="first-deterministic" + >>> ) + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) + >>> partition = fds.load_partition(0) + """ + + def __init__( + self, + num_partitions: int, + partition_by: str, + num_classes_per_partition: int, + class_assignment_mode: Literal[ + "random", "deterministic", "first-deterministic" + ] = "random", + shuffle: bool = True, + seed: Optional[int] = 42, + ) -> None: + super().__init__() + self._num_partitions = num_partitions + self._partition_by = partition_by + self._num_classes_per_partition = num_classes_per_partition + self._class_assignment_mode = class_assignment_mode + self._shuffle = shuffle + self._seed = seed + self._rng = np.random.default_rng(seed=self._seed) + + # Utility attributes + self._partition_id_to_indices: Dict[int, List[int]] = {} + self._partition_id_to_unique_labels: Dict[int, List[Any]] = { + pid: [] for pid in range(self._num_partitions) + } + self._unique_labels: List[Any] = [] + # Count in how many partitions the label is used + self._unique_label_to_times_used_counter: Dict[Any, int] = {} + self._partition_id_to_indices_determined = False + + def load_partition(self, partition_id: int) -> datasets.Dataset: + """Load a partition based on the partition index. + + Parameters + ---------- + partition_id : int + The index that corresponds to the requested partition. + + Returns + ------- + dataset_partition : Dataset + Single partition of a dataset. + """ + # The partitioning is done lazily - only when the first partition is + # requested. Only the first call creates the indices assignments for all the + # partition indices. + self._check_num_partitions_correctness_if_needed() + self._determine_partition_id_to_indices_if_needed() + return self.dataset.select(self._partition_id_to_indices[partition_id]) + + @property + def num_partitions(self) -> int: + """Total number of partitions.""" + self._check_num_partitions_correctness_if_needed() + self._determine_partition_id_to_indices_if_needed() + return self._num_partitions + + def _determine_partition_id_to_indices_if_needed(self) -> None: + """Create an assignment of indices to the partition indices.""" + if self._partition_id_to_indices_determined: + return + self._determine_partition_id_to_unique_labels() + assert self._unique_labels is not None + self._count_partitions_having_each_unique_label() + + labels = np.asarray(self.dataset[self._partition_by]) + self._check_correctness_of_unique_label_to_times_used_counter(labels) + for partition_id in range(self._num_partitions): + self._partition_id_to_indices[partition_id] = [] + + unused_labels = [] + for unique_label in self._unique_labels: + if self._unique_label_to_times_used_counter[unique_label] == 0: + unused_labels.append(unique_label) + continue + # Get the indices in the original dataset where the y == unique_label + unique_label_to_indices = np.where(labels == unique_label)[0] + + split_unique_labels_to_indices = np.array_split( + unique_label_to_indices, + self._unique_label_to_times_used_counter[unique_label], + ) + + split_index = 0 + for partition_id in range(self._num_partitions): + if unique_label in self._partition_id_to_unique_labels[partition_id]: + self._partition_id_to_indices[partition_id].extend( + split_unique_labels_to_indices[split_index] + ) + split_index += 1 + + if len(unused_labels) >= 1: + warnings.warn( + f"Classes: {unused_labels} will NOT be used due to the chosen " + f"configuration. If it is undesired behavior consider setting" + f" 'first_class_deterministic_assignment=True' which in case when" + f" the number of classes is smaller than the number of partitions will " + f"utilize all the classes for the created partitions.", + stacklevel=1, + ) + if self._shuffle: + for indices in self._partition_id_to_indices.values(): + # In place shuffling + self._rng.shuffle(indices) + + self._partition_id_to_indices_determined = True + + def _check_num_partitions_correctness_if_needed(self) -> None: + """Test num_partitions when the dataset is given (in load_partition).""" + if not self._partition_id_to_indices_determined: + if self._num_partitions > self.dataset.num_rows: + raise ValueError( + "The number of partitions needs to be smaller than the number of " + "samples in the dataset." + ) + + def _determine_partition_id_to_unique_labels(self) -> None: + """Determine the assignment of unique labels to the partitions.""" + self._unique_labels = sorted(self.dataset.unique(self._partition_by)) + num_unique_classes = len(self._unique_labels) + + if self._num_classes_per_partition > num_unique_classes: + raise ValueError( + f"The specified `num_classes_per_partition`" + f"={self._num_classes_per_partition} is greater than the number " + f"of unique classes in the given dataset={num_unique_classes}. " + f"Reduce the `num_classes_per_partition` or make use different dataset " + f"to apply this partitioning." + ) + if self._class_assignment_mode == "first-deterministic": + # if self._first_class_deterministic_assignment: + for partition_id in range(self._num_partitions): + label = partition_id % num_unique_classes + self._partition_id_to_unique_labels[partition_id].append(label) + + while ( + len(self._partition_id_to_unique_labels[partition_id]) + < self._num_classes_per_partition + ): + label = self._rng.choice(self._unique_labels, size=1)[0] + if label not in self._partition_id_to_unique_labels[partition_id]: + self._partition_id_to_unique_labels[partition_id].append(label) + elif self._class_assignment_mode == "deterministic": + for partition_id in range(self._num_partitions): + labels = [] + for i in range(self._num_classes_per_partition): + label = self._unique_labels[ + (partition_id + i) % len(self._unique_labels) + ] + labels.append(label) + self._partition_id_to_unique_labels[partition_id] = labels + elif self._class_assignment_mode == "random": + for partition_id in range(self._num_partitions): + labels = self._rng.choice( + self._unique_labels, + size=self._num_classes_per_partition, + replace=False, + ).tolist() + self._partition_id_to_unique_labels[partition_id] = labels + else: + raise ValueError( + f"The supported class_assignment_mode are: 'random', 'deterministic', " + f"'first-deterministic'. You provided: {self._class_assignment_mode}." + ) + + def _count_partitions_having_each_unique_label(self) -> None: + """Count the number of partitions that have each unique label. + + This computation is based on the assigment of the label to the partition_id in + the `_determine_partition_id_to_unique_labels` method. + Given: + * partition 0 has only labels: 0,1 (not necessarily just two samples it can have + many samples but either from 0 or 1) + * partition 1 has only labels: 1, 2 (same count note as above) + * and there are only two partitions then the following will be computed: + { + 0: 1, + 1: 2, + 2: 1 + } + """ + for unique_label in self._unique_labels: + self._unique_label_to_times_used_counter[unique_label] = 0 + for unique_labels in self._partition_id_to_unique_labels.values(): + for unique_label in unique_labels: + self._unique_label_to_times_used_counter[unique_label] += 1 + + def _check_correctness_of_unique_label_to_times_used_counter( + self, labels: NDArray + ) -> None: + """Check if partitioning is possible given the presence requirements. + + The number of times the label can be used must be smaller or equal to the number + of times that the label is present in the dataset. + """ + for unique_label in self._unique_labels: + num_unique = np.sum(labels == unique_label) + if self._unique_label_to_times_used_counter[unique_label] > num_unique: + raise ValueError( + f"Label: {unique_label} is needed to be assigned to more " + f"partitions " + f"({self._unique_label_to_times_used_counter[unique_label]})" + f" than there are samples (corresponding to this label) in the " + f"dataset ({num_unique}). Please decrease the `num_partitions`, " + f"`num_classes_per_partition` to avoid this situation, " + f"or try `class_assigment_mode='deterministic'` to create a more " + f"even distribution of classes along the partitions. " + f"Alternatively use a different dataset if you can not adjust" + f" the any of these parameters." + ) diff --git a/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py b/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py new file mode 100644 index 000000000000..151b7e14659c --- /dev/null +++ b/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py @@ -0,0 +1,262 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test cases for PathologicalPartitioner.""" + + +import unittest +from typing import Dict + +import numpy as np +from parameterized import parameterized + +import datasets +from datasets import Dataset +from flwr_datasets.partitioner.pathological_partitioner import PathologicalPartitioner + + +def _dummy_dataset_setup( + num_samples: int, partition_by: str, num_unique_classes: int +) -> Dataset: + """Create a dummy dataset for testing.""" + data = { + partition_by: np.tile( + np.arange(num_unique_classes), num_samples // num_unique_classes + 1 + )[:num_samples], + "features": np.random.randn(num_samples), + } + return Dataset.from_dict(data) + + +def _dummy_heterogeneous_dataset_setup( + num_samples: int, partition_by: str, num_unique_classes: int +) -> Dataset: + """Create a dummy dataset for testing.""" + data = { + partition_by: np.tile( + np.arange(num_unique_classes), num_samples // num_unique_classes + 1 + )[:num_samples], + "features": np.random.randn(num_samples), + } + return Dataset.from_dict(data) + + +class TestClassConstrainedPartitioner(unittest.TestCase): + """Unit tests for PathologicalPartitioner.""" + + @parameterized.expand( # type: ignore + [ + # num_partition, num_classes_per_partition, num_samples, total_classes + (3, 1, 60, 3), # Single class per partition scenario + (5, 2, 100, 5), + (5, 2, 100, 10), + (4, 3, 120, 6), + ] + ) + def test_correct_num_classes_when_partitioned( + self, + num_partitions: int, + num_classes_per_partition: int, + num_samples: int, + num_unique_classes: int, + ) -> None: + """Test correct number of unique classes.""" + dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes) + partitioner = PathologicalPartitioner( + num_partitions=num_partitions, + partition_by="labels", + num_classes_per_partition=num_classes_per_partition, + ) + partitioner.dataset = dataset + partitions: Dict[int, Dataset] = { + pid: partitioner.load_partition(pid) for pid in range(num_partitions) + } + unique_classes_per_partition = { + pid: np.unique(partition["labels"]) for pid, partition in partitions.items() + } + + for unique_classes in unique_classes_per_partition.values(): + self.assertEqual(num_classes_per_partition, len(unique_classes)) + + def test_first_class_deterministic_assignment(self) -> None: + """Test deterministic assignment of first classes to partitions. + + Test if all the classes are used (which has to be the case, given num_partitions + >= than the number of unique classes). + """ + dataset = _dummy_dataset_setup(100, "labels", 10) + partitioner = PathologicalPartitioner( + num_partitions=10, + partition_by="labels", + num_classes_per_partition=2, + class_assignment_mode="first-deterministic", + ) + partitioner.dataset = dataset + partitioner.load_partition(0) + expected_classes = set(range(10)) + actual_classes = set() + for pid in range(10): + partition = partitioner.load_partition(pid) + actual_classes.update(np.unique(partition["labels"])) + self.assertEqual(expected_classes, actual_classes) + + @parameterized.expand( + [ # type: ignore + # num_partitions, num_classes_per_partition, num_samples, num_unique_classes + (4, 2, 80, 8), + (10, 2, 100, 10), + ] + ) + def test_deterministic_class_assignment( + self, num_partitions, num_classes_per_partition, num_samples, num_unique_classes + ): + """Test deterministic assignment of classes to partitions.""" + dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes) + partitioner = PathologicalPartitioner( + num_partitions=num_partitions, + partition_by="labels", + num_classes_per_partition=num_classes_per_partition, + class_assignment_mode="deterministic", + ) + partitioner.dataset = dataset + partitions = { + pid: partitioner.load_partition(pid) for pid in range(num_partitions) + } + + # Verify each partition has the expected classes, order does not matter + for pid, partition in partitions.items(): + expected_labels = sorted( + [ + (pid + i) % num_unique_classes + for i in range(num_classes_per_partition) + ] + ) + actual_labels = sorted(np.unique(partition["labels"])) + self.assertTrue( + np.array_equal(expected_labels, actual_labels), + f"Partition {pid} does not have the expected labels: " + f"{expected_labels} but instead {actual_labels}.", + ) + + @parameterized.expand( + [ # type: ignore + # num_partitions, num_classes_per_partition, num_samples, num_unique_classes + (10, 3, 20, 3), + ] + ) + def test_too_many_partitions_for_a_class( + self, num_partitions, num_classes_per_partition, num_samples, num_unique_classes + ) -> None: + """Test too many partitions for the number of samples in a class.""" + dataset_1 = _dummy_dataset_setup( + num_samples // 2, "labels", num_unique_classes - 1 + ) + # Create a skewed part of the dataset for the last label + data = { + "labels": np.array([num_unique_classes - 1] * (num_samples // 2)), + "features": np.random.randn(num_samples // 2), + } + dataset_2 = Dataset.from_dict(data) + dataset = datasets.concatenate_datasets([dataset_1, dataset_2]) + + partitioner = PathologicalPartitioner( + num_partitions=num_partitions, + partition_by="labels", + num_classes_per_partition=num_classes_per_partition, + class_assignment_mode="random", + ) + partitioner.dataset = dataset + + with self.assertRaises(ValueError) as context: + _ = partitioner.load_partition(0) + self.assertEqual( + str(context.exception), + "Label: 0 is needed to be assigned to more partitions (10) than there are " + "samples (corresponding to this label) in the dataset (5). " + "Please decrease the `num_partitions`, `num_classes_per_partition` to " + "avoid this situation, or try `class_assigment_mode='deterministic'` to " + "create a more even distribution of classes along the partitions. " + "Alternatively use a different dataset if you can not adjust the any of " + "these parameters.", + ) + + @parameterized.expand( # type: ignore + [ + # num_partitions, num_classes_per_partition, num_samples, num_unique_classes + (10, 11, 100, 10), # 11 > 10 + (5, 11, 100, 10), # 11 > 10 + (10, 20, 100, 5), # 20 > 5 + ] + ) + def test_more_classes_per_partition_than_num_unique_classes_in_dataset_raises( + self, + num_partitions: int, + num_classes_per_partition: int, + num_samples: int, + num_unique_classes: int, + ) -> None: + """Test more num_classes_per_partition > num_unique_classes in the dataset.""" + dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes) + with self.assertRaises(ValueError) as context: + partitioner = PathologicalPartitioner( + num_partitions=num_partitions, + partition_by="labels", + num_classes_per_partition=num_classes_per_partition, + ) + partitioner.dataset = dataset + partitioner.load_partition(0) + self.assertEqual( + str(context.exception), + "The specified " + f"`num_classes_per_partition`={num_classes_per_partition} is " + f"greater than the number of unique classes in the given " + f"dataset={len(dataset.unique('labels'))}. Reduce the " + f"`num_classes_per_partition` or make use different dataset " + f"to apply this partitioning.", + ) + + @parameterized.expand( # type: ignore + [ + # num_classes_per_partition should be irrelevant since the exception should + # be raised at the very beginning + # num_partitions, num_classes_per_partition, num_samples + (10, 2, 5), + (10, 10, 5), + (100, 10, 99), + ] + ) + def test_more_partitions_than_samples_raises( + self, num_partitions: int, num_classes_per_partition: int, num_samples: int + ) -> None: + """Test if generation of more partitions that there are samples raises.""" + # The number of unique classes in the dataset should be irrelevant since the + # exception should be raised at the very beginning + dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes=5) + with self.assertRaises(ValueError) as context: + partitioner = PathologicalPartitioner( + num_partitions=num_partitions, + partition_by="labels", + num_classes_per_partition=num_classes_per_partition, + ) + partitioner.dataset = dataset + partitioner.load_partition(0) + self.assertEqual( + str(context.exception), + "The number of partitions needs to be smaller than the number of " + "samples in the dataset.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/dev/build-docker-image-matrix.py b/dev/build-docker-image-matrix.py index 51d7fd0083d1..b7c4d2daaefd 100644 --- a/dev/build-docker-image-matrix.py +++ b/dev/build-docker-image-matrix.py @@ -168,6 +168,13 @@ def tag_latest_ubuntu_with_flwr_version(image: BaseImage) -> List[str]: tag_latest_ubuntu_with_flwr_version, lambda image: image.distro.name == DistroName.UBUNTU, ) + # ubuntu images for each supported python version + + generate_binary_images( + "superexec", + base_images, + tag_latest_ubuntu_with_flwr_version, + lambda image: image.distro.name == DistroName.UBUNTU, + ) ) print( diff --git a/dev/build-docs.sh b/dev/build-docs.sh index f8d4f91508de..f4bf958b0ebf 100755 --- a/dev/build-docs.sh +++ b/dev/build-docs.sh @@ -8,9 +8,7 @@ cd $ROOT ./dev/build-baseline-docs.sh cd $ROOT -./dev/update-examples.sh -cd examples/doc -make docs +python dev/build-example-docs.py cd $ROOT ./datasets/dev/build-flwr-datasets-docs.sh diff --git a/dev/build-example-docs.py b/dev/build-example-docs.py new file mode 100644 index 000000000000..367994708bf9 --- /dev/null +++ b/dev/build-example-docs.py @@ -0,0 +1,283 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Build the Flower Example docs.""" + +import os +import shutil +import re +import subprocess +from pathlib import Path + +ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +INDEX = os.path.join(ROOT, "examples", "doc", "source", "index.rst") + +initial_text = """ +Flower Examples Documentation +----------------------------- + +Welcome to Flower Examples' documentation. `Flower `_ is +a friendly federated learning framework. + +Join the Flower Community +------------------------- + +The Flower Community is growing quickly - we're a friendly group of researchers, +engineers, students, professionals, academics, and other enthusiasts. + +.. button-link:: https://flower.ai/join-slack + :color: primary + :shadow: + + Join us on Slack + +Quickstart Examples +------------------- + +Flower Quickstart Examples are a collection of demo projects that show how you +can use Flower in combination with other existing frameworks or technologies. + +""" + +table_headers = ( + "\n.. list-table::\n :widths: 50 15 15 15\n " + ":header-rows: 1\n\n * - Title\n - Framework\n - Dataset\n - Tags\n\n" +) + +categories = { + "quickstart": {"table": table_headers, "list": ""}, + "advanced": {"table": table_headers, "list": ""}, + "other": {"table": table_headers, "list": ""}, +} + +urls = { + # Frameworks + "Android": "https://www.android.com/", + "C++": "https://isocpp.org/", + "Docker": "https://www.docker.com/", + "JAX": "https://jax.readthedocs.io/en/latest/", + "Java": "https://www.java.com/", + "Keras": "https://keras.io/", + "Kotlin": "https://kotlinlang.org/", + "mlcube": "https://docs.mlcommons.org/mlcube/", + "MLX": "https://ml-explore.github.io/mlx/build/html/index.html", + "MONAI": "https://monai.io/", + "PEFT": "https://huggingface.co/docs/peft/index", + "Swift": "https://www.swift.org/", + "TensorFlowLite": "https://www.tensorflow.org/lite", + "fastai": "https://fast.ai/", + "lifelines": "https://lifelines.readthedocs.io/en/latest/index.html", + "lightning": "https://lightning.ai/docs/pytorch/stable/", + "numpy": "https://numpy.org/", + "opacus": "https://opacus.ai/", + "pandas": "https://pandas.pydata.org/", + "scikit-learn": "https://scikit-learn.org/", + "tabnet": "https://github.com/titu1994/tf-TabNet", + "tensorboard": "https://www.tensorflow.org/tensorboard", + "tensorflow": "https://www.tensorflow.org/", + "torch": "https://pytorch.org/", + "torchvision": "https://pytorch.org/vision/stable/index.html", + "transformers": "https://huggingface.co/docs/transformers/index", + "wandb": "https://wandb.ai/home", + "whisper": "https://huggingface.co/openai/whisper-tiny", + "xgboost": "https://xgboost.readthedocs.io/en/stable/", + # Datasets + "Adult Census Income": "https://www.kaggle.com/datasets/uciml/adult-census-income/data", + "Alpaca-GPT4": "https://huggingface.co/datasets/vicgalle/alpaca-gpt4", + "CIFAR-10": "https://huggingface.co/datasets/uoft-cs/cifar10", + "HIGGS": "https://archive.ics.uci.edu/dataset/280/higgs", + "IMDB": "https://huggingface.co/datasets/stanfordnlp/imdb", + "Iris": "https://scikit-learn.org/stable/auto_examples/datasets/plot_iris_dataset.html", + "MNIST": "https://huggingface.co/datasets/ylecun/mnist", + "MedNIST": "https://medmnist.com/", + "Oxford Flower-102": "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/", + "SpeechCommands": "https://huggingface.co/datasets/google/speech_commands", + "Titanic": "https://www.kaggle.com/competitions/titanic", + "Waltons": "https://lifelines.readthedocs.io/en/latest/lifelines.datasets.html#lifelines.datasets.load_waltons", +} + + +def _convert_to_link(search_result): + if "," in search_result: + result = "" + for part in search_result.split(","): + result += f"{_convert_to_link(part)}, " + return result[:-2] + else: + search_result = search_result.strip() + name, url = search_result, urls.get(search_result, None) + if url: + return f"`{name.strip()} <{url.strip()}>`_" + else: + return search_result + + +def _read_metadata(example): + with open(os.path.join(example, "README.md")) as f: + content = f.read() + + metadata_match = re.search(r"^---(.*?)^---", content, re.DOTALL | re.MULTILINE) + if not metadata_match: + raise ValueError("Metadata block not found") + metadata = metadata_match.group(1) + + title_match = re.search(r"^# (.+)$", content, re.MULTILINE) + if not title_match: + raise ValueError("Title not found in metadata") + title = title_match.group(1).strip() + + tags_match = re.search(r"^tags:\s*\[(.+?)\]$", metadata, re.MULTILINE) + if not tags_match: + raise ValueError("Tags not found in metadata") + tags = tags_match.group(1).strip() + + dataset_match = re.search( + r"^dataset:\s*\[(.*?)\]$", metadata, re.DOTALL | re.MULTILINE + ) + if not dataset_match: + raise ValueError("Dataset not found in metadata") + dataset = dataset_match.group(1).strip() + + framework_match = re.search( + r"^framework:\s*\[(.*?|)\]$", metadata, re.DOTALL | re.MULTILINE + ) + if not framework_match: + raise ValueError("Framework not found in metadata") + framework = framework_match.group(1).strip() + + dataset = _convert_to_link(re.sub(r"\s+", " ", dataset).strip()) + framework = _convert_to_link(re.sub(r"\s+", " ", framework).strip()) + return title, tags, dataset, framework + + +def _add_table_entry(example, tag, table_var): + title, tags, dataset, framework = _read_metadata(example) + example_name = Path(example).stem + table_entry = ( + f" * - `{title} <{example_name}.html>`_ \n " + f"- {framework} \n - {dataset} \n - {tags}\n\n" + ) + if tag in tags: + categories[table_var]["table"] += table_entry + categories[table_var]["list"] += f" {example_name}\n" + return True + return False + + +def _copy_markdown_files(example): + for file in os.listdir(example): + if file.endswith(".md"): + src = os.path.join(example, file) + dest = os.path.join( + ROOT, "examples", "doc", "source", os.path.basename(example) + ".md" + ) + shutil.copyfile(src, dest) + + +def _add_gh_button(example): + gh_text = f'[View on GitHub](https://github.com/adap/flower/blob/main/examples/{example})' + readme_file = os.path.join(ROOT, "examples", "doc", "source", example + ".md") + with open(readme_file, "r+") as f: + content = f.read() + if gh_text not in content: + content = re.sub( + r"(^# .+$)", rf"\1\n\n{gh_text}", content, count=1, flags=re.MULTILINE + ) + f.seek(0) + f.write(content) + f.truncate() + + +def _copy_images(example): + static_dir = os.path.join(example, "_static") + dest_dir = os.path.join(ROOT, "examples", "doc", "source", "_static") + if os.path.isdir(static_dir): + for file in os.listdir(static_dir): + if file.endswith((".jpg", ".png", ".jpeg")): + shutil.copyfile( + os.path.join(static_dir, file), os.path.join(dest_dir, file) + ) + + +def _add_all_entries(): + examples_dir = os.path.join(ROOT, "examples") + for example in sorted(os.listdir(examples_dir)): + example_path = os.path.join(examples_dir, example) + if os.path.isdir(example_path) and example != "doc": + _copy_markdown_files(example_path) + _add_gh_button(example) + _copy_images(example) + + +def _main(): + if os.path.exists(INDEX): + os.remove(INDEX) + + with open(INDEX, "w") as index_file: + index_file.write(initial_text) + + examples_dir = os.path.join(ROOT, "examples") + for example in sorted(os.listdir(examples_dir)): + example_path = os.path.join(examples_dir, example) + if os.path.isdir(example_path) and example != "doc": + _copy_markdown_files(example_path) + _add_gh_button(example) + _copy_images(example_path) + if not _add_table_entry(example_path, "quickstart", "quickstart"): + if not _add_table_entry(example_path, "comprehensive", "comprehensive"): + if not _add_table_entry(example_path, "advanced", "advanced"): + _add_table_entry(example_path, "", "other") + + with open(INDEX, "a") as index_file: + index_file.write(categories["quickstart"]["table"]) + + index_file.write("\nAdvanced Examples\n-----------------\n") + index_file.write( + "Advanced Examples are mostly for users that are both familiar with " + "Federated Learning but also somewhat familiar with Flower's main " + "features.\n" + ) + index_file.write(categories["advanced"]["table"]) + + index_file.write("\nOther Examples\n--------------\n") + index_file.write( + "Flower Examples are a collection of example projects written with " + "Flower that explore different domains and features. You can check " + "which examples already exist and/or contribute your own example.\n" + ) + index_file.write(categories["other"]["table"]) + + _add_all_entries() + + index_file.write( + "\n.. toctree::\n :maxdepth: 1\n :caption: Quickstart\n :hidden:\n\n" + ) + index_file.write(categories["quickstart"]["list"]) + + index_file.write( + "\n.. toctree::\n :maxdepth: 1\n :caption: Advanced\n :hidden:\n\n" + ) + index_file.write(categories["advanced"]["list"]) + + index_file.write( + "\n.. toctree::\n :maxdepth: 1\n :caption: Others\n :hidden:\n\n" + ) + index_file.write(categories["other"]["list"]) + + index_file.write("\n") + + +if __name__ == "__main__": + _main() + subprocess.call(f"cd {ROOT}/examples/doc && make html", shell=True) diff --git a/dev/changelog_config.toml b/dev/changelog_config.toml index c5ff1bcdd1c1..637ea9b4b2c6 100644 --- a/dev/changelog_config.toml +++ b/dev/changelog_config.toml @@ -3,7 +3,7 @@ type = ["ci", "docs", "feat", "fix", "refactor", "break"] -project = ["framework", "baselines", "datasets", "examples"] +project = ["framework", "baselines", "datasets", "examples", "benchmarks"] scope = "skip" diff --git a/dev/format.sh b/dev/format.sh index 71edf9c6065a..e1e2abc307f1 100755 --- a/dev/format.sh +++ b/dev/format.sh @@ -18,6 +18,11 @@ find src/proto/flwr/proto -name *.proto | grep "\.proto" | xargs clang-format -i python -m black -q examples python -m docformatter -i -r examples +# Benchmarks +python -m isort benchmarks +python -m black -q benchmarks +python -m docformatter -i -r benchmarks + # E2E python -m isort e2e python -m black -q e2e diff --git a/dev/test.sh b/dev/test.sh index 8cbe88c9298b..58ac0b3d24cd 100755 --- a/dev/test.sh +++ b/dev/test.sh @@ -11,11 +11,11 @@ clang-format --Werror --dry-run src/proto/flwr/proto/* echo "- clang-format: done" echo "- isort: start" -python -m isort --check-only --skip src/py/flwr/proto src/py/flwr e2e +python -m isort --check-only --skip src/py/flwr/proto src/py/flwr benchmarks e2e echo "- isort: done" echo "- black: start" -python -m black --exclude "src\/py\/flwr\/proto" --check src/py/flwr examples e2e +python -m black --exclude "src\/py\/flwr\/proto" --check src/py/flwr benchmarks examples e2e echo "- black: done" echo "- init_py_check: start" diff --git a/dev/update-examples.sh b/dev/update-examples.sh deleted file mode 100755 index 1076b4621984..000000000000 --- a/dev/update-examples.sh +++ /dev/null @@ -1,91 +0,0 @@ -#!/bin/bash -set -e -cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/../ - -ROOT=`pwd` -INDEX=$ROOT/examples/doc/source/index.md -INSERT_LINE=6 - -copy_markdown_files () { - for file in $1/*.md; do - # Copy the README into the source of the Example docs as the name of the example - if [[ $(basename "$file") = "README.md" ]]; then - cp $file $ROOT/examples/doc/source/$1.md 2>&1 >/dev/null - else - # If the example contains other markdown files, copy them to the source of the Example docs - cp $file $ROOT/examples/doc/source/$(basename "$file") 2>&1 >/dev/null - fi - done -} - -add_gh_button () { - gh_text="[\"View](https://github.com/adap/flower/blob/main/examples/$1)" - readme_file="$ROOT/examples/doc/source/$1.md" - - if ! grep -Fq "$gh_text" "$readme_file"; then - awk -v text="$gh_text" ' - /^# / && !found { - print $0 "\n" text; - found=1; - next; - } - { print } - ' "$readme_file" > tmpfile && mv tmpfile "$readme_file" - fi -} - -copy_images () { - if [ -d "$1/_static" ]; then - cp $1/_static/**.{jpg,png,jpeg} $ROOT/examples/doc/source/_static/ 2>/dev/null || true - fi -} - -add_to_index () { - (echo $INSERT_LINE; echo a; echo $1; echo .; echo wq) | ed $INDEX 2>&1 >/dev/null -} - -add_single_entry () { - # Copy markdown files to correct folder - copy_markdown_files $1 - - # Add button linked to GitHub - add_gh_button $1 - - # Copy all images of the _static folder into the examples - # docs static folder - copy_images $1 - - # Insert the name of the example into the index file - add_to_index $1 -} - -add_all_entries () { - cd $ROOT/examples - # Iterate through each folder in examples/ - for d in $(printf '%s\n' */ | sort -V); do - # Add entry based on the name of the folder - example=${d%/} - - if [[ $example != doc ]]; then - add_single_entry $example - fi - done -} - -# Clean up before starting -rm -f $ROOT/examples/doc/source/*.md -rm -f $INDEX - -# Create empty index file -touch $INDEX - -echo "# Flower Examples Documentation" >> $INDEX -echo "" >> $INDEX -echo "\`\`\`{toctree}" >> $INDEX -echo "---" >> $INDEX -echo "maxdepth: 1" >> $INDEX -echo "---" >> $INDEX - -add_all_entries - -echo "\`\`\`" >> $INDEX diff --git a/doc/source/_static/docker-ci-release.png b/doc/source/_static/docker-ci-release.png deleted file mode 100644 index 6ec97ce9fb06..000000000000 Binary files a/doc/source/_static/docker-ci-release.png and /dev/null differ diff --git a/doc/source/contributor-how-to-release-flower.rst b/doc/source/contributor-how-to-release-flower.rst index fc4c2d436b05..4853d87bc4c1 100644 --- a/doc/source/contributor-how-to-release-flower.rst +++ b/doc/source/contributor-how-to-release-flower.rst @@ -12,24 +12,6 @@ The version number of a release is stated in ``pyproject.toml``. To release a ne 2. Once the changelog has been updated with all the changes, run ``./dev/prepare-release-changelog.sh v``, where ```` is the version stated in ``pyproject.toml`` (notice the ``v`` added before it). This will replace the ``Unreleased`` header of the changelog by the version and current date, and it will add a thanking message for the contributors. Open a pull request with those changes. 3. Once the pull request is merged, tag the release commit with the version number as soon as the PR is merged: ``git tag v`` (notice the ``v`` added before the version number), then ``git push --tags``. This will create a draft release on GitHub containing the correct artifacts and the relevant part of the changelog. 4. Check the draft release on GitHub, and if everything is good, publish it. -5. Trigger the CI for building the Docker images. - -To trigger the workflow, a collaborator must create a ``workflow_dispatch`` event in the -GitHub CI. This can be done either through the UI or via the GitHub CLI. The event requires only one -input, the Flower version, to be released. - -**Via the UI** - -1. Go to the ``Build docker images`` workflow `page `_. -2. Click on the ``Run workflow`` button and type the new version of Flower in the ``Version of Flower`` input field. -3. Click on the **green** ``Run workflow`` button. - -.. image:: _static/docker-ci-release.png - -**Via the GitHub CI** - -1. Make sure you are logged in via ``gh auth login`` and that the current working directory is the root of the Flower repository. -2. Trigger the workflow via ``gh workflow run docker-images.yml -f flwr-version=``. After the release ----------------- diff --git a/doc/source/explanation-differential-privacy.rst b/doc/source/explanation-differential-privacy.rst index 69fd333f9b13..e488f5ccbd57 100644 --- a/doc/source/explanation-differential-privacy.rst +++ b/doc/source/explanation-differential-privacy.rst @@ -32,7 +32,7 @@ and for all possible outputs S ⊆ Range(A): .. math:: \small - P[M(D_{1} \in A)] \leq e^{\delta} P[M(D_{2} \in A)] + \delta + P[M(D_{1} \in A)] \leq e^{\epsilon} P[M(D_{2} \in A)] + \delta The :math:`\epsilon` parameter, also known as the privacy budget, is a metric of privacy loss. diff --git a/doc/source/tutorial-quickstart-xgboost.rst b/doc/source/tutorial-quickstart-xgboost.rst index 7ac055138814..34ad5f6e99c0 100644 --- a/doc/source/tutorial-quickstart-xgboost.rst +++ b/doc/source/tutorial-quickstart-xgboost.rst @@ -96,26 +96,26 @@ Prior to local training, we require loading the HIGGS dataset from Flower Datase fds = FederatedDataset(dataset="jxie/higgs", partitioners={"train": partitioner}) # Load the partition for this `node_id` - partition = fds.load_partition(node_id=args.node_id, split="train") + partition = fds.load_partition(partition_id=args.partition_id, split="train") partition.set_format("numpy") -In this example, we split the dataset into two partitions with uniform distribution (:code:`IidPartitioner(num_partitions=2)`). -Then, we load the partition for the given client based on :code:`node_id`: +In this example, we split the dataset into 30 partitions with uniform distribution (:code:`IidPartitioner(num_partitions=30)`). +Then, we load the partition for the given client based on :code:`partition_id`: .. code-block:: python - # We first define arguments parser for user to specify the client/node ID. + # We first define arguments parser for user to specify the client/partition ID. parser = argparse.ArgumentParser() parser.add_argument( - "--node-id", + "--partition-id", default=0, type=int, - help="Node ID used for the current client.", + help="Partition ID used for the current client.", ) args = parser.parse_args() - # Load the partition for this `node_id`. - partition = fds.load_partition(idx=args.node_id, split="train") + # Load the partition for this `partition_id`. + partition = fds.load_partition(idx=args.partition_id, split="train") partition.set_format("numpy") After that, we do train/test splitting on the given partition (client's local data), and transform data format for :code:`xgboost` package. @@ -186,12 +186,23 @@ We follow the general rule to define :code:`XgbClient` class inherited from :cod .. code-block:: python class XgbClient(fl.client.Client): - def __init__(self): - self.bst = None - self.config = None + def __init__( + self, + train_dmatrix, + valid_dmatrix, + num_train, + num_val, + num_local_round, + params, + ): + self.train_dmatrix = train_dmatrix + self.valid_dmatrix = valid_dmatrix + self.num_train = num_train + self.num_val = num_val + self.num_local_round = num_local_round + self.params = params -The :code:`self.bst` is used to keep the Booster objects that remain consistent across rounds, -allowing them to store predictions from trees integrated in earlier rounds and maintain other essential data structures for training. +All required parameters defined above are passed to :code:`XgbClient`'s constructor. Then, we override :code:`get_parameters`, :code:`fit` and :code:`evaluate` methods insides :code:`XgbClient` class as follows. @@ -214,27 +225,27 @@ As a result, let's return an empty tensor in :code:`get_parameters` when it is c .. code-block:: python def fit(self, ins: FitIns) -> FitRes: - if not self.bst: + global_round = int(ins.config["global_round"]) + if global_round == 1: # First round local training - log(INFO, "Start training at round 1") bst = xgb.train( - params, - train_dmatrix, - num_boost_round=num_local_round, - evals=[(valid_dmatrix, "validate"), (train_dmatrix, "train")], + self.params, + self.train_dmatrix, + num_boost_round=self.num_local_round, + evals=[(self.valid_dmatrix, "validate"), (self.train_dmatrix, "train")], ) - self.config = bst.save_config() - self.bst = bst else: + bst = xgb.Booster(params=self.params) for item in ins.parameters.tensors: global_model = bytearray(item) # Load global model into booster - self.bst.load_model(global_model) - self.bst.load_config(self.config) + bst.load_model(global_model) - bst = self._local_boost() + # Local training + bst = self._local_boost(bst) + # Save model local_model = bst.save_raw("json") local_model_bytes = bytes(local_model) @@ -244,60 +255,81 @@ As a result, let's return an empty tensor in :code:`get_parameters` when it is c message="OK", ), parameters=Parameters(tensor_type="", tensors=[local_model_bytes]), - num_examples=num_train, + num_examples=self.num_train, metrics={}, ) In :code:`fit`, at the first round, we call :code:`xgb.train()` to build up the first set of trees. -the returned Booster object and config are stored in :code:`self.bst` and :code:`self.config`, respectively. -From the second round, we load the global model sent from server to :code:`self.bst`, +From the second round, we load the global model sent from server to new build Booster object, and then update model weights on local training data with function :code:`local_boost` as follows: .. code-block:: python - def _local_boost(self): + def _local_boost(self, bst_input): # Update trees based on local training data. - for i in range(num_local_round): - self.bst.update(train_dmatrix, self.bst.num_boosted_rounds()) + for i in range(self.num_local_round): + bst_input.update(self.train_dmatrix, bst_input.num_boosted_rounds()) - # Extract the last N=num_local_round trees for sever aggregation - bst = self.bst[ - self.bst.num_boosted_rounds() - - num_local_round : self.bst.num_boosted_rounds() + # Bagging: extract the last N=num_local_round trees for sever aggregation + bst = bst_input[ + bst_input.num_boosted_rounds() + - self.num_local_round : bst_input.num_boosted_rounds() ] -Given :code:`num_local_round`, we update trees by calling :code:`self.bst.update` method. + return bst + +Given :code:`num_local_round`, we update trees by calling :code:`bst_input.update` method. After training, the last :code:`N=num_local_round` trees will be extracted to send to the server. .. code-block:: python def evaluate(self, ins: EvaluateIns) -> EvaluateRes: - eval_results = self.bst.eval_set( - evals=[(valid_dmatrix, "valid")], - iteration=self.bst.num_boosted_rounds() - 1, + # Load global model + bst = xgb.Booster(params=self.params) + for para in ins.parameters.tensors: + para_b = bytearray(para) + bst.load_model(para_b) + + # Run evaluation + eval_results = bst.eval_set( + evals=[(self.valid_dmatrix, "valid")], + iteration=bst.num_boosted_rounds() - 1, ) auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) + global_round = ins.config["global_round"] + log(INFO, f"AUC = {auc} at round {global_round}") + return EvaluateRes( status=Status( code=Code.OK, message="OK", ), loss=0.0, - num_examples=num_val, + num_examples=self.num_val, metrics={"AUC": auc}, ) -In :code:`evaluate`, we call :code:`self.bst.eval_set` function to conduct evaluation on valid set. +In :code:`evaluate`, after loading the global model, we call :code:`bst.eval_set` function to conduct evaluation on valid set. The AUC value will be returned. Now, we can create an instance of our class :code:`XgbClient` and add one line to actually run this client: .. code-block:: python - fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient()) + fl.client.start_client( + server_address="127.0.0.1:8080", + client=XgbClient( + train_dmatrix, + valid_dmatrix, + num_train, + num_val, + num_local_round, + params, + ).to_client(), + ) -That's it for the client. We only have to implement :code:`Client`and call :code:`fl.client.start_client()`. +That's it for the client. We only have to implement :code:`Client` and call :code:`fl.client.start_client()`. The string :code:`"[::]:8080"` tells the client which server to connect to. In our case we can run the server and the client on the same machine, therefore we use :code:`"[::]:8080"`. If we run a truly federated workload with the server and @@ -325,6 +357,8 @@ We first define a strategy for XGBoost bagging aggregation. min_evaluate_clients=2, fraction_evaluate=1.0, evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation, + on_evaluate_config_fn=config_func, + on_fit_config_fn=config_func, ) def evaluate_metrics_aggregation(eval_metrics): @@ -336,8 +370,16 @@ We first define a strategy for XGBoost bagging aggregation. metrics_aggregated = {"AUC": auc_aggregated} return metrics_aggregated + def config_func(rnd: int) -> Dict[str, str]: + """Return a configuration with global epochs.""" + config = { + "global_round": str(rnd), + } + return config + We use two clients for this example. An :code:`evaluate_metrics_aggregation` function is defined to collect and wighted average the AUC values from clients. +The :code:`config_func` function is to return the current FL round number to client's :code:`fit()` and :code:`evaluate()` methods. Then, we start the server: @@ -346,7 +388,7 @@ Then, we start the server: # Start Flower server fl.server.start_server( server_address="0.0.0.0:8080", - config=fl.server.ServerConfig(num_rounds=num_rounds), + config=fl.server.ServerConfig(num_rounds=5), strategy=strategy, ) @@ -535,52 +577,66 @@ Open a new terminal and start the first client: .. code-block:: shell - $ python3 client.py --node-id=0 + $ python3 client.py --partition-id=0 Open another terminal and start the second client: .. code-block:: shell - $ python3 client.py --node-id=1 + $ python3 client.py --partition-id=1 Each client will have its own dataset. You should now see how the training does in the very first terminal (the one that started the server): .. code-block:: shell - INFO flwr 2023-11-20 11:21:56,454 | app.py:163 | Starting Flower server, config: ServerConfig(num_rounds=5, round_timeout=None) - INFO flwr 2023-11-20 11:21:56,473 | app.py:176 | Flower ECE: gRPC server running (5 rounds), SSL is disabled - INFO flwr 2023-11-20 11:21:56,473 | server.py:89 | Initializing global parameters - INFO flwr 2023-11-20 11:21:56,473 | server.py:276 | Requesting initial parameters from one random client - INFO flwr 2023-11-20 11:22:38,302 | server.py:280 | Received initial parameters from one random client - INFO flwr 2023-11-20 11:22:38,302 | server.py:91 | Evaluating initial parameters - INFO flwr 2023-11-20 11:22:38,302 | server.py:104 | FL starting - DEBUG flwr 2023-11-20 11:22:38,302 | server.py:222 | fit_round 1: strategy sampled 2 clients (out of 2) - DEBUG flwr 2023-11-20 11:22:38,636 | server.py:236 | fit_round 1 received 2 results and 0 failures - DEBUG flwr 2023-11-20 11:22:38,643 | server.py:173 | evaluate_round 1: strategy sampled 2 clients (out of 2) - DEBUG flwr 2023-11-20 11:22:38,653 | server.py:187 | evaluate_round 1 received 2 results and 0 failures - DEBUG flwr 2023-11-20 11:22:38,653 | server.py:222 | fit_round 2: strategy sampled 2 clients (out of 2) - DEBUG flwr 2023-11-20 11:22:38,721 | server.py:236 | fit_round 2 received 2 results and 0 failures - DEBUG flwr 2023-11-20 11:22:38,745 | server.py:173 | evaluate_round 2: strategy sampled 2 clients (out of 2) - DEBUG flwr 2023-11-20 11:22:38,756 | server.py:187 | evaluate_round 2 received 2 results and 0 failures - DEBUG flwr 2023-11-20 11:22:38,756 | server.py:222 | fit_round 3: strategy sampled 2 clients (out of 2) - DEBUG flwr 2023-11-20 11:22:38,831 | server.py:236 | fit_round 3 received 2 results and 0 failures - DEBUG flwr 2023-11-20 11:22:38,868 | server.py:173 | evaluate_round 3: strategy sampled 2 clients (out of 2) - DEBUG flwr 2023-11-20 11:22:38,881 | server.py:187 | evaluate_round 3 received 2 results and 0 failures - DEBUG flwr 2023-11-20 11:22:38,881 | server.py:222 | fit_round 4: strategy sampled 2 clients (out of 2) - DEBUG flwr 2023-11-20 11:22:38,960 | server.py:236 | fit_round 4 received 2 results and 0 failures - DEBUG flwr 2023-11-20 11:22:39,012 | server.py:173 | evaluate_round 4: strategy sampled 2 clients (out of 2) - DEBUG flwr 2023-11-20 11:22:39,026 | server.py:187 | evaluate_round 4 received 2 results and 0 failures - DEBUG flwr 2023-11-20 11:22:39,026 | server.py:222 | fit_round 5: strategy sampled 2 clients (out of 2) - DEBUG flwr 2023-11-20 11:22:39,111 | server.py:236 | fit_round 5 received 2 results and 0 failures - DEBUG flwr 2023-11-20 11:22:39,177 | server.py:173 | evaluate_round 5: strategy sampled 2 clients (out of 2) - DEBUG flwr 2023-11-20 11:22:39,193 | server.py:187 | evaluate_round 5 received 2 results and 0 failures - INFO flwr 2023-11-20 11:22:39,193 | server.py:153 | FL finished in 0.8905023969999988 - INFO flwr 2023-11-20 11:22:39,193 | app.py:226 | app_fit: losses_distributed [(1, 0), (2, 0), (3, 0), (4, 0), (5, 0)] - INFO flwr 2023-11-20 11:22:39,193 | app.py:227 | app_fit: metrics_distributed_fit {} - INFO flwr 2023-11-20 11:22:39,193 | app.py:228 | app_fit: metrics_distributed {'AUC': [(1, 0.7572), (2, 0.7705), (3, 0.77595), (4, 0.78), (5, 0.78385)]} - INFO flwr 2023-11-20 11:22:39,193 | app.py:229 | app_fit: losses_centralized [] - INFO flwr 2023-11-20 11:22:39,193 | app.py:230 | app_fit: metrics_centralized {} + INFO : Starting Flower server, config: num_rounds=5, no round_timeout + INFO : Flower ECE: gRPC server running (5 rounds), SSL is disabled + INFO : [INIT] + INFO : Requesting initial parameters from one random client + INFO : Received initial parameters from one random client + INFO : Evaluating initial global parameters + INFO : + INFO : [ROUND 1] + INFO : configure_fit: strategy sampled 2 clients (out of 2) + INFO : aggregate_fit: received 2 results and 0 failures + INFO : configure_evaluate: strategy sampled 2 clients (out of 2) + INFO : aggregate_evaluate: received 2 results and 0 failures + INFO : + INFO : [ROUND 2] + INFO : configure_fit: strategy sampled 2 clients (out of 2) + INFO : aggregate_fit: received 2 results and 0 failures + INFO : configure_evaluate: strategy sampled 2 clients (out of 2) + INFO : aggregate_evaluate: received 2 results and 0 failures + INFO : + INFO : [ROUND 3] + INFO : configure_fit: strategy sampled 2 clients (out of 2) + INFO : aggregate_fit: received 2 results and 0 failures + INFO : configure_evaluate: strategy sampled 2 clients (out of 2) + INFO : aggregate_evaluate: received 2 results and 0 failures + INFO : + INFO : [ROUND 4] + INFO : configure_fit: strategy sampled 2 clients (out of 2) + INFO : aggregate_fit: received 2 results and 0 failures + INFO : configure_evaluate: strategy sampled 2 clients (out of 2) + INFO : aggregate_evaluate: received 2 results and 0 failures + INFO : + INFO : [ROUND 5] + INFO : configure_fit: strategy sampled 2 clients (out of 2) + INFO : aggregate_fit: received 2 results and 0 failures + INFO : configure_evaluate: strategy sampled 2 clients (out of 2) + INFO : aggregate_evaluate: received 2 results and 0 failures + INFO : + INFO : [SUMMARY] + INFO : Run finished 5 round(s) in 1.67s + INFO : History (loss, distributed): + INFO : round 1: 0 + INFO : round 2: 0 + INFO : round 3: 0 + INFO : round 4: 0 + INFO : round 5: 0 + INFO : History (metrics, distributed, evaluate): + INFO : {'AUC': [(1, 0.76755), (2, 0.775), (3, 0.77935), (4, 0.7836), (5, 0.7872)]} Congratulations! You've successfully built and run your first federated XGBoost system. diff --git a/e2e/fastai/README.md b/e2e/framework-fastai/README.md similarity index 100% rename from e2e/fastai/README.md rename to e2e/framework-fastai/README.md diff --git a/e2e/fastai/client.py b/e2e/framework-fastai/client.py similarity index 100% rename from e2e/fastai/client.py rename to e2e/framework-fastai/client.py diff --git a/e2e/fastai/pyproject.toml b/e2e/framework-fastai/pyproject.toml similarity index 100% rename from e2e/fastai/pyproject.toml rename to e2e/framework-fastai/pyproject.toml diff --git a/e2e/fastai/simulation.py b/e2e/framework-fastai/simulation.py similarity index 100% rename from e2e/fastai/simulation.py rename to e2e/framework-fastai/simulation.py diff --git a/e2e/jax/README.md b/e2e/framework-jax/README.md similarity index 100% rename from e2e/jax/README.md rename to e2e/framework-jax/README.md diff --git a/e2e/jax/client.py b/e2e/framework-jax/client.py similarity index 100% rename from e2e/jax/client.py rename to e2e/framework-jax/client.py diff --git a/e2e/jax/jax_training.py b/e2e/framework-jax/jax_training.py similarity index 100% rename from e2e/jax/jax_training.py rename to e2e/framework-jax/jax_training.py diff --git a/e2e/jax/pyproject.toml b/e2e/framework-jax/pyproject.toml similarity index 100% rename from e2e/jax/pyproject.toml rename to e2e/framework-jax/pyproject.toml diff --git a/e2e/jax/simulation.py b/e2e/framework-jax/simulation.py similarity index 100% rename from e2e/jax/simulation.py rename to e2e/framework-jax/simulation.py diff --git a/e2e/opacus/.gitignore b/e2e/framework-opacus/.gitignore similarity index 100% rename from e2e/opacus/.gitignore rename to e2e/framework-opacus/.gitignore diff --git a/e2e/opacus/README.md b/e2e/framework-opacus/README.md similarity index 100% rename from e2e/opacus/README.md rename to e2e/framework-opacus/README.md diff --git a/e2e/opacus/client.py b/e2e/framework-opacus/client.py similarity index 100% rename from e2e/opacus/client.py rename to e2e/framework-opacus/client.py diff --git a/e2e/opacus/pyproject.toml b/e2e/framework-opacus/pyproject.toml similarity index 100% rename from e2e/opacus/pyproject.toml rename to e2e/framework-opacus/pyproject.toml diff --git a/e2e/opacus/simulation.py b/e2e/framework-opacus/simulation.py similarity index 100% rename from e2e/opacus/simulation.py rename to e2e/framework-opacus/simulation.py diff --git a/e2e/pandas/README.md b/e2e/framework-pandas/README.md similarity index 100% rename from e2e/pandas/README.md rename to e2e/framework-pandas/README.md diff --git a/e2e/pandas/client.py b/e2e/framework-pandas/client.py similarity index 100% rename from e2e/pandas/client.py rename to e2e/framework-pandas/client.py diff --git a/e2e/pandas/pyproject.toml b/e2e/framework-pandas/pyproject.toml similarity index 100% rename from e2e/pandas/pyproject.toml rename to e2e/framework-pandas/pyproject.toml diff --git a/e2e/pandas/server.py b/e2e/framework-pandas/server.py similarity index 100% rename from e2e/pandas/server.py rename to e2e/framework-pandas/server.py diff --git a/e2e/pandas/simulation.py b/e2e/framework-pandas/simulation.py similarity index 100% rename from e2e/pandas/simulation.py rename to e2e/framework-pandas/simulation.py diff --git a/e2e/pandas/strategy.py b/e2e/framework-pandas/strategy.py similarity index 100% rename from e2e/pandas/strategy.py rename to e2e/framework-pandas/strategy.py diff --git a/e2e/pytorch-lightning/README.md b/e2e/framework-pytorch-lightning/README.md similarity index 100% rename from e2e/pytorch-lightning/README.md rename to e2e/framework-pytorch-lightning/README.md diff --git a/e2e/pytorch-lightning/client.py b/e2e/framework-pytorch-lightning/client.py similarity index 100% rename from e2e/pytorch-lightning/client.py rename to e2e/framework-pytorch-lightning/client.py diff --git a/e2e/pytorch-lightning/mnist.py b/e2e/framework-pytorch-lightning/mnist.py similarity index 100% rename from e2e/pytorch-lightning/mnist.py rename to e2e/framework-pytorch-lightning/mnist.py diff --git a/e2e/pytorch-lightning/pyproject.toml b/e2e/framework-pytorch-lightning/pyproject.toml similarity index 100% rename from e2e/pytorch-lightning/pyproject.toml rename to e2e/framework-pytorch-lightning/pyproject.toml diff --git a/e2e/pytorch-lightning/simulation.py b/e2e/framework-pytorch-lightning/simulation.py similarity index 100% rename from e2e/pytorch-lightning/simulation.py rename to e2e/framework-pytorch-lightning/simulation.py diff --git a/e2e/pytorch/README.md b/e2e/framework-pytorch/README.md similarity index 100% rename from e2e/pytorch/README.md rename to e2e/framework-pytorch/README.md diff --git a/e2e/pytorch/client.py b/e2e/framework-pytorch/client.py similarity index 100% rename from e2e/pytorch/client.py rename to e2e/framework-pytorch/client.py diff --git a/e2e/pytorch/pyproject.toml b/e2e/framework-pytorch/pyproject.toml similarity index 100% rename from e2e/pytorch/pyproject.toml rename to e2e/framework-pytorch/pyproject.toml diff --git a/e2e/pytorch/simulation.py b/e2e/framework-pytorch/simulation.py similarity index 100% rename from e2e/pytorch/simulation.py rename to e2e/framework-pytorch/simulation.py diff --git a/e2e/pytorch/simulation_next.py b/e2e/framework-pytorch/simulation_next.py similarity index 100% rename from e2e/pytorch/simulation_next.py rename to e2e/framework-pytorch/simulation_next.py diff --git a/e2e/scikit-learn/README.md b/e2e/framework-scikit-learn/README.md similarity index 100% rename from e2e/scikit-learn/README.md rename to e2e/framework-scikit-learn/README.md diff --git a/e2e/scikit-learn/client.py b/e2e/framework-scikit-learn/client.py similarity index 100% rename from e2e/scikit-learn/client.py rename to e2e/framework-scikit-learn/client.py diff --git a/e2e/scikit-learn/pyproject.toml b/e2e/framework-scikit-learn/pyproject.toml similarity index 100% rename from e2e/scikit-learn/pyproject.toml rename to e2e/framework-scikit-learn/pyproject.toml diff --git a/e2e/scikit-learn/simulation.py b/e2e/framework-scikit-learn/simulation.py similarity index 100% rename from e2e/scikit-learn/simulation.py rename to e2e/framework-scikit-learn/simulation.py diff --git a/e2e/scikit-learn/utils.py b/e2e/framework-scikit-learn/utils.py similarity index 100% rename from e2e/scikit-learn/utils.py rename to e2e/framework-scikit-learn/utils.py diff --git a/e2e/tensorflow/README.md b/e2e/framework-tensorflow/README.md similarity index 100% rename from e2e/tensorflow/README.md rename to e2e/framework-tensorflow/README.md diff --git a/e2e/tensorflow/client.py b/e2e/framework-tensorflow/client.py similarity index 100% rename from e2e/tensorflow/client.py rename to e2e/framework-tensorflow/client.py diff --git a/e2e/tensorflow/pyproject.toml b/e2e/framework-tensorflow/pyproject.toml similarity index 100% rename from e2e/tensorflow/pyproject.toml rename to e2e/framework-tensorflow/pyproject.toml diff --git a/e2e/tabnet/simulation.py b/e2e/framework-tensorflow/simulation.py similarity index 100% rename from e2e/tabnet/simulation.py rename to e2e/framework-tensorflow/simulation.py diff --git a/e2e/tensorflow/simulation_next.py b/e2e/framework-tensorflow/simulation_next.py similarity index 100% rename from e2e/tensorflow/simulation_next.py rename to e2e/framework-tensorflow/simulation_next.py diff --git a/e2e/tabnet/README.md b/e2e/tabnet/README.md deleted file mode 100644 index 258043c3ffa8..000000000000 --- a/e2e/tabnet/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Flower with Tabnet testing - -This directory is used for testing Flower with Tabnet. - -It uses the `FedAvg` strategy. diff --git a/e2e/tabnet/client.py b/e2e/tabnet/client.py deleted file mode 100644 index 1a7ecfd68f73..000000000000 --- a/e2e/tabnet/client.py +++ /dev/null @@ -1,95 +0,0 @@ -import os - -import tabnet -import tensorflow as tf -import tensorflow_datasets as tfds - -import flwr as fl - -train_size = 125 -BATCH_SIZE = 50 -col_names = ["sepal_length", "sepal_width", "petal_length", "petal_width"] - - -def transform(ds): - features = tf.unstack(ds["features"]) - labels = ds["label"] - - x = dict(zip(col_names, features)) - y = tf.one_hot(labels, 3) - return x, y - - -def prepare_iris_dataset(): - ds_full = tfds.load(name="iris", split=tfds.Split.TRAIN) - ds_full = ds_full.shuffle(150, seed=0) - - ds_train = ds_full.take(train_size) - ds_train = ds_train.map(transform) - ds_train = ds_train.batch(BATCH_SIZE) - - ds_test = ds_full.skip(train_size) - ds_test = ds_test.map(transform) - ds_test = ds_test.batch(BATCH_SIZE) - - feature_columns = [] - for col_name in col_names: - feature_columns.append(tf.feature_column.numeric_column(col_name)) - - return ds_train, ds_test, feature_columns - - -ds_train, ds_test, feature_columns = prepare_iris_dataset() -# Make TensorFlow log less verbose -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - -# Load TabNet model -model = tabnet.TabNetClassifier( - feature_columns, - num_classes=3, - feature_dim=8, - output_dim=4, - num_decision_steps=4, - relaxation_factor=1.0, - sparsity_coefficient=1e-5, - batch_momentum=0.98, - virtual_batch_size=None, - norm_type="group", - num_groups=1, -) -lr = tf.keras.optimizers.schedules.ExponentialDecay( - 0.01, decay_steps=100, decay_rate=0.9, staircase=False -) -optimizer = tf.keras.optimizers.Adam(lr) -model.compile(optimizer, loss="categorical_crossentropy", metrics=["accuracy"]) - - -# Define Flower client -class FlowerClient(fl.client.NumPyClient): - def get_parameters(self, config): - return model.get_weights() - - def fit(self, parameters, config): - model.set_weights(parameters) - model.fit(ds_train, epochs=25) - return model.get_weights(), len(ds_train), {} - - def evaluate(self, parameters, config): - model.set_weights(parameters) - loss, accuracy = model.evaluate(ds_test) - return loss, len(ds_train), {"accuracy": accuracy} - - -def client_fn(cid): - return FlowerClient().to_client() - - -app = fl.client.ClientApp( - client_fn=client_fn, -) - -if __name__ == "__main__": - # Start Flower client - fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() - ) diff --git a/e2e/tabnet/pyproject.toml b/e2e/tabnet/pyproject.toml deleted file mode 100644 index 99379ddb607e..000000000000 --- a/e2e/tabnet/pyproject.toml +++ /dev/null @@ -1,25 +0,0 @@ -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[project] -name = "quickstart-tabnet-test" -version = "0.1.0" -description = "Tabnet Federated Learning E2E test with Flower" -authors = [ - { name = "The Flower Authors", email = "hello@flower.ai" }, -] -dependencies = [ - "flwr[simulation] @ {root:parent:parent:uri}", - "tensorflow-cpu>=2.9.1,!=2.11.1; platform_machine == \"x86_64\"", - "tensorflow-macos>=2.9.1,!=2.11.1; sys_platform == \"darwin\" and platform_machine == \"arm64\"", - "tensorflow_datasets==4.9.2", - "tensorflow-io-gcs-filesystem<0.35.0", - "tabnet==0.1.6", -] - -[tool.hatch.build.targets.wheel] -packages = ["."] - -[tool.hatch.metadata] -allow-direct-references = true diff --git a/e2e/tensorflow/simulation.py b/e2e/tensorflow/simulation.py deleted file mode 100644 index bf05a77cf32a..000000000000 --- a/e2e/tensorflow/simulation.py +++ /dev/null @@ -1,14 +0,0 @@ -from client import client_fn - -import flwr as fl - -hist = fl.simulation.start_simulation( - client_fn=client_fn, - num_clients=2, - config=fl.server.ServerConfig(num_rounds=3), -) - -assert ( - hist.losses_distributed[-1][1] == 0 - or (hist.losses_distributed[0][1] / hist.losses_distributed[-1][1]) >= 0.98 -) diff --git a/e2e/test.sh b/e2e/test_legacy.sh similarity index 96% rename from e2e/test.sh rename to e2e/test_legacy.sh index 4ea17a4f994b..dc17ca8c6378 100755 --- a/e2e/test.sh +++ b/e2e/test_legacy.sh @@ -2,7 +2,7 @@ set -e case "$1" in - pandas) + framework-pandas) server_file="server.py" ;; bare-https) diff --git a/e2e/test_driver.sh b/e2e/test_superlink.sh similarity index 90% rename from e2e/test_driver.sh rename to e2e/test_superlink.sh index e177863bab78..1bb81cc47ea1 100755 --- a/e2e/test_driver.sh +++ b/e2e/test_superlink.sh @@ -2,7 +2,7 @@ set -e case "$1" in - pandas) + framework-pandas) server_arg="--insecure" client_arg="--insecure" server_dir="./" @@ -70,11 +70,11 @@ timeout 2m flower-superlink $server_arg $db_arg $rest_arg_superlink $server_auth sl_pid=$! sleep 3 -timeout 2m flower-client-app client:app $client_arg $rest_arg_supernode --superlink $server_address $client_auth_1 & +timeout 2m flower-supernode client:app $client_arg $rest_arg_supernode --superlink $server_address $client_auth_1 & cl1_pid=$! sleep 3 -timeout 2m flower-client-app client:app $client_arg $rest_arg_supernode --superlink $server_address $client_auth_2 & +timeout 2m flower-supernode client:app $client_arg $rest_arg_supernode --superlink $server_address $client_auth_2 & cl2_pid=$! sleep 3 diff --git a/examples/advanced-pytorch/README.md b/examples/advanced-pytorch/README.md index c1ba85b95879..ac0737673407 100644 --- a/examples/advanced-pytorch/README.md +++ b/examples/advanced-pytorch/README.md @@ -1,3 +1,9 @@ +--- +tags: [advanced, vision, fds] +dataset: [CIFAR-10] +framework: [torch, torchvision] +--- + # Advanced Flower Example (PyTorch) This example demonstrates an advanced federated learning setup using Flower with PyTorch. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) and it differs from the quickstart example in the following ways: diff --git a/examples/advanced-tensorflow/README.md b/examples/advanced-tensorflow/README.md index 94707b5cbc98..375c539d13dd 100644 --- a/examples/advanced-tensorflow/README.md +++ b/examples/advanced-tensorflow/README.md @@ -1,3 +1,9 @@ +--- +tags: [advanced, vision, fds] +dataset: [CIFAR-10] +framework: [tensorflow, Keras] +--- + # Advanced Flower Example (TensorFlow/Keras) This example demonstrates an advanced federated learning setup using Flower with TensorFlow/Keras. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) and it differs from the quickstart example in the following ways: diff --git a/examples/android-kotlin/README.md b/examples/android-kotlin/README.md index 2d0f704fdc0e..6cadb8e436fe 100644 --- a/examples/android-kotlin/README.md +++ b/examples/android-kotlin/README.md @@ -1,3 +1,9 @@ +--- +tags: [mobile, vision, sdk] +dataset: [CIFAR-10] +framework: [Android, Kotlin, TensorFlowLite] +--- + # Flower Android Client Example with Kotlin and TensorFlow Lite 2022 This example is similar to the Flower Android Example in Java: diff --git a/examples/android/README.md b/examples/android/README.md index f9f2bb93b8dc..83519f15d04d 100644 --- a/examples/android/README.md +++ b/examples/android/README.md @@ -1,3 +1,9 @@ +--- +tags: [mobile, vision, sdk] +dataset: [CIFAR-10] +framework: [Android, Java, TensorFlowLite] +--- + # Flower Android Example (TensorFlowLite) This example demonstrates a federated learning setup with Android clients in a background thread. The training on Android is done on a CIFAR10 dataset using TensorFlow Lite. The setup is as follows: diff --git a/examples/app-pytorch/README.md b/examples/app-pytorch/README.md index 14de3c7d632e..5cfae8440ed2 100644 --- a/examples/app-pytorch/README.md +++ b/examples/app-pytorch/README.md @@ -1,3 +1,9 @@ +--- +tags: [basic, vision, fds] +dataset: [CIFAR-10] +framework: [torch, torchvision] +--- + # Flower App (PyTorch) 🧪 > 🧪 = This example covers experimental features that might change in future versions of Flower diff --git a/examples/app-secure-aggregation/README.md b/examples/app-secure-aggregation/README.md index d1ea7bdc893f..8e483fb2f6bd 100644 --- a/examples/app-secure-aggregation/README.md +++ b/examples/app-secure-aggregation/README.md @@ -1,3 +1,9 @@ +--- +tags: [basic, vision, fds] +dataset: [] +framework: [numpy] +--- + # Secure aggregation with Flower (the SecAgg+ protocol) 🧪 > 🧪 = This example covers experimental features that might change in future versions of Flower diff --git a/examples/custom-metrics/README.md b/examples/custom-metrics/README.md index 317fb6336106..dd6985070cef 100644 --- a/examples/custom-metrics/README.md +++ b/examples/custom-metrics/README.md @@ -1,3 +1,10 @@ +--- +title: Example Flower App with Custom Metrics +tags: [basic, vision, fds] +dataset: [CIFAR-10] +framework: [tensorflow] +--- + # Flower Example using Custom Metrics This simple example demonstrates how to calculate custom metrics over multiple clients beyond the traditional ones available in the ML frameworks. In this case, it demonstrates the use of ready-available `scikit-learn` metrics: accuracy, recall, precision, and f1-score. diff --git a/examples/custom-mods/README.md b/examples/custom-mods/README.md index 6b03abcfbfe0..c2007eb323ae 100644 --- a/examples/custom-mods/README.md +++ b/examples/custom-mods/README.md @@ -1,3 +1,9 @@ +--- +tags: [mods, monitoring, app] +dataset: [CIFAR-10] +framework: [wandb, tensorboard] +--- + # Using custom mods 🧪 > 🧪 = This example covers experimental features that might change in future versions of Flower @@ -207,7 +213,7 @@ app = fl.client.ClientApp( client_fn=client_fn, mods=[ get_wandb_mod("Custom mods example"), - ], + ], ) ``` diff --git a/examples/doc/source/.gitignore b/examples/doc/source/.gitignore index dd449725e188..73ee14e96f68 100644 --- a/examples/doc/source/.gitignore +++ b/examples/doc/source/.gitignore @@ -1 +1,2 @@ *.md +index.rst diff --git a/examples/embedded-devices/README.md b/examples/embedded-devices/README.md index f1c5931b823a..86f19399932d 100644 --- a/examples/embedded-devices/README.md +++ b/examples/embedded-devices/README.md @@ -1,3 +1,9 @@ +--- +tags: [basic, vision, fds] +dataset: [CIFAR-10, MNIST] +framework: [torch, tensorflow] +--- + # Federated Learning on Embedded Devices with Flower This example will show you how Flower makes it very easy to run Federated Learning workloads on edge devices. Here we'll be showing how to use NVIDIA Jetson devices and Raspberry Pi as Flower clients. You can run this example using either PyTorch or Tensorflow. The FL workload (i.e. model, dataset and training loop) is mostly borrowed from the [quickstart-pytorch](https://github.com/adap/flower/tree/main/examples/simulation-pytorch) and [quickstart-tensorflow](https://github.com/adap/flower/tree/main/examples/quickstart-tensorflow) examples. @@ -65,7 +71,7 @@ If you are working on this tutorial on your laptop or desktop, it can host the F - Install `pip`. In the terminal type: `sudo apt install python3-pip -y` - Now clone this directory. You just need to execute the `git clone` command shown at the top of this README.md on your device. - - Install Flower and your ML framework: We have prepared some convenient installation scripts that will install everything you need. You are free to install other versions of these ML frameworks to suit your needs. + - Install Flower and your ML framework of choice: We have prepared some convenient installation scripts that will install everything you need. You are free to install other versions of these ML frameworks to suit your needs. - If you want your clients to use PyTorch: `pip3 install -r requirements_pytorch.txt` - If you want your clients to use TensorFlow: `pip3 install -r requirements_tf.txt` diff --git a/examples/federated-kaplan-meier-fitter/README.md b/examples/federated-kaplan-meier-fitter/README.md index 1569467d6f82..20d4ca4c47af 100644 --- a/examples/federated-kaplan-meier-fitter/README.md +++ b/examples/federated-kaplan-meier-fitter/README.md @@ -1,3 +1,9 @@ +--- +tags: [estimator, medical] +dataset: [Waltons] +framework: [lifelines] +--- + # Flower Example using KaplanMeierFitter This is an introductory example on **federated survival analysis** using [Flower](https://flower.ai/) diff --git a/examples/fl-dp-sa/README.md b/examples/fl-dp-sa/README.md index 47eedb70a2b8..65c8a5b18fa8 100644 --- a/examples/fl-dp-sa/README.md +++ b/examples/fl-dp-sa/README.md @@ -1,4 +1,10 @@ -# fl_dp_sa +--- +tags: [basic, vision, fds] +dataset: [MNIST] +framework: [torch, torchvision] +--- + +# Example of Flower App with DP and SA This is a simple example that utilizes central differential privacy with client-side fixed clipping and secure aggregation. Note: This example is designed for a small number of rounds and is intended for demonstration purposes. diff --git a/examples/fl-tabular/README.md b/examples/fl-tabular/README.md index 58afd1080b70..ee6dd7d00ef0 100644 --- a/examples/fl-tabular/README.md +++ b/examples/fl-tabular/README.md @@ -1,3 +1,9 @@ +--- +tags: [basic, tabular, fds] +dataset: [Adult Census Income] +framework: [scikit-learn, torch] +--- + # Flower Example on Adult Census Income Tabular Dataset This code exemplifies a federated learning setup using the Flower framework on the ["Adult Census Income"](https://huggingface.co/datasets/scikit-learn/adult-census-income) tabular dataset. The "Adult Census Income" dataset contains demographic information such as age, education, occupation, etc., with the target attribute being income level (\<=50K or >50K). The dataset is partitioned into subsets, simulating a federated environment with 5 clients, each holding a distinct portion of the data. Categorical variables are one-hot encoded, and the data is split into training and testing sets. Federated learning is conducted using the FedAvg strategy for 5 rounds. diff --git a/examples/flower-authentication/README.md b/examples/flower-authentication/README.md index 589270e621c9..d10780eeae5d 100644 --- a/examples/flower-authentication/README.md +++ b/examples/flower-authentication/README.md @@ -1,3 +1,9 @@ +--- +tags: [advanced, vision, fds] +dataset: [CIFAR-10] +framework: [torch, torchvision] +--- + # Flower Authentication with PyTorch 🧪 > 🧪 = This example covers experimental features that might change in future versions of Flower diff --git a/examples/flower-in-30-minutes/README.md b/examples/flower-in-30-minutes/README.md index 5fd9b882413b..faec3d72dae2 100644 --- a/examples/flower-in-30-minutes/README.md +++ b/examples/flower-in-30-minutes/README.md @@ -1,3 +1,9 @@ +--- +tags: [colab, vision, simulation] +dataset: [CIFAR-10] +framework: [torch] +--- + # 30-minute tutorial running Flower simulation with PyTorch This README links to a Jupyter notebook that you can either download and run locally or [![open it in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adap/flower/blob/main/examples/flower-in-30-minutes/tutorial.ipynb). This is a short 30-minute (or less!) tutorial showcasing the basics of Flower federated learning simulations using PyTorch. diff --git a/examples/flower-simulation-step-by-step-pytorch/README.md b/examples/flower-simulation-step-by-step-pytorch/README.md index beb8dd7f6f95..b00afedbe80b 100644 --- a/examples/flower-simulation-step-by-step-pytorch/README.md +++ b/examples/flower-simulation-step-by-step-pytorch/README.md @@ -1,3 +1,9 @@ +--- +tags: [basic, vision, simulation] +dataset: [MNIST] +framework: [torch] +--- + # Flower Simulation Step-by-Step > Since this tutorial (and its video series) was put together, Flower has been updated a few times. As a result, some of the steps to construct the environment (see below) have been updated. Some parts of the code have also been updated. Overall, the content of this tutorial and how things work remains the same as in the video tutorials. diff --git a/examples/flower-via-docker-compose/README.md b/examples/flower-via-docker-compose/README.md index 3ef1ac37bcda..3325a731fecf 100644 --- a/examples/flower-via-docker-compose/README.md +++ b/examples/flower-via-docker-compose/README.md @@ -1,3 +1,10 @@ +--- +title: Leveraging Flower and Docker for Device Heterogeneity Management in FL +tags: [deployment, vision, tutorial] +dataset: [CIFAR-10] +framework: [Docker, tensorflow] +--- + # Leveraging Flower and Docker for Device Heterogeneity Management in Federated Learning

diff --git a/examples/ios/README.md b/examples/ios/README.md index 4e17e7a674f3..aef4177dddf7 100644 --- a/examples/ios/README.md +++ b/examples/ios/README.md @@ -1,3 +1,9 @@ +--- +tags: [mobile, vision, sdk] +dataset: [MNIST] +framework: [Swift] +--- + # FLiOS - A Flower SDK for iOS Devices with Example FLiOS is a sample application for testing and benchmarking the Swift implementation of Flower. The default scenario uses the MNIST dataset and the associated digit recognition model. The app includes the Swift package in `./src/swift` and allows extension for other benchmarking scenarios. The app guides the user through the steps of the machine learning process that would be executed in a normal production environment as a background task of the application. The app is therefore aimed at researchers and research institutions to test their hypotheses and perform performance analyses. diff --git a/examples/llm-flowertune/README.md b/examples/llm-flowertune/README.md index 4f98072f8c7f..46076e0b2078 100644 --- a/examples/llm-flowertune/README.md +++ b/examples/llm-flowertune/README.md @@ -1,3 +1,10 @@ +--- +title: Federated LLM Fine-tuning with Flower +tags: [llm, nlp, LLama2] +dataset: [Alpaca-GPT4] +framework: [PEFT, torch] +--- + # LLM FlowerTune: Federated LLM Fine-tuning with Flower Large language models (LLMs), which have been trained on vast amounts of publicly accessible data, have shown remarkable effectiveness in a wide range of areas. diff --git a/examples/opacus/README.md b/examples/opacus/README.md index 6fc0d2ff49a0..aea5d0f689fe 100644 --- a/examples/opacus/README.md +++ b/examples/opacus/README.md @@ -1,3 +1,9 @@ +--- +tags: [dp, security, fds] +dataset: [CIFAR-10] +framework: [opacus, torch] +--- + # Training with Sample-Level Differential Privacy using Opacus Privacy Engine In this example, we demonstrate how to train a model with differential privacy (DP) using Flower. We employ PyTorch and integrate the Opacus Privacy Engine to achieve sample-level differential privacy. This setup ensures robust privacy guarantees during the client training phase. The code is adapted from the [PyTorch Quickstart example](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch). diff --git a/examples/pytorch-federated-variational-autoencoder/README.md b/examples/pytorch-federated-variational-autoencoder/README.md index 00af7a6328b2..52f94a16307c 100644 --- a/examples/pytorch-federated-variational-autoencoder/README.md +++ b/examples/pytorch-federated-variational-autoencoder/README.md @@ -1,3 +1,9 @@ +--- +tags: [basic, vision, fds] +dataset: [CIFAR-10] +framework: [torch, torchvision] +--- + # Flower Example for Federated Variational Autoencoder using Pytorch This example demonstrates how a variational autoencoder (VAE) can be trained in a federated way using the Flower framework. diff --git a/examples/pytorch-from-centralized-to-federated/README.md b/examples/pytorch-from-centralized-to-federated/README.md index 06ee89dddcac..1bff7d02f52c 100644 --- a/examples/pytorch-from-centralized-to-federated/README.md +++ b/examples/pytorch-from-centralized-to-federated/README.md @@ -1,3 +1,9 @@ +--- +tags: [basic, vision, fds] +dataset: [CIFAR-10] +framework: [torch] +--- + # PyTorch: From Centralized To Federated This example demonstrates how an already existing centralized PyTorch-based machine learning project can be federated with Flower. diff --git a/examples/quickstart-cpp/README.md b/examples/quickstart-cpp/README.md index d6cbeebe1bc6..61b76ece52b0 100644 --- a/examples/quickstart-cpp/README.md +++ b/examples/quickstart-cpp/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, linear regression, tabular] +dataset: [Synthetic] +framework: [C++] +--- + # Flower Clients in C++ (under development) In this example you will train a linear model on synthetic data using C++ clients. diff --git a/examples/quickstart-fastai/README.md b/examples/quickstart-fastai/README.md index 38ef23c95a1e..d1bf97cd4203 100644 --- a/examples/quickstart-fastai/README.md +++ b/examples/quickstart-fastai/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, vision] +dataset: [MNIST] +framework: [fastai] +--- + # Flower Example using fastai This introductory example to Flower uses [fastai](https://www.fast.ai/), but deep knowledge of fastai is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. diff --git a/examples/quickstart-huggingface/README.md b/examples/quickstart-huggingface/README.md index ce7790cd4af5..fa4330040ea7 100644 --- a/examples/quickstart-huggingface/README.md +++ b/examples/quickstart-huggingface/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, llm, nlp, sentiment] +dataset: [IMDB] +framework: [transformers] +--- + # Federated HuggingFace Transformers using Flower and PyTorch This introductory example to using [HuggingFace](https://huggingface.co) Transformers with Flower with PyTorch. This example has been extended from the [quickstart-pytorch](https://flower.ai/docs/examples/quickstart-pytorch.html) example. The training script closely follows the [HuggingFace course](https://huggingface.co/course/chapter3?fw=pt), so you are encouraged to check that out for a detailed explanation of the transformer pipeline. diff --git a/examples/quickstart-jax/README.md b/examples/quickstart-jax/README.md index 836adf558d88..b47f3a82e13b 100644 --- a/examples/quickstart-jax/README.md +++ b/examples/quickstart-jax/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, linear regression] +dataset: [Synthetic] +framework: [JAX] +--- + # JAX: From Centralized To Federated This example demonstrates how an already existing centralized JAX-based machine learning project can be federated with Flower. diff --git a/examples/quickstart-mlcube/README.md b/examples/quickstart-mlcube/README.md index 8e6fc29b3ad8..f0c6c5664a82 100644 --- a/examples/quickstart-mlcube/README.md +++ b/examples/quickstart-mlcube/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, vision, deployment] +dataset: [MNIST] +framework: [mlcube, tensorflow, Keras] +--- + # Flower Example using TensorFlow/Keras + MLCube This introductory example to Flower uses MLCube together with Keras, but deep knowledge of Keras is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use-cases with MLCube. Running this example in itself is quite easy. diff --git a/examples/quickstart-mlx/README.md b/examples/quickstart-mlx/README.md index cca55bcb946a..a4ac44bf8460 100644 --- a/examples/quickstart-mlx/README.md +++ b/examples/quickstart-mlx/README.md @@ -1,3 +1,10 @@ +--- +title: Simple Flower Example using MLX +tags: [quickstart, vision] +dataset: [MNIST] +framework: [MLX] +--- + # Flower Example using MLX This introductory example to Flower uses [MLX](https://ml-explore.github.io/mlx/build/html/index.html), but deep knowledge of MLX is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. diff --git a/examples/quickstart-monai/README.md b/examples/quickstart-monai/README.md index 4a9afef4f86a..dc31f03e4b1b 100644 --- a/examples/quickstart-monai/README.md +++ b/examples/quickstart-monai/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, medical, vision] +dataset: [MedNIST] +framework: [MONAI] +--- + # Flower Example using MONAI This introductory example to Flower uses MONAI, but deep knowledge of MONAI is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. diff --git a/examples/quickstart-pandas/README.md b/examples/quickstart-pandas/README.md index dd69f3ead3cb..0b4b3a6ac78a 100644 --- a/examples/quickstart-pandas/README.md +++ b/examples/quickstart-pandas/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, tabular, federated analytics] +dataset: [Iris] +framework: [pandas] +--- + # Flower Example using Pandas This introductory example to Flower uses Pandas, but deep knowledge of Pandas is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to diff --git a/examples/quickstart-pytorch-lightning/README.md b/examples/quickstart-pytorch-lightning/README.md index fb29c7e9e9ea..04eb911818fc 100644 --- a/examples/quickstart-pytorch-lightning/README.md +++ b/examples/quickstart-pytorch-lightning/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, vision, fds] +dataset: [MNIST] +framework: [lightning] +--- + # Flower Example using PyTorch Lightning This introductory example to Flower uses PyTorch, but deep knowledge of PyTorch Lightning is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the MNIST dataset. diff --git a/examples/quickstart-pytorch/README.md b/examples/quickstart-pytorch/README.md index 93d6a593f362..8eace1ea6845 100644 --- a/examples/quickstart-pytorch/README.md +++ b/examples/quickstart-pytorch/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, vision, fds] +dataset: [CIFAR-10] +framework: [torch, torchvision] +--- + # Flower Example using PyTorch This introductory example to Flower uses PyTorch, but deep knowledge of PyTorch is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset. diff --git a/examples/quickstart-sklearn-tabular/README.md b/examples/quickstart-sklearn-tabular/README.md index a975a9392800..b0b4cd1b84c0 100644 --- a/examples/quickstart-sklearn-tabular/README.md +++ b/examples/quickstart-sklearn-tabular/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, tabular, fds] +dataset: [Iris] +framework: [scikit-learn] +--- + # Flower Example using scikit-learn This example of Flower uses `scikit-learn`'s `LogisticRegression` model to train a federated learning system on diff --git a/examples/quickstart-tabnet/README.md b/examples/quickstart-tabnet/README.md index 19a139f83064..e8be55eaacef 100644 --- a/examples/quickstart-tabnet/README.md +++ b/examples/quickstart-tabnet/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, tabular] +dataset: [Iris] +framework: [tabnet] +--- + # Flower TabNet Example using TensorFlow This introductory example to Flower uses Keras but deep knowledge of Keras is not necessarily required to run the example. However, it will help you understanding how to adapt Flower to your use-cases. You can learn more about TabNet from [paper](https://arxiv.org/abs/1908.07442) and its implementation using TensorFlow at [this repository](https://github.com/titu1994/tf-TabNet). Note also that the basis of this example using federated learning is the example from the repository above. diff --git a/examples/quickstart-tensorflow/README.md b/examples/quickstart-tensorflow/README.md index ae1fe19834a3..386f8bbd96f0 100644 --- a/examples/quickstart-tensorflow/README.md +++ b/examples/quickstart-tensorflow/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, vision, fds] +dataset: [CIFAR-10] +framework: [tensorflow] +--- + # Flower Example using TensorFlow/Keras This introductory example to Flower uses Keras but deep knowledge of Keras is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. diff --git a/examples/simulation-pytorch/README.md b/examples/simulation-pytorch/README.md index 93f9e1acbac7..2dbfbc849ab7 100644 --- a/examples/simulation-pytorch/README.md +++ b/examples/simulation-pytorch/README.md @@ -1,3 +1,9 @@ +--- +tags: [basic, vision, fds, simulation] +dataset: [MNIST] +framework: [torch, torchvision] +--- + # Flower Simulation example using PyTorch This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on a single machine. Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for a deep dive into how Flower simulation works. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the MNIST dataset. This examples uses 100 clients by default. diff --git a/examples/simulation-tensorflow/README.md b/examples/simulation-tensorflow/README.md index 917d7b34c7af..047cb4379659 100644 --- a/examples/simulation-tensorflow/README.md +++ b/examples/simulation-tensorflow/README.md @@ -1,3 +1,9 @@ +--- +tags: [basic, vision, fds, simulation] +dataset: [MNIST] +framework: [tensorflow, Keras] +--- + # Flower Simulation example using TensorFlow/Keras This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on a single machine. Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for a deep dive into how Flower simulation works. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the MNIST dataset. This examples uses 100 clients by default. diff --git a/examples/sklearn-logreg-mnist/README.md b/examples/sklearn-logreg-mnist/README.md index 12b1a5e3bc1a..b117c5452086 100644 --- a/examples/sklearn-logreg-mnist/README.md +++ b/examples/sklearn-logreg-mnist/README.md @@ -1,4 +1,10 @@ -# Flower Example using scikit-learn +--- +tags: [basic, vision, logistic regression, fds] +dataset: [MNIST] +framework: [scikit-learn] +--- + +# Flower Logistic Regression Example using scikit-learn This example of Flower uses `scikit-learn`'s `LogisticRegression` model to train a federated learning system. It will help you understand how to adapt Flower for use with `scikit-learn`. Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the MNIST dataset. diff --git a/examples/tensorflow-privacy/README.md b/examples/tensorflow-privacy/README.md index a1f1be00f6b0..8156f92f60c9 100644 --- a/examples/tensorflow-privacy/README.md +++ b/examples/tensorflow-privacy/README.md @@ -1,3 +1,9 @@ +--- +tags: [basic, vision, fds, privacy, dp] +dataset: [MNIST] +framework: [tensorflow] +--- + # Training with Sample-Level Differential Privacy using TensorFlow-Privacy Engine In this example, we demonstrate how to train a model with sample-level differential privacy (DP) using Flower. We employ TensorFlow and integrate the tensorflow-privacy Engine to achieve sample-level differential privacy. This setup ensures robust privacy guarantees during the client training phase. diff --git a/examples/vertical-fl/README.md b/examples/vertical-fl/README.md index ba8228a059f9..ab5d2210d8d5 100644 --- a/examples/vertical-fl/README.md +++ b/examples/vertical-fl/README.md @@ -1,3 +1,10 @@ +--- +title: Vertical FL Flower Example +tags: [vertical, tabular, advanced] +dataset: [Titanic] +framework: [torch, pandas, scikit-learn] +--- + # Vertical Federated Learning example This example will showcase how you can perform Vertical Federated Learning using diff --git a/examples/vit-finetune/README.md b/examples/vit-finetune/README.md index ac1652acf02d..957c0eda0b68 100644 --- a/examples/vit-finetune/README.md +++ b/examples/vit-finetune/README.md @@ -1,3 +1,10 @@ +--- +title: Federated finetuning of a ViT +tags: [finetuneing, vision, fds] +dataset: [Oxford Flower-102] +framework: [torch, torchvision] +--- + # Federated finetuning of a ViT This example shows how to use Flower's Simulation Engine to federate the finetuning of a Vision Transformer ([ViT-Base-16](https://pytorch.org/vision/main/models/generated/torchvision.models.vit_b_16.html#torchvision.models.vit_b_16)) that has been pretrained on ImageNet. To keep things simple we'll be finetuning it to [Oxford Flower-102](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html) datasset, creating 20 partitions using [Flower Datasets](https://flower.ai/docs/datasets/). We'll be finetuning just the exit `head` of the ViT, this means that the training is not that costly and each client requires just ~1GB of VRAM (for a batch size of 32 images). diff --git a/examples/whisper-federated-finetuning/README.md b/examples/whisper-federated-finetuning/README.md index ddebe51247b2..cfd0db842bae 100644 --- a/examples/whisper-federated-finetuning/README.md +++ b/examples/whisper-federated-finetuning/README.md @@ -1,3 +1,9 @@ +--- +tags: [finetuning, speech, transformers] +dataset: [SpeechCommands] +framework: [transformers, whisper] +--- + # On-device Federated Finetuning for Speech Classification This example demonstrates how to, from a pre-trained [Whisper](https://openai.com/research/whisper) model, finetune it for the downstream task of keyword spotting. We'll be implementing a federated downstream finetuning pipeline using Flower involving a total of 100 clients. As for the downstream dataset, we'll be using the [Google Speech Commands](https://huggingface.co/datasets/speech_commands) dataset for keyword spotting. We'll take the encoder part of the [Whisper-tiny](https://huggingface.co/openai/whisper-tiny) model, freeze its parameters, and learn a lightweight classification (\<800K parameters !!) head to correctly classify a spoken word. diff --git a/examples/xgboost-comprehensive/README.md b/examples/xgboost-comprehensive/README.md index dc6d7e3872d6..62fcba2bb06d 100644 --- a/examples/xgboost-comprehensive/README.md +++ b/examples/xgboost-comprehensive/README.md @@ -1,3 +1,9 @@ +--- +tags: [advanced, classification, tabular] +dataset: [HIGGS] +framework: [xgboost] +--- + # Flower Example using XGBoost (Comprehensive) This example demonstrates a comprehensive federated learning setup using Flower with XGBoost. diff --git a/examples/xgboost-quickstart/README.md b/examples/xgboost-quickstart/README.md index 713b6eab8bac..fa3e9d0dc6fb 100644 --- a/examples/xgboost-quickstart/README.md +++ b/examples/xgboost-quickstart/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, classification, tabular] +dataset: [HIGGS] +framework: [xgboost] +--- + # Flower Example using XGBoost This example demonstrates how to perform EXtreme Gradient Boosting (XGBoost) within Flower using `xgboost` package. diff --git a/pyproject.toml b/pyproject.toml index 5daf007471aa..c5ab0e5edcee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ flower-simulation = "flwr.simulation.run_simulation:run_simulation_from_cli" python = "^3.8" # Mandatory dependencies numpy = "^1.21.0" -grpcio = "^1.60.0" +grpcio = "^1.60.0,!=1.64.2,!=1.65.0" protobuf = "^4.25.2" cryptography = "^42.0.4" pycryptodome = "^3.18.0" @@ -131,12 +131,7 @@ licensecheck = "==2024" pre-commit = "==3.5.0" [tool.isort] -line_length = 88 -indent = " " -multi_line_output = 3 -include_trailing_comma = true -force_grid_wrap = 0 -use_parentheses = true +profile = "black" known_first_party = ["flwr", "flwr_tool"] [tool.black] diff --git a/src/docker/base/alpine/Dockerfile b/src/docker/base/alpine/Dockerfile index 04864b525e2e..9e58d82e3bda 100644 --- a/src/docker/base/alpine/Dockerfile +++ b/src/docker/base/alpine/Dockerfile @@ -54,9 +54,11 @@ FROM python:${PYTHON_VERSION}-${DISTRO}${DISTRO_VERSION} as base # required by the grpc package RUN apk add --no-cache \ libstdc++ \ + ca-certificates \ # add non-root user && adduser \ --no-create-home \ + --home /app \ --disabled-password \ --gecos "" \ --uid 49999 app \ diff --git a/src/docker/base/ubuntu/Dockerfile b/src/docker/base/ubuntu/Dockerfile index 4aeddc3f8d8d..960ed07edf96 100644 --- a/src/docker/base/ubuntu/Dockerfile +++ b/src/docker/base/ubuntu/Dockerfile @@ -48,29 +48,7 @@ RUN git clone https://github.com/pyenv/pyenv.git \ RUN LATEST=$(pyenv latest -k ${PYTHON_VERSION}) \ && python-build "${LATEST}" /usr/local/bin/python -FROM $DISTRO:$DISTRO_VERSION as base - -ENV DEBIAN_FRONTEND=noninteractive - -RUN apt-get update \ - && apt-get -y --no-install-recommends install \ - libsqlite3-0 \ - && rm -rf /var/lib/apt/lists/* - -COPY --from=python /usr/local/bin/python /usr/local/bin/python - -ENV PATH=/usr/local/bin/python/bin:$PATH \ - # Send stdout and stderr stream directly to the terminal. Ensures that no - # output is retained in a buffer if the application crashes. - PYTHONUNBUFFERED=1 \ - # Typically, bytecode is created on the first invocation to speed up following invocation. - # However, in Docker we only make a single invocation (when we start the container). - # Therefore, we can disable bytecode writing. - PYTHONDONTWRITEBYTECODE=1 \ - # Ensure that python encoding is always UTF-8. - PYTHONIOENCODING=UTF-8 \ - LANG=C.UTF-8 \ - LC_ALL=C.UTF-8 +ENV PATH=/usr/local/bin/python/bin:$PATH # Use a virtual environment to ensure that Python packages are installed in the same location # regardless of whether the subsequent image build is run with the app or the root user @@ -86,16 +64,43 @@ RUN pip install -U --no-cache-dir \ setuptools==${SETUPTOOLS_VERSION} \ ${FLWR_PACKAGE}==${FLWR_VERSION} -# add non-root user -RUN adduser \ +FROM $DISTRO:$DISTRO_VERSION as base + +COPY --from=python /usr/local/bin/python /usr/local/bin/python + +ENV DEBIAN_FRONTEND=noninteractive \ + PATH=/usr/local/bin/python/bin:$PATH + +RUN apt-get update \ + && apt-get -y --no-install-recommends install \ + libsqlite3-0 \ + ca-certificates \ + && rm -rf /var/lib/apt/lists/* \ + # add non-root user + && adduser \ --no-create-home \ + --home /app \ --disabled-password \ --gecos "" \ --uid 49999 app \ && mkdir -p /app \ - && chown -R app:app /python \ && chown -R app:app /app +COPY --from=python --chown=app:app /python/venv /python/venv + +ENV PATH=/python/venv/bin:$PATH \ + # Send stdout and stderr stream directly to the terminal. Ensures that no + # output is retained in a buffer if the application crashes. + PYTHONUNBUFFERED=1 \ + # Typically, bytecode is created on the first invocation to speed up following invocation. + # However, in Docker we only make a single invocation (when we start the container). + # Therefore, we can disable bytecode writing. + PYTHONDONTWRITEBYTECODE=1 \ + # Ensure that python encoding is always UTF-8. + PYTHONIOENCODING=UTF-8 \ + LANG=C.UTF-8 \ + LC_ALL=C.UTF-8 + WORKDIR /app USER app ENV HOME=/app diff --git a/src/docker/superexec/Dockerfile b/src/docker/superexec/Dockerfile new file mode 100644 index 000000000000..9e4cc722921e --- /dev/null +++ b/src/docker/superexec/Dockerfile @@ -0,0 +1,20 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +ARG BASE_REPOSITORY=flwr/base +ARG BASE_IMAGE +FROM $BASE_REPOSITORY:$BASE_IMAGE + +ENTRYPOINT ["flower-superexec"] diff --git a/src/docker/supernode/Dockerfile b/src/docker/supernode/Dockerfile index 8dce1c389a5b..8b78577b1201 100644 --- a/src/docker/supernode/Dockerfile +++ b/src/docker/supernode/Dockerfile @@ -17,4 +17,4 @@ ARG BASE_REPOSITORY=flwr/base ARG BASE_IMAGE FROM $BASE_REPOSITORY:$BASE_IMAGE -ENTRYPOINT ["flower-client-app"] +ENTRYPOINT ["flower-supernode"] diff --git a/src/proto/flwr/proto/driver.proto b/src/proto/flwr/proto/driver.proto index edbd5d91bb5b..77dc52b3258b 100644 --- a/src/proto/flwr/proto/driver.proto +++ b/src/proto/flwr/proto/driver.proto @@ -42,6 +42,7 @@ service Driver { message CreateRunRequest { string fab_id = 1; string fab_version = 2; + map override_config = 3; } message CreateRunResponse { sint64 run_id = 1; } diff --git a/src/proto/flwr/proto/exec.proto b/src/proto/flwr/proto/exec.proto index 8e5f53b02ca8..d0d8dfcbb273 100644 --- a/src/proto/flwr/proto/exec.proto +++ b/src/proto/flwr/proto/exec.proto @@ -25,7 +25,10 @@ service Exec { rpc StreamLogs(StreamLogsRequest) returns (stream StreamLogsResponse) {} } -message StartRunRequest { bytes fab_file = 1; } +message StartRunRequest { + bytes fab_file = 1; + map override_config = 2; +} message StartRunResponse { sint64 run_id = 1; } message StreamLogsRequest { sint64 run_id = 1; } message StreamLogsResponse { string log_output = 1; } diff --git a/src/proto/flwr/proto/run.proto b/src/proto/flwr/proto/run.proto index 76a7fd91532f..e41748381cab 100644 --- a/src/proto/flwr/proto/run.proto +++ b/src/proto/flwr/proto/run.proto @@ -21,6 +21,7 @@ message Run { sint64 run_id = 1; string fab_id = 2; string fab_version = 3; + map override_config = 4; } message GetRunRequest { sint64 run_id = 1; } message GetRunResponse { Run run = 1; } diff --git a/src/py/flwr/cli/config_utils.py b/src/py/flwr/cli/config_utils.py index d06a1d6dba96..33bf12e34b04 100644 --- a/src/py/flwr/cli/config_utils.py +++ b/src/py/flwr/cli/config_utils.py @@ -108,6 +108,14 @@ def load(path: Optional[Path] = None) -> Optional[Dict[str, Any]]: return load_from_string(toml_file.read()) +def _validate_run_config(config_dict: Dict[str, Any], errors: List[str]) -> None: + for key, value in config_dict.items(): + if isinstance(value, dict): + _validate_run_config(config_dict[key], errors) + elif not isinstance(value, str): + errors.append(f"Config value of key {key} is not of type `str`.") + + # pylint: disable=too-many-branches def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]: """Validate pyproject.toml fields.""" @@ -133,6 +141,8 @@ def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]] else: if "publisher" not in config["flower"]: errors.append('Property "publisher" missing in [flower]') + if "config" in config["flower"]: + _validate_run_config(config["flower"]["config"], errors) if "components" not in config["flower"]: errors.append("Missing [flower.components] section") else: diff --git a/src/py/flwr/cli/new/new.py b/src/py/flwr/cli/new/new.py index 9367cf6c9ffb..a0a2dc98556d 100644 --- a/src/py/flwr/cli/new/new.py +++ b/src/py/flwr/cli/new/new.py @@ -264,9 +264,11 @@ def new( bold=True, ) ) + + _add = " huggingface-cli login\n" if framework_str == "flowertune" else "" print( typer.style( - f" cd {package_name}\n" + " pip install -e .\n flwr run\n", + f" cd {package_name}\n" + " pip install -e .\n" + _add + " flwr run\n", fg=typer.colors.BRIGHT_CYAN, bold=True, ) diff --git a/src/py/flwr/cli/run/run.py b/src/py/flwr/cli/run/run.py index f5882bd14ab8..4ee2368f5794 100644 --- a/src/py/flwr/cli/run/run.py +++ b/src/py/flwr/cli/run/run.py @@ -18,13 +18,14 @@ from enum import Enum from logging import DEBUG from pathlib import Path -from typing import Optional +from typing import Dict, Optional import typer from typing_extensions import Annotated from flwr.cli import config_utils from flwr.cli.build import build +from flwr.common.config import parse_config_args from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel from flwr.common.logger import log @@ -58,15 +59,20 @@ def run( Optional[Path], typer.Option(help="Path of the Flower project to run"), ] = None, + config_overrides: Annotated[ + Optional[str], + typer.Option( + "--config", + "-c", + help="Override configuration key-value pairs", + ), + ] = None, ) -> None: """Run Flower project.""" - if use_superexec: - _start_superexec_run(directory) - return - typer.secho("Loading project configuration... ", fg=typer.colors.BLUE) - config, errors, warnings = config_utils.load_and_validate() + pyproject_path = directory / "pyproject.toml" if directory else None + config, errors, warnings = config_utils.load_and_validate(path=pyproject_path) if config is None: typer.secho( @@ -88,6 +94,12 @@ def run( typer.secho("Success", fg=typer.colors.GREEN) + if use_superexec: + _start_superexec_run( + parse_config_args(config_overrides, separator=","), directory + ) + return + server_app_ref = config["flower"]["components"]["serverapp"] client_app_ref = config["flower"]["components"]["clientapp"] @@ -115,7 +127,9 @@ def run( ) -def _start_superexec_run(directory: Optional[Path]) -> None: +def _start_superexec_run( + override_config: Dict[str, str], directory: Optional[Path] +) -> None: def on_channel_state_change(channel_connectivity: str) -> None: """Log channel connectivity.""" log(DEBUG, channel_connectivity) @@ -132,6 +146,9 @@ def on_channel_state_change(channel_connectivity: str) -> None: fab_path = build(directory) - req = StartRunRequest(fab_file=Path(fab_path).read_bytes()) + req = StartRunRequest( + fab_file=Path(fab_path).read_bytes(), + override_config=override_config, + ) res = stub.StartRun(req) typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index fa446afcc1fb..851083d4abb7 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -19,6 +19,7 @@ import time from dataclasses import dataclass from logging import DEBUG, ERROR, INFO, WARN +from pathlib import Path from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union from cryptography.hazmat.primitives.asymmetric import ec @@ -41,6 +42,7 @@ from flwr.common.logger import log, warn_deprecated_feature from flwr.common.message import Error from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential +from flwr.common.typing import Run from .grpc_adapter_client.connection import grpc_adapter from .grpc_client.connection import grpc_connection @@ -192,6 +194,7 @@ def _start_client_internal( max_retries: Optional[int] = None, max_wait_time: Optional[float] = None, partition_id: Optional[int] = None, + flwr_dir: Optional[Path] = None, ) -> None: """Start a Flower client node which connects to a Flower server. @@ -235,9 +238,11 @@ class `flwr.client.Client` (default: None) The maximum duration before the client stops trying to connect to the server in case of connection error. If set to None, there is no limit to the total time. - partitioni_id: Optional[int] (default: None) + partition_id: Optional[int] (default: None) The data partition index associated with this node. Better suited for prototyping purposes. + flwr_dir: Optional[Path] (default: None) + The fully resolved path containing installed Flower Apps. """ if insecure is None: insecure = root_certificates is None @@ -298,7 +303,7 @@ def _on_backoff(retry_state: RetryState) -> None: retry_invoker = RetryInvoker( wait_gen_factory=exponential, recoverable_exceptions=connection_error_type, - max_tries=max_retries, + max_tries=max_retries + 1 if max_retries is not None else None, max_time=max_wait_time, on_giveup=lambda retry_state: ( log( @@ -315,8 +320,7 @@ def _on_backoff(retry_state: RetryState) -> None: ) node_state = NodeState(partition_id=partition_id) - # run_id -> (fab_id, fab_version) - run_info: Dict[int, Tuple[str, str]] = {} + runs: Dict[int, Run] = {} while not app_state_tracker.interrupt: sleep_duration: int = 0 @@ -366,15 +370,17 @@ def _on_backoff(retry_state: RetryState) -> None: # Get run info run_id = message.metadata.run_id - if run_id not in run_info: + if run_id not in runs: if get_run is not None: - run_info[run_id] = get_run(run_id) + runs[run_id] = get_run(run_id) # If get_run is None, i.e., in grpc-bidi mode else: - run_info[run_id] = ("", "") + runs[run_id] = Run(run_id, "", "", {}) # Register context for this run - node_state.register_context(run_id=run_id) + node_state.register_context( + run_id=run_id, run=runs[run_id], flwr_dir=flwr_dir + ) # Retrieve context for this run context = node_state.retrieve_context(run_id=run_id) @@ -388,7 +394,10 @@ def _on_backoff(retry_state: RetryState) -> None: # Handle app loading and task message try: # Load ClientApp instance - client_app: ClientApp = load_client_app_fn(*run_info[run_id]) + run: Run = runs[run_id] + client_app: ClientApp = load_client_app_fn( + run.fab_id, run.fab_version + ) # Execute ClientApp reply_message = client_app(message=message, context=context) @@ -573,7 +582,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ], ], diff --git a/src/py/flwr/client/grpc_adapter_client/connection.py b/src/py/flwr/client/grpc_adapter_client/connection.py index e4e32b3accd0..971b630e470b 100644 --- a/src/py/flwr/client/grpc_adapter_client/connection.py +++ b/src/py/flwr/client/grpc_adapter_client/connection.py @@ -27,6 +27,7 @@ from flwr.common.logger import log from flwr.common.message import Message from flwr.common.retry_invoker import RetryInvoker +from flwr.common.typing import Run @contextmanager @@ -45,7 +46,7 @@ def grpc_adapter( # pylint: disable=R0913 Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ]: """Primitives for request/response-based interaction with a server via GrpcAdapter. diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 8c049861c672..3e9f261c1ecf 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -38,6 +38,7 @@ from flwr.common.grpc import create_channel from flwr.common.logger import log from flwr.common.retry_invoker import RetryInvoker +from flwr.common.typing import Run from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, Reason, @@ -73,7 +74,7 @@ def grpc_connection( # pylint: disable=R0913, R0915 Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ]: """Establish a gRPC connection to a gRPC server. diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 34dc0e417383..8062ce28fcc7 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -41,6 +41,7 @@ from flwr.common.message import Message, Metadata from flwr.common.retry_invoker import RetryInvoker from flwr.common.serde import message_from_taskins, message_to_taskres +from flwr.common.typing import Run from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, @@ -80,7 +81,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915 Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ]: """Primitives for request/response-based interaction with a server. @@ -266,7 +267,7 @@ def send(message: Message) -> None: # Cleanup metadata = None - def get_run(run_id: int) -> Tuple[str, str]: + def get_run(run_id: int) -> Run: # Call FleetAPI get_run_request = GetRunRequest(run_id=run_id) get_run_response: GetRunResponse = retry_invoker.invoke( @@ -275,7 +276,12 @@ def get_run(run_id: int) -> Tuple[str, str]: ) # Return fab_id and fab_version - return get_run_response.run.fab_id, get_run_response.run.fab_version + return Run( + run_id, + get_run_response.run.fab_id, + get_run_response.run.fab_version, + dict(get_run_response.run.override_config.items()), + ) try: # Yield methods diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index cafcbbde3984..9ce4c9620c43 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -145,7 +145,7 @@ def test_client_without_get_properties() -> None: actual_msg = handle_legacy_message_from_msgtype( client_fn=_get_client_fn(client), message=message, - context=Context(state=RecordSet()), + context=Context(state=RecordSet(), run_config={}), ) # Assert @@ -209,7 +209,7 @@ def test_client_with_get_properties() -> None: actual_msg = handle_legacy_message_from_msgtype( client_fn=_get_client_fn(client), message=message, - context=Context(state=RecordSet()), + context=Context(state=RecordSet(), run_config={}), ) # Assert diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py index 36844a2983a1..5e4c4411e1f7 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py @@ -73,7 +73,7 @@ def func(configs: Dict[str, ConfigsRecordValues]) -> ConfigsRecord: def _make_ctxt() -> Context: cfg = ConfigsRecord(SecAggPlusState().to_dict()) - return Context(RecordSet(configs_records={RECORD_KEY_STATE: cfg})) + return Context(RecordSet(configs_records={RECORD_KEY_STATE: cfg}), run_config={}) def _make_set_state_fn( diff --git a/src/py/flwr/client/mod/utils_test.py b/src/py/flwr/client/mod/utils_test.py index 035e41639b10..7a1dd8988399 100644 --- a/src/py/flwr/client/mod/utils_test.py +++ b/src/py/flwr/client/mod/utils_test.py @@ -104,7 +104,7 @@ def test_multiple_mods(self) -> None: state = RecordSet() state.metrics_records[METRIC] = MetricsRecord({COUNTER: 0.0}) - context = Context(state=state) + context = Context(state=state, run_config={}) message = _get_dummy_flower_message() # Execute @@ -129,7 +129,7 @@ def test_filter(self) -> None: # Prepare footprint: List[str] = [] mock_app = make_mock_app("app", footprint) - context = Context(state=RecordSet()) + context = Context(state=RecordSet(), run_config={}) message = _get_dummy_flower_message() def filter_mod( diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/node_state.py index cda00d25b62c..2b090eba9720 100644 --- a/src/py/flwr/client/node_state.py +++ b/src/py/flwr/client/node_state.py @@ -15,9 +15,21 @@ """Node state.""" +from dataclasses import dataclass +from pathlib import Path from typing import Any, Dict, Optional from flwr.common import Context, RecordSet +from flwr.common.config import get_fused_config +from flwr.common.typing import Run + + +@dataclass() +class RunInfo: + """Contains the Context and initial run_config of a Run.""" + + context: Context + initial_run_config: Dict[str, str] class NodeState: @@ -25,20 +37,31 @@ class NodeState: def __init__(self, partition_id: Optional[int]) -> None: self._meta: Dict[str, Any] = {} # holds metadata about the node - self.run_contexts: Dict[int, Context] = {} + self.run_infos: Dict[int, RunInfo] = {} self._partition_id = partition_id - def register_context(self, run_id: int) -> None: + def register_context( + self, + run_id: int, + run: Optional[Run] = None, + flwr_dir: Optional[Path] = None, + ) -> None: """Register new run context for this node.""" - if run_id not in self.run_contexts: - self.run_contexts[run_id] = Context( - state=RecordSet(), partition_id=self._partition_id + if run_id not in self.run_infos: + initial_run_config = get_fused_config(run, flwr_dir) if run else {} + self.run_infos[run_id] = RunInfo( + initial_run_config=initial_run_config, + context=Context( + state=RecordSet(), + run_config=initial_run_config.copy(), + partition_id=self._partition_id, + ), ) def retrieve_context(self, run_id: int) -> Context: """Get run context given a run_id.""" - if run_id in self.run_contexts: - return self.run_contexts[run_id] + if run_id in self.run_infos: + return self.run_infos[run_id].context raise RuntimeError( f"Context for run_id={run_id} doesn't exist." @@ -48,4 +71,9 @@ def retrieve_context(self, run_id: int) -> Context: def update_context(self, run_id: int, context: Context) -> None: """Update run context.""" - self.run_contexts[run_id] = context + if context.run_config != self.run_infos[run_id].initial_run_config: + raise ValueError( + "The `run_config` field of the `Context` object cannot be " + f"modified (run_id: {run_id})." + ) + self.run_infos[run_id].context = context diff --git a/src/py/flwr/client/node_state_tests.py b/src/py/flwr/client/node_state_tests.py index 311dbd41d742..effd64a3ae7a 100644 --- a/src/py/flwr/client/node_state_tests.py +++ b/src/py/flwr/client/node_state_tests.py @@ -59,7 +59,8 @@ def test_multirun_in_node_state() -> None: node_state.update_context(run_id=run_id, context=updated_state) # Verify values - for run_id, context in node_state.run_contexts.items(): + for run_id, run_info in node_state.run_infos.items(): assert ( - context.state.configs_records["counter"]["count"] == expected_values[run_id] + run_info.context.state.configs_records["counter"]["count"] + == expected_values[run_id] ) diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index db5bd7eb6770..0efa5731ae51 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -41,6 +41,7 @@ from flwr.common.message import Message, Metadata from flwr.common.retry_invoker import RetryInvoker from flwr.common.serde import message_from_taskins, message_to_taskres +from flwr.common.typing import Run from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, @@ -91,7 +92,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915 Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ]: """Primitives for request/response-based interaction with a server. @@ -344,16 +345,21 @@ def send(message: Message) -> None: res.results, # pylint: disable=no-member ) - def get_run(run_id: int) -> Tuple[str, str]: + def get_run(run_id: int) -> Run: # Construct the request req = GetRunRequest(run_id=run_id) # Send the request res = _request(req, GetRunResponse, PATH_GET_RUN) if res is None: - return "", "" + return Run(run_id, "", "", {}) - return res.run.fab_id, res.run.fab_version + return Run( + run_id, + res.run.fab_id, + res.run.fab_version, + dict(res.run.override_config.items()), + ) try: # Yield methods diff --git a/src/py/flwr/client/supernode/app.py b/src/py/flwr/client/supernode/app.py index 20e0c44eab14..355a2a13a0e5 100644 --- a/src/py/flwr/client/supernode/app.py +++ b/src/py/flwr/client/supernode/app.py @@ -68,6 +68,7 @@ def run_supernode() -> None: max_retries=args.max_retries, max_wait_time=args.max_wait_time, partition_id=args.partition_id, + flwr_dir=get_flwr_dir(args.flwr_dir), ) # Graceful shutdown @@ -178,7 +179,7 @@ def _get_load_client_app_fn( else: flwr_dir = Path(args.flwr_dir).absolute() - sys.path.insert(0, str(flwr_dir.absolute())) + inserted_path = None default_app_ref: str = getattr(args, "client-app") @@ -188,6 +189,11 @@ def _get_load_client_app_fn( "Flower SuperNode will load and validate ClientApp `%s`", getattr(args, "client-app"), ) + # Insert sys.path + dir_path = Path(args.dir).absolute() + sys.path.insert(0, str(dir_path)) + inserted_path = str(dir_path) + valid, error_msg = validate(default_app_ref) if not valid and error_msg: raise LoadClientAppError(error_msg) from None @@ -196,7 +202,7 @@ def _load(fab_id: str, fab_version: str) -> ClientApp: # If multi-app feature is disabled if not multi_app: # Get sys path to be inserted - sys_path = Path(args.dir).absolute() + dir_path = Path(args.dir).absolute() # Set app reference client_app_ref = default_app_ref @@ -209,7 +215,7 @@ def _load(fab_id: str, fab_version: str) -> ClientApp: log(WARN, "FAB ID is not provided; the default ClientApp will be loaded.") # Get sys path to be inserted - sys_path = Path(args.dir).absolute() + dir_path = Path(args.dir).absolute() # Set app reference client_app_ref = default_app_ref @@ -222,13 +228,21 @@ def _load(fab_id: str, fab_version: str) -> ClientApp: raise LoadClientAppError("Failed to load ClientApp") from e # Get sys path to be inserted - sys_path = Path(project_dir).absolute() + dir_path = Path(project_dir).absolute() # Set app reference client_app_ref = config["flower"]["components"]["clientapp"] # Set sys.path - sys.path.insert(0, str(sys_path)) + nonlocal inserted_path + if inserted_path != str(dir_path): + # Remove the previously inserted path + if inserted_path is not None: + sys.path.remove(inserted_path) + # Insert the new path + sys.path.insert(0, str(dir_path)) + + inserted_path = str(dir_path) # Load ClientApp log( @@ -236,7 +250,7 @@ def _load(fab_id: str, fab_version: str) -> ClientApp: "Loading ClientApp `%s`", client_app_ref, ) - client_app = load_app(client_app_ref, LoadClientAppError, sys_path) + client_app = load_app(client_app_ref, LoadClientAppError, dir_path) if not isinstance(client_app, ClientApp): raise LoadClientAppError( @@ -345,8 +359,8 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None: "--max-retries", type=int, default=None, - help="The maximum number of times the client will try to connect to the" - "server before giving up in case of a connection error. By default," + help="The maximum number of times the client will try to reconnect to the" + "SuperLink before giving up in case of a connection error. By default," "it is set to None, meaning there is no limit to the number of tries.", ) parser.add_argument( @@ -354,7 +368,7 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None: type=float, default=None, help="The maximum duration before the client stops trying to" - "connect to the server in case of connection error. By default, it" + "connect to the SuperLink in case of connection error. By default, it" "is set to None, meaning there is no limit to the total time.", ) parser.add_argument( diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index 20de00a6fba9..9770bdb4af2b 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -16,12 +16,13 @@ import os from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import tomli from flwr.cli.config_utils import validate_fields from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME +from flwr.common.typing import Run def get_flwr_dir(provided_path: Optional[str] = None) -> Path: @@ -30,7 +31,7 @@ def get_flwr_dir(provided_path: Optional[str] = None) -> Path: return Path( os.getenv( FLWR_HOME, - f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}/.flwr", + Path(f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}") / ".flwr", ) ) return Path(provided_path).absolute() @@ -71,3 +72,75 @@ def get_project_config(project_dir: Union[str, Path]) -> Dict[str, Any]: ) return config + + +def _fuse_dicts( + main_dict: Dict[str, str], override_dict: Dict[str, str] +) -> Dict[str, str]: + fused_dict = main_dict.copy() + + for key, value in override_dict.items(): + if key in main_dict: + fused_dict[key] = value + + return fused_dict + + +def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> Dict[str, str]: + """Merge the overrides from a `Run` with the config from a FAB. + + Get the config using the fab_id and the fab_version, remove the nesting by adding + the nested keys as prefixes separated by dots, and fuse it with the override dict. + """ + if not run.fab_id or not run.fab_version: + return {} + + project_dir = get_project_dir(run.fab_id, run.fab_version, flwr_dir) + + default_config = get_project_config(project_dir)["flower"].get("config", {}) + flat_default_config = flatten_dict(default_config) + + return _fuse_dicts(flat_default_config, run.override_config) + + +def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, str]: + """Flatten dict by joining nested keys with a given separator.""" + items: List[Tuple[str, str]] = [] + separator: str = "." + for k, v in raw_dict.items(): + new_key = f"{parent_key}{separator}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, parent_key=new_key).items()) + elif isinstance(v, str): + items.append((new_key, v)) + else: + raise ValueError( + f"The value for key {k} needs to be a `str` or a `dict`.", + ) + return dict(items) + + +def parse_config_args( + config_overrides: Optional[str], + separator: str = ",", +) -> Dict[str, str]: + """Parse separator separated list of key-value pairs separated by '='.""" + overrides: Dict[str, str] = {} + + if config_overrides is None: + return overrides + + overrides_list = config_overrides.split(separator) + if ( + len(overrides_list) == 1 + and "=" not in overrides_list + and overrides_list[0].endswith(".toml") + ): + with Path(overrides_list[0]).open("rb") as config_file: + overrides = flatten_dict(tomli.load(config_file)) + else: + for kv_pair in overrides_list: + key, value = kv_pair.split("=") + overrides[key] = value + + return overrides diff --git a/src/py/flwr/common/config_test.py b/src/py/flwr/common/config_test.py new file mode 100644 index 000000000000..fe429bab9cb5 --- /dev/null +++ b/src/py/flwr/common/config_test.py @@ -0,0 +1,230 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test util functions handling Flower config.""" + +import os +import textwrap +from pathlib import Path +from unittest.mock import patch + +import pytest + +from .config import ( + _fuse_dicts, + flatten_dict, + get_flwr_dir, + get_project_config, + get_project_dir, + parse_config_args, +) + +# Mock constants +FAB_CONFIG_FILE = "pyproject.toml" + + +def test_get_flwr_dir_with_provided_path() -> None: + """Test get_flwr_dir with a provided valid path.""" + provided_path = "." + assert get_flwr_dir(provided_path) == Path(provided_path).absolute() + + +def test_get_flwr_dir_without_provided_path() -> None: + """Test get_flwr_dir without a provided path, using default home directory.""" + with patch.dict(os.environ, {"HOME": "/home/user"}): + assert get_flwr_dir() == Path("/home/user/.flwr") + + +def test_get_flwr_dir_with_flwr_home() -> None: + """Test get_flwr_dir with FLWR_HOME environment variable set.""" + with patch.dict(os.environ, {"FLWR_HOME": "/custom/flwr/home"}): + assert get_flwr_dir() == Path("/custom/flwr/home") + + +def test_get_flwr_dir_with_xdg_data_home() -> None: + """Test get_flwr_dir with FLWR_HOME environment variable set.""" + with patch.dict(os.environ, {"XDG_DATA_HOME": "/custom/data/home"}): + assert get_flwr_dir() == Path("/custom/data/home/.flwr") + + +def test_get_project_dir_invalid_fab_id() -> None: + """Test get_project_dir with an invalid fab_id.""" + with pytest.raises(ValueError): + get_project_dir("invalid_fab_id", "1.0.0") + + +def test_get_project_dir_valid() -> None: + """Test get_project_dir with an valid fab_id and version.""" + app_path = get_project_dir("app_name/user", "1.0.0", flwr_dir=".") + assert app_path == Path("apps") / "app_name" / "user" / "1.0.0" + + +def test_get_project_config_file_not_found() -> None: + """Test get_project_config when the configuration file is not found.""" + with pytest.raises(FileNotFoundError): + get_project_config("/invalid/dir") + + +def test_get_fused_config_valid(tmp_path: Path) -> None: + """Test get_project_config when the configuration file is not found.""" + pyproject_toml_content = """ + [build-system] + requires = ["hatchling"] + build-backend = "hatchling.build" + + [project] + name = "fedgpt" + version = "1.0.0" + description = "" + license = {text = "Apache License (2.0)"} + dependencies = [ + "flwr[simulation]>=1.9.0,<2.0", + "numpy>=1.21.0", + ] + + [flower] + publisher = "flwrlabs" + + [flower.components] + serverapp = "fedgpt.server:app" + clientapp = "fedgpt.client:app" + + [flower.config] + num_server_rounds = "10" + momentum = "0.1" + lr = "0.01" + serverapp.test = "key" + + [flower.config.clientapp] + test = "key" + """ + overrides = { + "num_server_rounds": "5", + "lr": "0.2", + "serverapp.test": "overriden", + } + expected_config = { + "num_server_rounds": "5", + "momentum": "0.1", + "lr": "0.2", + "serverapp.test": "overriden", + "clientapp.test": "key", + } + # Current directory + origin = Path.cwd() + + try: + # Change into the temporary directory + os.chdir(tmp_path) + with open(FAB_CONFIG_FILE, "w", encoding="utf-8") as f: + f.write(textwrap.dedent(pyproject_toml_content)) + + # Execute + default_config = get_project_config(tmp_path)["flower"].get("config", {}) + + config = _fuse_dicts(flatten_dict(default_config), overrides) + + # Assert + assert config == expected_config + finally: + os.chdir(origin) + + +def test_get_project_config_file_valid(tmp_path: Path) -> None: + """Test get_project_config when the configuration file is not found.""" + pyproject_toml_content = """ + [build-system] + requires = ["hatchling"] + build-backend = "hatchling.build" + + [project] + name = "fedgpt" + version = "1.0.0" + description = "" + license = {text = "Apache License (2.0)"} + dependencies = [ + "flwr[simulation]>=1.9.0,<2.0", + "numpy>=1.21.0", + ] + + [flower] + publisher = "flwrlabs" + + [flower.components] + serverapp = "fedgpt.server:app" + clientapp = "fedgpt.client:app" + + [flower.config] + num_server_rounds = "10" + momentum = "0.1" + lr = "0.01" + """ + expected_config = { + "build-system": {"build-backend": "hatchling.build", "requires": ["hatchling"]}, + "project": { + "name": "fedgpt", + "version": "1.0.0", + "description": "", + "license": {"text": "Apache License (2.0)"}, + "dependencies": ["flwr[simulation]>=1.9.0,<2.0", "numpy>=1.21.0"], + }, + "flower": { + "publisher": "flwrlabs", + "components": { + "serverapp": "fedgpt.server:app", + "clientapp": "fedgpt.client:app", + }, + "config": { + "num_server_rounds": "10", + "momentum": "0.1", + "lr": "0.01", + }, + }, + } + # Current directory + origin = Path.cwd() + + try: + # Change into the temporary directory + os.chdir(tmp_path) + with open(FAB_CONFIG_FILE, "w", encoding="utf-8") as f: + f.write(textwrap.dedent(pyproject_toml_content)) + + # Execute + config = get_project_config(tmp_path) + + # Assert + assert config == expected_config + finally: + os.chdir(origin) + + +def test_flatten_dict() -> None: + """Test flatten_dict with a nested dictionary.""" + raw_dict = {"a": {"b": {"c": "d"}}, "e": "f"} + expected = {"a.b.c": "d", "e": "f"} + assert flatten_dict(raw_dict) == expected + + +def test_parse_config_args_none() -> None: + """Test parse_config_args with None as input.""" + assert not parse_config_args(None) + + +def test_parse_config_args_overrides() -> None: + """Test parse_config_args with key-value pairs.""" + assert parse_config_args("key1=value1,key2=value2") == { + "key1": "value1", + "key2": "value2", + } diff --git a/src/py/flwr/common/context.py b/src/py/flwr/common/context.py index 8fe0f1781817..8120723ce9e9 100644 --- a/src/py/flwr/common/context.py +++ b/src/py/flwr/common/context.py @@ -16,7 +16,7 @@ from dataclasses import dataclass -from typing import Optional +from typing import Dict, Optional from .record import RecordSet @@ -34,6 +34,10 @@ class Context: executing mods. It can also be used as a memory to access at different points during the lifecycle of this entity (e.g. across multiple rounds) + run_config : Dict[str, str] + A config (key/value mapping) held by the entity in a given run and that will + stay local. It can be used at any point during the lifecycle of this entity + (e.g. across multiple rounds) partition_id : Optional[int] (default: None) An index that specifies the data partition that the ClientApp using this Context object should make use of. Setting this attribute is better suited for @@ -42,7 +46,14 @@ class Context: state: RecordSet partition_id: Optional[int] - - def __init__(self, state: RecordSet, partition_id: Optional[int] = None) -> None: + run_config: Dict[str, str] + + def __init__( + self, + state: RecordSet, + run_config: Dict[str, str], + partition_id: Optional[int] = None, + ) -> None: self.state = state + self.run_config = run_config self.partition_id = partition_id diff --git a/src/py/flwr/common/logger.py b/src/py/flwr/common/logger.py index 7225b0663ae7..2077f9beaca0 100644 --- a/src/py/flwr/common/logger.py +++ b/src/py/flwr/common/logger.py @@ -197,6 +197,44 @@ def warn_deprecated_feature(name: str) -> None: ) +def warn_deprecated_feature_with_example( + deprecation_message: str, example_message: str, code_example: str +) -> None: + """Warn if a feature is deprecated and show code example.""" + log( + WARN, + """DEPRECATED FEATURE: %s + + Check the following `FEATURE UPDATE` warning message for the preferred + new mechanism to use this feature in Flower. + """, + deprecation_message, + ) + log( + WARN, + """FEATURE UPDATE: %s + ------------------------------------------------------------ + %s + ------------------------------------------------------------ + """, + example_message, + code_example, + ) + + +def warn_unsupported_feature(name: str) -> None: + """Warn the user when they use an unsupported feature.""" + log( + WARN, + """UNSUPPORTED FEATURE: %s + + This is an unsupported feature. It will be removed + entirely in future versions of Flower. + """, + name, + ) + + def set_logger_propagation( child_logger: logging.Logger, value: bool = True ) -> logging.Logger: diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index f51830955679..04d2cf5bbf7f 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -194,3 +194,4 @@ class Run: run_id: int fab_id: str fab_version: str + override_config: Dict[str, str] diff --git a/src/py/flwr/proto/common_pb2.py b/src/py/flwr/proto/common_pb2.py new file mode 100644 index 000000000000..8a6430137f05 --- /dev/null +++ b/src/py/flwr/proto/common_pb2.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: flwr/proto/common.proto +# Protobuf Python Version: 4.25.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/common.proto\x12\nflwr.protob\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.common_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None +# @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/common_pb2.pyi b/src/py/flwr/proto/common_pb2.pyi new file mode 100644 index 000000000000..e08fa11c2caa --- /dev/null +++ b/src/py/flwr/proto/common_pb2.pyi @@ -0,0 +1,7 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import google.protobuf.descriptor + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor diff --git a/src/py/flwr/proto/common_pb2_grpc.py b/src/py/flwr/proto/common_pb2_grpc.py new file mode 100644 index 000000000000..2daafffebfc8 --- /dev/null +++ b/src/py/flwr/proto/common_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/src/py/flwr/proto/common_pb2_grpc.pyi b/src/py/flwr/proto/common_pb2_grpc.pyi new file mode 100644 index 000000000000..f3a5a087ef5d --- /dev/null +++ b/src/py/flwr/proto/common_pb2_grpc.pyi @@ -0,0 +1,4 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" diff --git a/src/py/flwr/proto/driver_pb2.py b/src/py/flwr/proto/driver_pb2.py index a2458b445563..07975937328d 100644 --- a/src/py/flwr/proto/driver_pb2.py +++ b/src/py/flwr/proto/driver_pb2.py @@ -17,29 +17,33 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\"7\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\x84\x03\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\"\xb9\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x1a\x35\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\x84\x03\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.driver_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_CREATERUNREQUEST']._serialized_start=107 - _globals['_CREATERUNREQUEST']._serialized_end=162 - _globals['_CREATERUNRESPONSE']._serialized_start=164 - _globals['_CREATERUNRESPONSE']._serialized_end=199 - _globals['_GETNODESREQUEST']._serialized_start=201 - _globals['_GETNODESREQUEST']._serialized_end=234 - _globals['_GETNODESRESPONSE']._serialized_start=236 - _globals['_GETNODESRESPONSE']._serialized_end=287 - _globals['_PUSHTASKINSREQUEST']._serialized_start=289 - _globals['_PUSHTASKINSREQUEST']._serialized_end=353 - _globals['_PUSHTASKINSRESPONSE']._serialized_start=355 - _globals['_PUSHTASKINSRESPONSE']._serialized_end=394 - _globals['_PULLTASKRESREQUEST']._serialized_start=396 - _globals['_PULLTASKRESREQUEST']._serialized_end=466 - _globals['_PULLTASKRESRESPONSE']._serialized_start=468 - _globals['_PULLTASKRESRESPONSE']._serialized_end=533 - _globals['_DRIVER']._serialized_start=536 - _globals['_DRIVER']._serialized_end=924 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._options = None + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' + _globals['_CREATERUNREQUEST']._serialized_start=108 + _globals['_CREATERUNREQUEST']._serialized_end=293 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=240 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=293 + _globals['_CREATERUNRESPONSE']._serialized_start=295 + _globals['_CREATERUNRESPONSE']._serialized_end=330 + _globals['_GETNODESREQUEST']._serialized_start=332 + _globals['_GETNODESREQUEST']._serialized_end=365 + _globals['_GETNODESRESPONSE']._serialized_start=367 + _globals['_GETNODESRESPONSE']._serialized_end=418 + _globals['_PUSHTASKINSREQUEST']._serialized_start=420 + _globals['_PUSHTASKINSREQUEST']._serialized_end=484 + _globals['_PUSHTASKINSRESPONSE']._serialized_start=486 + _globals['_PUSHTASKINSRESPONSE']._serialized_end=525 + _globals['_PULLTASKRESREQUEST']._serialized_start=527 + _globals['_PULLTASKRESREQUEST']._serialized_end=597 + _globals['_PULLTASKRESRESPONSE']._serialized_start=599 + _globals['_PULLTASKRESRESPONSE']._serialized_end=664 + _globals['_DRIVER']._serialized_start=667 + _globals['_DRIVER']._serialized_end=1055 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/driver_pb2.pyi b/src/py/flwr/proto/driver_pb2.pyi index 2d8d11fb59a3..95d4c9785ff1 100644 --- a/src/py/flwr/proto/driver_pb2.pyi +++ b/src/py/flwr/proto/driver_pb2.pyi @@ -16,16 +16,33 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor class CreateRunRequest(google.protobuf.message.Message): """CreateRun""" DESCRIPTOR: google.protobuf.descriptor.Descriptor + class OverrideConfigEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + value: typing.Text + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Text = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + FAB_ID_FIELD_NUMBER: builtins.int FAB_VERSION_FIELD_NUMBER: builtins.int + OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int fab_id: typing.Text fab_version: typing.Text + @property + def override_config(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, typing.Text]: ... def __init__(self, *, fab_id: typing.Text = ..., fab_version: typing.Text = ..., + override_config: typing.Optional[typing.Mapping[typing.Text, typing.Text]] = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config"]) -> None: ... global___CreateRunRequest = CreateRunRequest class CreateRunResponse(google.protobuf.message.Message): diff --git a/src/py/flwr/proto/exec_pb2.py b/src/py/flwr/proto/exec_pb2.py index 7b037a9454c0..4aee0f4a882f 100644 --- a/src/py/flwr/proto/exec_pb2.py +++ b/src/py/flwr/proto/exec_pb2.py @@ -14,21 +14,25 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\"#\n\x0fStartRunRequest\x12\x10\n\x08\x66\x61\x62_file\x18\x01 \x01(\x0c\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"#\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"(\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t2\xa0\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\"\xa4\x01\n\x0fStartRunRequest\x12\x10\n\x08\x66\x61\x62_file\x18\x01 \x01(\x0c\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x1a\x35\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"#\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"(\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t2\xa0\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.exec_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_STARTRUNREQUEST']._serialized_start=37 - _globals['_STARTRUNREQUEST']._serialized_end=72 - _globals['_STARTRUNRESPONSE']._serialized_start=74 - _globals['_STARTRUNRESPONSE']._serialized_end=108 - _globals['_STREAMLOGSREQUEST']._serialized_start=110 - _globals['_STREAMLOGSREQUEST']._serialized_end=145 - _globals['_STREAMLOGSRESPONSE']._serialized_start=147 - _globals['_STREAMLOGSRESPONSE']._serialized_end=187 - _globals['_EXEC']._serialized_start=190 - _globals['_EXEC']._serialized_end=350 + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._options = None + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' + _globals['_STARTRUNREQUEST']._serialized_start=38 + _globals['_STARTRUNREQUEST']._serialized_end=202 + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=149 + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=202 + _globals['_STARTRUNRESPONSE']._serialized_start=204 + _globals['_STARTRUNRESPONSE']._serialized_end=238 + _globals['_STREAMLOGSREQUEST']._serialized_start=240 + _globals['_STREAMLOGSREQUEST']._serialized_end=275 + _globals['_STREAMLOGSRESPONSE']._serialized_start=277 + _globals['_STREAMLOGSRESPONSE']._serialized_end=317 + _globals['_EXEC']._serialized_start=320 + _globals['_EXEC']._serialized_end=480 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/exec_pb2.pyi b/src/py/flwr/proto/exec_pb2.pyi index 466812808da8..8065fc1de1b4 100644 --- a/src/py/flwr/proto/exec_pb2.pyi +++ b/src/py/flwr/proto/exec_pb2.pyi @@ -4,6 +4,7 @@ isort:skip_file """ import builtins import google.protobuf.descriptor +import google.protobuf.internal.containers import google.protobuf.message import typing import typing_extensions @@ -12,13 +13,30 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor class StartRunRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + class OverrideConfigEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + value: typing.Text + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Text = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + FAB_FILE_FIELD_NUMBER: builtins.int + OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int fab_file: builtins.bytes + @property + def override_config(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, typing.Text]: ... def __init__(self, *, fab_file: builtins.bytes = ..., + override_config: typing.Optional[typing.Mapping[typing.Text, typing.Text]] = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["fab_file",b"fab_file"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["fab_file",b"fab_file","override_config",b"override_config"]) -> None: ... global___StartRunRequest = StartRunRequest class StartRunResponse(google.protobuf.message.Message): diff --git a/src/py/flwr/proto/run_pb2.py b/src/py/flwr/proto/run_pb2.py index 13f06e7169aa..d6531201f647 100644 --- a/src/py/flwr/proto/run_pb2.py +++ b/src/py/flwr/proto/run_pb2.py @@ -14,17 +14,21 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\":\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\"\xaf\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x1a\x35\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.run_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_RUN']._serialized_start=36 - _globals['_RUN']._serialized_end=94 - _globals['_GETRUNREQUEST']._serialized_start=96 - _globals['_GETRUNREQUEST']._serialized_end=127 - _globals['_GETRUNRESPONSE']._serialized_start=129 - _globals['_GETRUNRESPONSE']._serialized_end=175 + _globals['_RUN_OVERRIDECONFIGENTRY']._options = None + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' + _globals['_RUN']._serialized_start=37 + _globals['_RUN']._serialized_end=212 + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=159 + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=212 + _globals['_GETRUNREQUEST']._serialized_start=214 + _globals['_GETRUNREQUEST']._serialized_end=245 + _globals['_GETRUNRESPONSE']._serialized_start=247 + _globals['_GETRUNRESPONSE']._serialized_end=293 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/run_pb2.pyi b/src/py/flwr/proto/run_pb2.pyi index 401d27855a41..3c58c04c1734 100644 --- a/src/py/flwr/proto/run_pb2.pyi +++ b/src/py/flwr/proto/run_pb2.pyi @@ -4,6 +4,7 @@ isort:skip_file """ import builtins import google.protobuf.descriptor +import google.protobuf.internal.containers import google.protobuf.message import typing import typing_extensions @@ -12,19 +13,36 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor class Run(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + class OverrideConfigEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + value: typing.Text + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Text = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + RUN_ID_FIELD_NUMBER: builtins.int FAB_ID_FIELD_NUMBER: builtins.int FAB_VERSION_FIELD_NUMBER: builtins.int + OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int run_id: builtins.int fab_id: typing.Text fab_version: typing.Text + @property + def override_config(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, typing.Text]: ... def __init__(self, *, run_id: builtins.int = ..., fab_id: typing.Text = ..., fab_version: typing.Text = ..., + override_config: typing.Optional[typing.Mapping[typing.Text, typing.Text]] = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","run_id",b"run_id"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config","run_id",b"run_id"]) -> None: ... global___Run = Run class GetRunRequest(google.protobuf.message.Message): diff --git a/src/py/flwr/server/__init__.py b/src/py/flwr/server/__init__.py index 546ce263e2d5..896b46298327 100644 --- a/src/py/flwr/server/__init__.py +++ b/src/py/flwr/server/__init__.py @@ -28,6 +28,7 @@ from .server import Server as Server from .server_app import ServerApp as ServerApp from .server_config import ServerConfig as ServerConfig +from .serverapp_components import ServerAppComponents as ServerAppComponents __all__ = [ "ClientManager", @@ -36,6 +37,7 @@ "LegacyContext", "Server", "ServerApp", + "ServerAppComponents", "ServerConfig", "SimpleClientManager", "run_server_app", diff --git a/src/py/flwr/server/compat/legacy_context.py b/src/py/flwr/server/compat/legacy_context.py index 0b00c98bb16d..9e120c824103 100644 --- a/src/py/flwr/server/compat/legacy_context.py +++ b/src/py/flwr/server/compat/legacy_context.py @@ -52,4 +52,4 @@ def __init__( self.strategy = strategy self.client_manager = client_manager self.history = History() - super().__init__(state) + super().__init__(state, run_config={}) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index e614df659e3f..84da5882eb73 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -16,8 +16,8 @@ import time import warnings -from logging import DEBUG, ERROR, WARNING -from typing import Iterable, List, Optional, Tuple, cast +from logging import DEBUG, WARNING +from typing import Iterable, List, Optional, cast import grpc @@ -27,8 +27,6 @@ from flwr.common.serde import message_from_taskres, message_to_taskins from flwr.common.typing import Run from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 - CreateRunRequest, - CreateRunResponse, GetNodesRequest, GetNodesResponse, PullTaskResRequest, @@ -53,167 +51,103 @@ """ -class GrpcDriverStub: - """`GrpcDriverStub` provides access to the gRPC Driver API/service. +class GrpcDriver(Driver): + """`GrpcDriver` provides an interface to the Driver API. Parameters ---------- - driver_service_address : Optional[str] - The IPv4 or IPv6 address of the Driver API server. - Defaults to `"[::]:9091"`. + run_id : int + The identifier of the run. + driver_service_address : str (default: "[::]:9091") + The address (URL, IPv6, IPv4) of the SuperLink Driver API service. root_certificates : Optional[bytes] (default: None) The PEM-encoded root certificates as a byte string. If provided, a secure connection using the certificates will be established to an SSL-enabled Flower server. """ - def __init__( + def __init__( # pylint: disable=too-many-arguments self, + run_id: int, driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, root_certificates: Optional[bytes] = None, ) -> None: - self.driver_service_address = driver_service_address - self.root_certificates = root_certificates - self.channel: Optional[grpc.Channel] = None - self.stub: Optional[DriverStub] = None + self._run_id = run_id + self._addr = driver_service_address + self._cert = root_certificates + self._run: Optional[Run] = None + self._grpc_stub: Optional[DriverStub] = None + self._channel: Optional[grpc.Channel] = None + self.node = Node(node_id=0, anonymous=True) + + @property + def _is_connected(self) -> bool: + """Check if connected to the Driver API server.""" + return self._channel is not None - def is_connected(self) -> bool: - """Return True if connected to the Driver API server, otherwise False.""" - return self.channel is not None + def _connect(self) -> None: + """Connect to the Driver API. - def connect(self) -> None: - """Connect to the Driver API.""" + This will not call GetRun. + """ event(EventType.DRIVER_CONNECT) - if self.channel is not None or self.stub is not None: + if self._is_connected: log(WARNING, "Already connected") return - self.channel = create_channel( - server_address=self.driver_service_address, - insecure=(self.root_certificates is None), - root_certificates=self.root_certificates, + self._channel = create_channel( + server_address=self._addr, + insecure=(self._cert is None), + root_certificates=self._cert, ) - self.stub = DriverStub(self.channel) - log(DEBUG, "[Driver] Connected to %s", self.driver_service_address) + self._grpc_stub = DriverStub(self._channel) + log(DEBUG, "[Driver] Connected to %s", self._addr) - def disconnect(self) -> None: + def _disconnect(self) -> None: """Disconnect from the Driver API.""" event(EventType.DRIVER_DISCONNECT) - if self.channel is None or self.stub is None: + if not self._is_connected: log(DEBUG, "Already disconnected") return - channel = self.channel - self.channel = None - self.stub = None + channel: grpc.Channel = self._channel + self._channel = None + self._grpc_stub = None channel.close() log(DEBUG, "[Driver] Disconnected") - def create_run(self, req: CreateRunRequest) -> CreateRunResponse: - """Request for run ID.""" - # Check if channel is open - if self.stub is None: - log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverStub` instance not connected") - - # Call Driver API - res: CreateRunResponse = self.stub.CreateRun(request=req) - return res - - def get_run(self, req: GetRunRequest) -> GetRunResponse: - """Get run information.""" - # Check if channel is open - if self.stub is None: - log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverStub` instance not connected") - - # Call gRPC Driver API - res: GetRunResponse = self.stub.GetRun(request=req) - return res - - def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse: - """Get client IDs.""" - # Check if channel is open - if self.stub is None: - log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverStub` instance not connected") - - # Call gRPC Driver API - res: GetNodesResponse = self.stub.GetNodes(request=req) - return res - - def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse: - """Schedule tasks.""" - # Check if channel is open - if self.stub is None: - log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverStub` instance not connected") - - # Call gRPC Driver API - res: PushTaskInsResponse = self.stub.PushTaskIns(request=req) - return res - - def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse: - """Get task results.""" - # Check if channel is open - if self.stub is None: - log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverStub` instance not connected") - - # Call Driver API - res: PullTaskResResponse = self.stub.PullTaskRes(request=req) - return res - - -class GrpcDriver(Driver): - """`Driver` class provides an interface to the Driver API. - - Parameters - ---------- - run_id : int - The identifier of the run. - stub : Optional[GrpcDriverStub] (default: None) - The ``GrpcDriverStub`` instance used to communicate with the SuperLink. - If None, an instance connected to "[::]:9091" will be created. - """ - - def __init__( # pylint: disable=too-many-arguments - self, - run_id: int, - stub: Optional[GrpcDriverStub] = None, - ) -> None: - self._run_id = run_id - self._run: Optional[Run] = None - self.stub = stub if stub is not None else GrpcDriverStub() - self.node = Node(node_id=0, anonymous=True) + def _init_run(self) -> None: + # Check if is initialized + if self._run is not None: + return + # Get the run info + req = GetRunRequest(run_id=self._run_id) + res: GetRunResponse = self._stub.GetRun(req) + if not res.HasField("run"): + raise RuntimeError(f"Cannot find the run with ID: {self._run_id}") + self._run = Run( + run_id=res.run.run_id, + fab_id=res.run.fab_id, + fab_version=res.run.fab_version, + override_config=dict(res.run.override_config.items()), + ) @property def run(self) -> Run: """Run information.""" - self._get_stub_and_run_id() - return Run(**vars(cast(Run, self._run))) + self._init_run() + return Run(**vars(self._run)) - def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]: - # Check if is initialized - if self._run is None: - # Connect - if not self.stub.is_connected(): - self.stub.connect() - # Get the run info - req = GetRunRequest(run_id=self._run_id) - res = self.stub.get_run(req) - if not res.HasField("run"): - raise RuntimeError(f"Cannot find the run with ID: {self._run_id}") - self._run = Run( - run_id=res.run.run_id, - fab_id=res.run.fab_id, - fab_version=res.run.fab_version, - ) - - return self.stub, self._run.run_id + @property + def _stub(self) -> DriverStub: + """Driver stub.""" + if not self._is_connected: + self._connect() + return cast(DriverStub, self._grpc_stub) def _check_message(self, message: Message) -> None: # Check if the message is valid if not ( - message.metadata.run_id == cast(Run, self._run).run_id + # Assume self._run being initialized + message.metadata.run_id == self._run_id and message.metadata.src_node_id == self.node.node_id and message.metadata.message_id == "" and message.metadata.reply_to_message == "" @@ -234,7 +168,7 @@ def create_message( # pylint: disable=too-many-arguments This method constructs a new `Message` with given content and metadata. The `run_id` and `src_node_id` will be set automatically. """ - _, run_id = self._get_stub_and_run_id() + self._init_run() if ttl: warnings.warn( "A custom TTL was set, but note that the SuperLink does not enforce " @@ -245,7 +179,7 @@ def create_message( # pylint: disable=too-many-arguments ttl_ = DEFAULT_TTL if ttl is None else ttl metadata = Metadata( - run_id=run_id, + run_id=self._run_id, message_id="", # Will be set by the server src_node_id=self.node.node_id, dst_node_id=dst_node_id, @@ -258,9 +192,11 @@ def create_message( # pylint: disable=too-many-arguments def get_node_ids(self) -> List[int]: """Get node IDs.""" - stub, run_id = self._get_stub_and_run_id() + self._init_run() # Call GrpcDriverStub method - res = stub.get_nodes(GetNodesRequest(run_id=run_id)) + res: GetNodesResponse = self._stub.GetNodes( + GetNodesRequest(run_id=self._run_id) + ) return [node.node_id for node in res.nodes] def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: @@ -269,7 +205,7 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: This method takes an iterable of messages and sends each message to the node specified in `dst_node_id`. """ - stub, _ = self._get_stub_and_run_id() + self._init_run() # Construct TaskIns task_ins_list: List[TaskIns] = [] for msg in messages: @@ -280,7 +216,9 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: # Add to list task_ins_list.append(taskins) # Call GrpcDriverStub method - res = stub.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list)) + res: PushTaskInsResponse = self._stub.PushTaskIns( + PushTaskInsRequest(task_ins_list=task_ins_list) + ) return list(res.task_ids) def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: @@ -289,9 +227,9 @@ def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: This method is used to collect messages from the SuperLink that correspond to a set of given message IDs. """ - stub, _ = self._get_stub_and_run_id() + self._init_run() # Pull TaskRes - res = stub.pull_task_res( + res: PullTaskResResponse = self._stub.PullTaskRes( PullTaskResRequest(node=self.node, task_ids=message_ids) ) # Convert TaskRes to Message @@ -331,7 +269,7 @@ def send_and_receive( def close(self) -> None: """Disconnect from the SuperLink if connected.""" # Check if `connect` was called before - if not self.stub.is_connected(): + if not self._is_connected: return # Disconnect - self.stub.disconnect() + self._disconnect() diff --git a/src/py/flwr/server/driver/grpc_driver_test.py b/src/py/flwr/server/driver/grpc_driver_test.py index 72efc5f8b2c6..fdf3c676190d 100644 --- a/src/py/flwr/server/driver/grpc_driver_test.py +++ b/src/py/flwr/server/driver/grpc_driver_test.py @@ -41,10 +41,13 @@ def setUp(self) -> None: mock_response = Mock( run=Run(run_id=61016, fab_id="mock/mock", fab_version="v1.0.0") ) - self.mock_grpc_driver_stub = Mock() - self.mock_grpc_driver_stub.get_run.return_value = mock_response - self.mock_grpc_driver_stub.HasField.return_value = True - self.driver = GrpcDriver(run_id=61016, stub=self.mock_grpc_driver_stub) + self.mock_stub = Mock() + self.mock_channel = Mock() + self.mock_stub.GetRun.return_value = mock_response + mock_response.HasField.return_value = True + self.driver = GrpcDriver(run_id=61016) + self.driver._grpc_stub = self.mock_stub # pylint: disable=protected-access + self.driver._channel = self.mock_channel # pylint: disable=protected-access def test_init_grpc_driver(self) -> None: """Test GrpcDriverStub initialization.""" @@ -52,21 +55,21 @@ def test_init_grpc_driver(self) -> None: self.assertEqual(self.driver.run.run_id, 61016) self.assertEqual(self.driver.run.fab_id, "mock/mock") self.assertEqual(self.driver.run.fab_version, "v1.0.0") - self.mock_grpc_driver_stub.get_run.assert_called_once() + self.mock_stub.GetRun.assert_called_once() def test_get_nodes(self) -> None: """Test retrieval of nodes.""" # Prepare mock_response = Mock() mock_response.nodes = [Mock(node_id=404), Mock(node_id=200)] - self.mock_grpc_driver_stub.get_nodes.return_value = mock_response + self.mock_stub.GetNodes.return_value = mock_response # Execute node_ids = self.driver.get_node_ids() - args, kwargs = self.mock_grpc_driver_stub.get_nodes.call_args + args, kwargs = self.mock_stub.GetNodes.call_args # Assert - self.mock_grpc_driver_stub.get_run.assert_called_once() + self.mock_stub.GetRun.assert_called_once() self.assertEqual(len(args), 1) self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], GetNodesRequest) @@ -77,7 +80,7 @@ def test_push_messages_valid(self) -> None: """Test pushing valid messages.""" # Prepare mock_response = Mock(task_ids=["id1", "id2"]) - self.mock_grpc_driver_stub.push_task_ins.return_value = mock_response + self.mock_stub.PushTaskIns.return_value = mock_response msgs = [ self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL) for _ in range(2) @@ -85,10 +88,10 @@ def test_push_messages_valid(self) -> None: # Execute msg_ids = self.driver.push_messages(msgs) - args, kwargs = self.mock_grpc_driver_stub.push_task_ins.call_args + args, kwargs = self.mock_stub.PushTaskIns.call_args # Assert - self.mock_grpc_driver_stub.get_run.assert_called_once() + self.mock_stub.GetRun.assert_called_once() self.assertEqual(len(args), 1) self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], PushTaskInsRequest) @@ -100,7 +103,7 @@ def test_push_messages_invalid(self) -> None: """Test pushing invalid messages.""" # Prepare mock_response = Mock(task_ids=["id1", "id2"]) - self.mock_grpc_driver_stub.push_task_ins.return_value = mock_response + self.mock_stub.PushTaskIns.return_value = mock_response msgs = [ self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL) for _ in range(2) @@ -124,16 +127,16 @@ def test_pull_messages_with_given_message_ids(self) -> None: ), TaskRes(task=Task(ancestry=["id3"], error=error_to_proto(Error(code=0)))), ] - self.mock_grpc_driver_stub.pull_task_res.return_value = mock_response + self.mock_stub.PullTaskRes.return_value = mock_response msg_ids = ["id1", "id2", "id3"] # Execute msgs = self.driver.pull_messages(msg_ids) reply_tos = {msg.metadata.reply_to_message for msg in msgs} - args, kwargs = self.mock_grpc_driver_stub.pull_task_res.call_args + args, kwargs = self.mock_stub.PullTaskRes.call_args # Assert - self.mock_grpc_driver_stub.get_run.assert_called_once() + self.mock_stub.GetRun.assert_called_once() self.assertEqual(len(args), 1) self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], PullTaskResRequest) @@ -144,14 +147,14 @@ def test_send_and_receive_messages_complete(self) -> None: """Test send and receive all messages successfully.""" # Prepare mock_response = Mock(task_ids=["id1"]) - self.mock_grpc_driver_stub.push_task_ins.return_value = mock_response + self.mock_stub.PushTaskIns.return_value = mock_response # The response message must include either `content` (i.e. a recordset) or # an `Error`. We choose the latter in this case error_proto = error_to_proto(Error(code=0)) mock_response = Mock( task_res_list=[TaskRes(task=Task(ancestry=["id1"], error=error_proto))] ) - self.mock_grpc_driver_stub.pull_task_res.return_value = mock_response + self.mock_stub.PullTaskRes.return_value = mock_response msgs = [self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL)] # Execute @@ -166,9 +169,9 @@ def test_send_and_receive_messages_timeout(self) -> None: # Prepare sleep_fn = time.sleep mock_response = Mock(task_ids=["id1"]) - self.mock_grpc_driver_stub.push_task_ins.return_value = mock_response + self.mock_stub.PushTaskIns.return_value = mock_response mock_response = Mock(task_res_list=[]) - self.mock_grpc_driver_stub.pull_task_res.return_value = mock_response + self.mock_stub.PullTaskRes.return_value = mock_response msgs = [self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL)] # Execute @@ -182,22 +185,20 @@ def test_send_and_receive_messages_timeout(self) -> None: def test_del_with_initialized_driver(self) -> None: """Test cleanup behavior when Driver is initialized.""" - # Prepare - self.mock_grpc_driver_stub.is_connected.return_value = True - # Execute self.driver.close() # Assert - self.mock_grpc_driver_stub.disconnect.assert_called_once() + self.mock_channel.close.assert_called_once() def test_del_with_uninitialized_driver(self) -> None: """Test cleanup behavior when Driver is not initialized.""" # Prepare - self.mock_grpc_driver_stub.is_connected.return_value = False + self.driver._grpc_stub = None # pylint: disable=protected-access + self.driver._channel = None # pylint: disable=protected-access # Execute self.driver.close() # Assert - self.mock_grpc_driver_stub.disconnect.assert_not_called() + self.mock_channel.close.assert_not_called() diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index 0cc1c5a53e13..d0f32e830f7d 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -86,7 +86,10 @@ def setUp(self) -> None: for _ in range(self.num_nodes) ] self.state.get_run.return_value = Run( - run_id=61016, fab_id="mock/mock", fab_version="v1.0.0" + run_id=61016, + fab_id="mock/mock", + fab_version="v1.0.0", + override_config={"test_key": "test_value"}, ) state_factory = MagicMock(state=lambda: self.state) self.driver = InMemoryDriver(run_id=61016, state_factory=state_factory) @@ -98,6 +101,7 @@ def test_get_run(self) -> None: self.assertEqual(self.driver.run.run_id, 61016) self.assertEqual(self.driver.run.fab_id, "mock/mock") self.assertEqual(self.driver.run.fab_version, "v1.0.0") + self.assertEqual(self.driver.run.override_config["test_key"], "test_value") def test_get_nodes(self) -> None: """Test retrieval of nodes.""" @@ -223,7 +227,7 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None: # Prepare state = StateFactory("").state() self.driver = InMemoryDriver( - state.create_run("", ""), MagicMock(state=lambda: state) + state.create_run("", "", {}), MagicMock(state=lambda: state) ) msg_ids, node_id = push_messages(self.driver, self.num_nodes) assert isinstance(state, SqliteState) @@ -249,7 +253,7 @@ def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None: # Prepare state_factory = StateFactory(":flwr-in-memory-state:") state = state_factory.state() - self.driver = InMemoryDriver(state.create_run("", ""), state_factory) + self.driver = InMemoryDriver(state.create_run("", "", {}), state_factory) msg_ids, node_id = push_messages(self.driver, self.num_nodes) assert isinstance(state, InMemoryState) diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index 3505ebfdb0a9..b4697e99913f 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -19,16 +19,24 @@ import sys from logging import DEBUG, INFO, WARN from pathlib import Path -from typing import Optional +from typing import Dict, Optional from flwr.common import Context, EventType, RecordSet, event -from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir +from flwr.common.config import ( + get_flwr_dir, + get_fused_config, + get_project_config, + get_project_dir, +) from flwr.common.logger import log, update_console_handler, warn_deprecated_feature from flwr.common.object_ref import load_app -from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611 +from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 + CreateRunRequest, + CreateRunResponse, +) from .driver import Driver -from .driver.grpc_driver import GrpcDriver, GrpcDriverStub +from .driver.grpc_driver import GrpcDriver from .server_app import LoadServerAppError, ServerApp ADDRESS_DRIVER_API = "0.0.0.0:9091" @@ -37,6 +45,7 @@ def run( driver: Driver, server_app_dir: str, + server_app_run_config: Dict[str, str], server_app_attr: Optional[str] = None, loaded_server_app: Optional[ServerApp] = None, ) -> None: @@ -69,7 +78,7 @@ def _load() -> ServerApp: server_app = _load() # Initialize Context - context = Context(state=RecordSet()) + context = Context(state=RecordSet(), run_config=server_app_run_config) # Call ServerApp server_app(driver=driver, context=context) @@ -144,22 +153,29 @@ def run_server_app() -> None: # pylint: disable=too-many-branches "For more details, use: ``flower-server-app -h``" ) - stub = GrpcDriverStub( - driver_service_address=args.superlink, root_certificates=root_certificates - ) + # Initialize GrpcDriver if args.run_id is not None: # User provided `--run-id`, but not `server-app` - run_id = args.run_id + driver = GrpcDriver( + run_id=args.run_id, + driver_service_address=args.superlink, + root_certificates=root_certificates, + ) else: # User provided `server-app`, but not `--run-id` # Create run if run_id is not provided - stub.connect() + driver = GrpcDriver( + run_id=0, # Will be overwritten + driver_service_address=args.superlink, + root_certificates=root_certificates, + ) + # Create run req = CreateRunRequest(fab_id=args.fab_id, fab_version=args.fab_version) - res = stub.create_run(req) - run_id = res.run_id + res: CreateRunResponse = driver._stub.CreateRun(req) # pylint: disable=W0212 + # Overwrite driver._run_id + driver._run_id = res.run_id # pylint: disable=W0212 - # Initialize GrpcDriver - driver = GrpcDriver(run_id=run_id, stub=stub) + server_app_run_config = {} # Dynamically obtain ServerApp path based on run_id if args.run_id is not None: @@ -169,6 +185,7 @@ def run_server_app() -> None: # pylint: disable=too-many-branches server_app_dir = str(get_project_dir(run_.fab_id, run_.fab_version, flwr_dir)) config = get_project_config(server_app_dir) server_app_attr = config["flower"]["components"]["serverapp"] + server_app_run_config = get_fused_config(run_, flwr_dir) else: # User provided `server-app`, but not `--run-id` server_app_dir = str(Path(args.dir).absolute()) @@ -182,7 +199,12 @@ def run_server_app() -> None: # pylint: disable=too-many-branches ) # Run the ServerApp with the Driver - run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr) + run( + driver=driver, + server_app_dir=server_app_dir, + server_app_run_config=server_app_run_config, + server_app_attr=server_app_attr, + ) # Clean up driver.close() diff --git a/src/py/flwr/server/server_app.py b/src/py/flwr/server/server_app.py index 43b3bcce3f36..e9cb4ddcaf0d 100644 --- a/src/py/flwr/server/server_app.py +++ b/src/py/flwr/server/server_app.py @@ -17,8 +17,11 @@ from typing import Callable, Optional -from flwr.common import Context, RecordSet -from flwr.common.logger import warn_preview_feature +from flwr.common import Context +from flwr.common.logger import ( + warn_deprecated_feature_with_example, + warn_preview_feature, +) from flwr.server.strategy import Strategy from .client_manager import ClientManager @@ -26,7 +29,20 @@ from .driver import Driver from .server import Server from .server_config import ServerConfig -from .typing import ServerAppCallable +from .typing import ServerAppCallable, ServerFn + +SERVER_FN_USAGE_EXAMPLE = """ + + def server_fn(context: Context): + server_config = ServerConfig(num_rounds=3) + strategy = FedAvg() + return ServerAppComponents( + strategy=strategy, + server_config=server_config, + ) + + app = ServerApp(server_fn=server_fn) +""" class ServerApp: @@ -36,13 +52,15 @@ class ServerApp: -------- Use the `ServerApp` with an existing `Strategy`: - >>> server_config = ServerConfig(num_rounds=3) - >>> strategy = FedAvg() + >>> def server_fn(context: Context): + >>> server_config = ServerConfig(num_rounds=3) + >>> strategy = FedAvg() + >>> return ServerAppComponents( + >>> strategy=strategy, + >>> server_config=server_config, + >>> ) >>> - >>> app = ServerApp( - >>> server_config=server_config, - >>> strategy=strategy, - >>> ) + >>> app = ServerApp(server_fn=server_fn) Use the `ServerApp` with a custom main function: @@ -53,23 +71,52 @@ class ServerApp: >>> print("ServerApp running") """ + # pylint: disable=too-many-arguments def __init__( self, server: Optional[Server] = None, config: Optional[ServerConfig] = None, strategy: Optional[Strategy] = None, client_manager: Optional[ClientManager] = None, + server_fn: Optional[ServerFn] = None, ) -> None: + if any([server, config, strategy, client_manager]): + warn_deprecated_feature_with_example( + deprecation_message="Passing either `server`, `config`, `strategy` or " + "`client_manager` directly to the ServerApp " + "constructor is deprecated.", + example_message="Pass `ServerApp` arguments wrapped " + "in a `flwr.server.ServerAppComponents` object that gets " + "returned by a function passed as the `server_fn` argument " + "to the `ServerApp` constructor. For example: ", + code_example=SERVER_FN_USAGE_EXAMPLE, + ) + + if server_fn: + raise ValueError( + "Passing `server_fn` is incompatible with passing the " + "other arguments (now deprecated) to ServerApp. " + "Use `server_fn` exclusively." + ) + self._server = server self._config = config self._strategy = strategy self._client_manager = client_manager + self._server_fn = server_fn self._main: Optional[ServerAppCallable] = None def __call__(self, driver: Driver, context: Context) -> None: """Execute `ServerApp`.""" # Compatibility mode if not self._main: + if self._server_fn: + # Execute server_fn() + components = self._server_fn(context) + self._server = components.server + self._config = components.config + self._strategy = components.strategy + self._client_manager = components.client_manager start_driver( server=self._server, config=self._config, @@ -80,7 +127,6 @@ def __call__(self, driver: Driver, context: Context) -> None: return # New execution mode - context = Context(state=RecordSet()) self._main(driver, context) def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]: diff --git a/src/py/flwr/server/server_app_test.py b/src/py/flwr/server/server_app_test.py index 0751a0cb2bc5..7de8774d4c81 100644 --- a/src/py/flwr/server/server_app_test.py +++ b/src/py/flwr/server/server_app_test.py @@ -29,7 +29,7 @@ def test_server_app_custom_mode() -> None: # Prepare app = ServerApp() driver = MagicMock() - context = Context(state=RecordSet()) + context = Context(state=RecordSet(), run_config={}) called = {"called": False} diff --git a/src/py/flwr/server/serverapp_components.py b/src/py/flwr/server/serverapp_components.py new file mode 100644 index 000000000000..315f0a889a61 --- /dev/null +++ b/src/py/flwr/server/serverapp_components.py @@ -0,0 +1,52 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ServerAppComponents for the ServerApp.""" + + +from dataclasses import dataclass +from typing import Optional + +from .client_manager import ClientManager +from .server import Server +from .server_config import ServerConfig +from .strategy import Strategy + + +@dataclass +class ServerAppComponents: # pylint: disable=too-many-instance-attributes + """Components to construct a ServerApp. + + Parameters + ---------- + server : Optional[Server] (default: None) + A server implementation, either `flwr.server.Server` or a subclass + thereof. If no instance is provided, one will be created internally. + config : Optional[ServerConfig] (default: None) + Currently supported values are `num_rounds` (int, default: 1) and + `round_timeout` in seconds (float, default: None). + strategy : Optional[Strategy] (default: None) + An implementation of the abstract base class + `flwr.server.strategy.Strategy`. If no strategy is provided, then + `flwr.server.strategy.FedAvg` will be used. + client_manager : Optional[ClientManager] (default: None) + An implementation of the class `flwr.server.ClientManager`. If no + implementation is provided, then `flwr.server.SimpleClientManager` + will be used. + """ + + server: Optional[Server] = None + config: Optional[ServerConfig] = None + strategy: Optional[Strategy] = None + client_manager: Optional[ClientManager] = None diff --git a/src/py/flwr/server/superlink/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py index 03128f02158e..7f8ded3bdb85 100644 --- a/src/py/flwr/server/superlink/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -69,7 +69,11 @@ def CreateRun( """Create run ID.""" log(DEBUG, "DriverServicer.CreateRun") state: State = self.state_factory.state() - run_id = state.create_run(request.fab_id, request.fab_version) + run_id = state.create_run( + request.fab_id, + request.fab_version, + dict(request.override_config.items()), + ) return CreateRunResponse(run_id=run_id) def PushTaskIns( diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py index 01499102b7d8..798e71435585 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py @@ -328,7 +328,7 @@ def test_successful_get_run_with_metadata(self) -> None: self.state.create_node( ping_interval=30, public_key=public_key_to_bytes(self._client_public_key) ) - run_id = self.state.create_run("", "") + run_id = self.state.create_run("", "", {}) request = GetRunRequest(run_id=run_id) shared_secret = generate_shared_key( self._client_private_key, self._server_public_key @@ -359,7 +359,7 @@ def test_unsuccessful_get_run_with_metadata(self) -> None: self.state.create_node( ping_interval=30, public_key=public_key_to_bytes(self._client_public_key) ) - run_id = self.state.create_run("", "") + run_id = self.state.create_run("", "", {}) request = GetRunRequest(run_id=run_id) client_private_key, _ = generate_key_pairs() shared_secret = generate_shared_key(client_private_key, self._server_public_key) diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/backend.py b/src/py/flwr/server/superlink/fleet/vce/backend/backend.py index 1d5e3a6a51ad..31c64bd3b233 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/backend.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/backend.py @@ -33,8 +33,8 @@ def __init__(self, backend_config: BackendConfig, work_dir: str) -> None: """Construct a backend.""" @abstractmethod - async def build(self) -> None: - """Build backend asynchronously. + def build(self) -> None: + """Build backend. Different components need to be in place before workers in a backend are ready to accept jobs. When this method finishes executing, the backend should be fully @@ -54,11 +54,11 @@ def is_worker_idle(self) -> bool: """Report whether a backend worker is idle and can therefore run a ClientApp.""" @abstractmethod - async def terminate(self) -> None: + def terminate(self) -> None: """Terminate backend.""" @abstractmethod - async def process_message( + def process_message( self, app: Callable[[], ClientApp], message: Message, diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py index b6b9e248a656..0d2f4d193f0b 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py @@ -153,12 +153,12 @@ def is_worker_idle(self) -> bool: """Report whether the pool has idle actors.""" return self.pool.is_actor_available() - async def build(self) -> None: + def build(self) -> None: """Build pool of Ray actors that this backend will submit jobs to.""" - await self.pool.add_actors_to_pool(self.pool.actors_capacity) + self.pool.add_actors_to_pool(self.pool.actors_capacity) log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors) - async def process_message( + def process_message( self, app: Callable[[], ClientApp], message: Message, @@ -172,17 +172,16 @@ async def process_message( try: # Submit a task to the pool - future = await self.pool.submit( + future = self.pool.submit( lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), (app, message, str(partition_id), context), ) - await future # Fetch result ( out_mssg, updated_context, - ) = await self.pool.fetch_result_and_return_actor_to_pool(future) + ) = self.pool.fetch_result_and_return_actor_to_pool(future) return out_mssg, updated_context @@ -193,11 +192,11 @@ async def process_message( self.__class__.__name__, ) # add actor back into pool - await self.pool.add_actor_back_to_pool(future) + self.pool.add_actor_back_to_pool(future) raise ex - async def terminate(self) -> None: + def terminate(self) -> None: """Terminate all actors in actor pool.""" - await self.pool.terminate_all_actors() + self.pool.terminate_all_actors() ray.shutdown() log(DEBUG, "Terminated %s", self.__class__.__name__) diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py index fa7374f08853..287983003f8c 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py @@ -14,11 +14,10 @@ # ============================================================================== """Test for Ray backend for the Fleet API using the Simulation Engine.""" -import asyncio from math import pi from pathlib import Path from typing import Callable, Dict, Optional, Tuple, Union -from unittest import IsolatedAsyncioTestCase +from unittest import TestCase import ray @@ -84,18 +83,18 @@ def _load_app() -> ClientApp: return _load_app -async def backend_build_process_and_termination( +def backend_build_process_and_termination( backend: RayBackend, process_args: Optional[Tuple[Callable[[], ClientApp], Message, Context]] = None, ) -> Union[Tuple[Message, Context], None]: """Build, process job and terminate RayBackend.""" - await backend.build() + backend.build() to_return = None if process_args: - to_return = await backend.process_message(*process_args) + to_return = backend.process_message(*process_args) - await backend.terminate() + backend.terminate() return to_return @@ -121,7 +120,7 @@ def _create_message_and_context() -> Tuple[Message, Context, float]: ) # Construct emtpy Context - context = Context(state=RecordSet()) + context = Context(state=RecordSet(), run_config={}) # Expected output expected_output = pi * mult_factor @@ -129,10 +128,10 @@ def _create_message_and_context() -> Tuple[Message, Context, float]: return message, context, expected_output -class AsyncTestRayBackend(IsolatedAsyncioTestCase): - """A basic class that allows runnig multliple asyncio tests.""" +class TestRayBackend(TestCase): + """A basic class that allows runnig multliple tests.""" - async def on_cleanup(self) -> None: + def doCleanups(self) -> None: """Ensure Ray has shutdown.""" if ray.is_initialized(): ray.shutdown() @@ -140,9 +139,7 @@ async def on_cleanup(self) -> None: def test_backend_creation_and_termination(self) -> None: """Test creation of RayBackend and its termination.""" backend = RayBackend(backend_config={}, work_dir="") - asyncio.run( - backend_build_process_and_termination(backend=backend, process_args=None) - ) + backend_build_process_and_termination(backend=backend, process_args=None) def test_backend_creation_submit_and_termination( self, @@ -157,10 +154,8 @@ def test_backend_creation_submit_and_termination( message, context, expected_output = _create_message_and_context() - res = asyncio.run( - backend_build_process_and_termination( - backend=backend, process_args=(client_app_callable, message, context) - ) + res = backend_build_process_and_termination( + backend=backend, process_args=(client_app_callable, message, context) ) if res is None: @@ -189,7 +184,6 @@ def test_backend_creation_submit_and_termination_non_existing_client_app( self.test_backend_creation_submit_and_termination( client_app_loader=_load_from_module("a_non_existing_module:app") ) - self.addAsyncCleanup(self.on_cleanup) def test_backend_creation_submit_and_termination_existing_client_app( self, @@ -217,7 +211,6 @@ def test_backend_creation_submit_and_termination_existing_client_app_unsetworkdi client_app_loader=_load_from_module("raybackend_test:client_app"), workdir="/?&%$^#%@$!", ) - self.addAsyncCleanup(self.on_cleanup) def test_backend_creation_with_init_arguments(self) -> None: """Testing whether init args are properly parsed to Ray.""" @@ -248,5 +241,3 @@ def test_backend_creation_with_init_arguments(self) -> None: nodes = ray.nodes() assert nodes[0]["Resources"]["CPU"] == backend_config_2["init_args"]["num_cpus"] - - self.addAsyncCleanup(self.on_cleanup) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index 78dc6a900aae..3c0b36e1ca3c 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -14,14 +14,18 @@ # ============================================================================== """Fleet Simulation Engine API.""" -import asyncio + import json import sys +import threading import time import traceback +from concurrent.futures import ThreadPoolExecutor from logging import DEBUG, ERROR, INFO, WARN from pathlib import Path -from typing import Callable, Dict, List, Optional +from queue import Empty, Queue +from time import sleep +from typing import Callable, Dict, Optional from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError from flwr.client.node_state import NodeState @@ -30,8 +34,8 @@ from flwr.common.message import Error from flwr.common.object_ref import load_app from flwr.common.serde import message_from_taskins, message_to_taskres -from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 -from flwr.server.superlink.state import StateFactory +from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 +from flwr.server.superlink.state import State, StateFactory from .backend import Backend, error_messages_backends, supported_backends @@ -52,19 +56,21 @@ def _register_nodes( # pylint: disable=too-many-arguments,too-many-locals -async def worker( +def worker( app_fn: Callable[[], ClientApp], - queue: "asyncio.Queue[TaskIns]", + taskins_queue: "Queue[TaskIns]", + taskres_queue: "Queue[TaskRes]", node_states: Dict[int, NodeState], - state_factory: StateFactory, backend: Backend, + f_stop: threading.Event, ) -> None: """Get TaskIns from queue and pass it to an actor in the pool to execute it.""" - state = state_factory.state() - while True: + while not f_stop.is_set(): out_mssg = None try: - task_ins: TaskIns = await queue.get() + # Fetch from queue with timeout. We use a timeout so + # the stopping event can be evaluated even when the queue is empty. + task_ins: TaskIns = taskins_queue.get(timeout=1.0) node_id = task_ins.task.consumer.node_id # Register and retrieve runstate @@ -75,7 +81,7 @@ async def worker( message = message_from_taskins(task_ins) # Let backend process message - out_mssg, updated_context = await backend.process_message( + out_mssg, updated_context = backend.process_message( app_fn, message, context ) @@ -83,11 +89,9 @@ async def worker( node_states[node_id].update_context( task_ins.run_id, context=updated_context ) - - except asyncio.CancelledError as e: - log(DEBUG, "Terminating async worker: %s", e) - break - + except Empty: + # An exception raised if queue.get times out + pass # Exceptions aren't raised but reported as an error message except Exception as ex: # pylint: disable=broad-exception-caught log(ERROR, ex) @@ -111,67 +115,48 @@ async def worker( task_res = message_to_taskres(out_mssg) # Store TaskRes in state task_res.task.pushed_at = time.time() - state.store_task_res(task_res) + taskres_queue.put(task_res) -async def add_taskins_to_queue( - queue: "asyncio.Queue[TaskIns]", - state_factory: StateFactory, +def add_taskins_to_queue( + state: State, + queue: "Queue[TaskIns]", nodes_mapping: NodeToPartitionMapping, - backend: Backend, - consumers: List["asyncio.Task[None]"], - f_stop: asyncio.Event, + f_stop: threading.Event, ) -> None: - """Retrieve TaskIns and add it to the queue.""" - state = state_factory.state() - num_initial_consumers = len(consumers) + """Put TaskIns in a queue from State.""" while not f_stop.is_set(): for node_id in nodes_mapping.keys(): - task_ins = state.get_task_ins(node_id=node_id, limit=1) - if task_ins: - await queue.put(task_ins[0]) - - # Count consumers that are running - num_active = sum(not (cc.done()) for cc in consumers) - - # Alert if number of consumers decreased by half - if num_active < num_initial_consumers // 2: - log( - WARN, - "Number of active workers has more than halved: (%i/%i active)", - num_active, - num_initial_consumers, - ) + task_ins_list = state.get_task_ins(node_id=node_id, limit=1) + for task_ins in task_ins_list: + queue.put(task_ins) + sleep(0.1) - # Break if consumers died - if num_active == 0: - raise RuntimeError("All workers have died. Ending Simulation.") - # Log some stats - log( - DEBUG, - "Simulation Engine stats: " - "Active workers: (%i/%i) | %s (%i workers) | Tasks in queue: %i)", - num_active, - num_initial_consumers, - backend.__class__.__name__, - backend.num_workers, - queue.qsize(), - ) - await asyncio.sleep(1.0) - log(DEBUG, "Async producer: Stopped pulling from StateFactory.") +def put_taskres_into_state( + state: State, queue: "Queue[TaskRes]", f_stop: threading.Event +) -> None: + """Put TaskRes into State from a queue.""" + while not f_stop.is_set(): + try: + taskres = queue.get(timeout=1.0) + state.store_task_res(taskres) + except Empty: + # queue is empty when timeout was triggered + pass -async def run( +def run( app_fn: Callable[[], ClientApp], backend_fn: Callable[[], Backend], nodes_mapping: NodeToPartitionMapping, state_factory: StateFactory, node_states: Dict[int, NodeState], - f_stop: asyncio.Event, + f_stop: threading.Event, ) -> None: - """Run the VCE async.""" - queue: "asyncio.Queue[TaskIns]" = asyncio.Queue(128) + """Run the VCE.""" + taskins_queue: "Queue[TaskIns]" = Queue() + taskres_queue: "Queue[TaskRes]" = Queue() try: @@ -179,27 +164,48 @@ async def run( backend = backend_fn() # Build backend - await backend.build() + backend.build() # Add workers (they submit Messages to Backend) - worker_tasks = [ - asyncio.create_task( - worker(app_fn, queue, node_states, state_factory, backend) - ) - for _ in range(backend.num_workers) - ] - # Create producer (adds TaskIns into Queue) - producer = asyncio.create_task( - add_taskins_to_queue( - queue, state_factory, nodes_mapping, backend, worker_tasks, f_stop - ) + state = state_factory.state() + + extractor_th = threading.Thread( + target=add_taskins_to_queue, + args=( + state, + taskins_queue, + nodes_mapping, + f_stop, + ), ) + extractor_th.start() - # Wait for producer to finish - # The producer runs forever until f_stop is set or until - # all worker (consumer) coroutines are completed. Workers - # also run forever and only end if an exception is raised. - await asyncio.gather(producer) + injector_th = threading.Thread( + target=put_taskres_into_state, + args=( + state, + taskres_queue, + f_stop, + ), + ) + injector_th.start() + + with ThreadPoolExecutor() as executor: + _ = [ + executor.submit( + worker, + app_fn, + taskins_queue, + taskres_queue, + node_states, + backend, + f_stop, + ) + for _ in range(backend.num_workers) + ] + + extractor_th.join() + injector_th.join() except Exception as ex: @@ -214,18 +220,9 @@ async def run( raise RuntimeError("Simulation Engine crashed.") from ex finally: - # Produced task terminated, now cancel worker tasks - for w_t in worker_tasks: - _ = w_t.cancel() - - while not all(w_t.done() for w_t in worker_tasks): - log(DEBUG, "Terminating async workers...") - await asyncio.sleep(0.5) - - await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()]) # Terminate backend - await backend.terminate() + backend.terminate() # pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches @@ -234,7 +231,7 @@ def start_vce( backend_name: str, backend_config_json_stream: str, app_dir: str, - f_stop: asyncio.Event, + f_stop: threading.Event, client_app: Optional[ClientApp] = None, client_app_attr: Optional[str] = None, num_supernodes: Optional[int] = None, @@ -338,15 +335,13 @@ def _load() -> ClientApp: _ = app_fn() # Run main simulation loop - asyncio.run( - run( - app_fn, - backend_fn, - nodes_mapping, - state_factory, - node_states, - f_stop, - ) + run( + app_fn, + backend_fn, + nodes_mapping, + state_factory, + node_states, + f_stop, ) except LoadClientAppError as loadapp_ex: f_stop_delay = 10 diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py index df9f2cc96f95..7d37f03f6ade 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -15,7 +15,6 @@ """Test Fleet Simulation Engine API.""" -import asyncio import threading import time from itertools import cycle @@ -24,7 +23,7 @@ from pathlib import Path from time import sleep from typing import Dict, Optional, Set, Tuple -from unittest import IsolatedAsyncioTestCase +from unittest import TestCase from uuid import UUID from flwr.client.client_app import LoadClientAppError @@ -46,7 +45,7 @@ from flwr.server.superlink.state import InMemoryState, StateFactory -def terminate_simulation(f_stop: asyncio.Event, sleep_duration: int) -> None: +def terminate_simulation(f_stop: threading.Event, sleep_duration: int) -> None: """Set event to terminate Simulation Engine after `sleep_duration` seconds.""" sleep(sleep_duration) f_stop.set() @@ -82,7 +81,9 @@ def register_messages_into_state( ) -> Dict[UUID, float]: """Register `num_messages` into the state factory.""" state: InMemoryState = state_factory.state() # type: ignore - state.run_ids[run_id] = Run(run_id=run_id, fab_id="Mock/mock", fab_version="v1.0.0") + state.run_ids[run_id] = Run( + run_id=run_id, fab_id="Mock/mock", fab_version="v1.0.0", override_config={} + ) # Artificially add TaskIns to state so they can be processed # by the Simulation Engine logic nodes_cycle = cycle(nodes_mapping.keys()) # we have more messages than supernodes @@ -146,15 +147,15 @@ def start_and_shutdown( ) -> None: """Start Simulation Engine and terminate after specified number of seconds. - Some tests need to be terminated by triggering externally an asyncio.Event. This - is enabled whtn passing `duration`>0. + Some tests need to be terminated by triggering externally an threading.Event. This + is enabled when passing `duration`>0. """ - f_stop = asyncio.Event() + f_stop = threading.Event() if duration: # Setup thread that will set the f_stop event, triggering the termination of all - # asyncio logic in the Simulation Engine. It will also terminate the Backend. + # logic in the Simulation Engine. It will also terminate the Backend. termination_th = threading.Thread( target=terminate_simulation, args=(f_stop, duration) ) @@ -179,8 +180,8 @@ def start_and_shutdown( termination_th.join() -class AsyncTestFleetSimulationEngineRayBackend(IsolatedAsyncioTestCase): - """A basic class that enables testing asyncio functionalities.""" +class TestFleetSimulationEngineRayBackend(TestCase): + """A basic class that enables testing functionalities.""" def test_erroneous_no_supernodes_client_mapping(self) -> None: """Test with unset arguments.""" diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index 5a4e4eb0fd9a..bc4bd4478a23 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -275,7 +275,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `client_public_keys`.""" return self.public_key_to_node_id.get(client_public_key) - def create_run(self, fab_id: str, fab_version: str) -> int: + def create_run( + self, + fab_id: str, + fab_version: str, + override_config: Dict[str, str], + ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" # Sample a random int64 as run_id with self.lock: @@ -283,7 +288,10 @@ def create_run(self, fab_id: str, fab_version: str) -> int: if run_id not in self.run_ids: self.run_ids[run_id] = Run( - run_id=run_id, fab_id=fab_id, fab_version=fab_version + run_id=run_id, + fab_id=fab_id, + fab_version=fab_version, + override_config=override_config, ) return run_id log(ERROR, "Unexpected run creation failure.") diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 725f7c2dff4b..ea6f349b9f9a 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -15,6 +15,7 @@ """SQLite based implemenation of server state.""" +import json import re import sqlite3 import time @@ -61,9 +62,10 @@ SQL_CREATE_TABLE_RUN = """ CREATE TABLE IF NOT EXISTS run( - run_id INTEGER UNIQUE, - fab_id TEXT, - fab_version TEXT + run_id INTEGER UNIQUE, + fab_id TEXT, + fab_version TEXT, + override_config TEXT ); """ @@ -613,7 +615,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]: return node_id return None - def create_run(self, fab_id: str, fab_version: str) -> int: + def create_run( + self, + fab_id: str, + fab_version: str, + override_config: Dict[str, str], + ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" # Sample a random int64 as run_id run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES) @@ -622,8 +629,13 @@ def create_run(self, fab_id: str, fab_version: str) -> int: query = "SELECT COUNT(*) FROM run WHERE run_id = ?;" # If run_id does not exist if self.query(query, (run_id,))[0]["COUNT(*)"] == 0: - query = "INSERT INTO run (run_id, fab_id, fab_version) VALUES (?, ?, ?);" - self.query(query, (run_id, fab_id, fab_version)) + query = ( + "INSERT INTO run (run_id, fab_id, fab_version, override_config)" + "VALUES (?, ?, ?, ?);" + ) + self.query( + query, (run_id, fab_id, fab_version, json.dumps(override_config)) + ) return run_id log(ERROR, "Unexpected run creation failure.") return 0 @@ -687,7 +699,10 @@ def get_run(self, run_id: int) -> Optional[Run]: try: row = self.query(query, (run_id,))[0] return Run( - run_id=run_id, fab_id=row["fab_id"], fab_version=row["fab_version"] + run_id=run_id, + fab_id=row["fab_id"], + fab_version=row["fab_version"], + override_config=json.loads(row["override_config"]), ) except sqlite3.IntegrityError: log(ERROR, "`run_id` does not exist.") diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index 65e2c63cab69..c93f6ba756b8 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -16,7 +16,7 @@ import abc -from typing import List, Optional, Set +from typing import Dict, List, Optional, Set from uuid import UUID from flwr.common.typing import Run @@ -157,7 +157,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `client_public_keys`.""" @abc.abstractmethod - def create_run(self, fab_id: str, fab_version: str) -> int: + def create_run( + self, + fab_id: str, + fab_version: str, + override_config: Dict[str, str], + ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" @abc.abstractmethod diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 373202d5cde6..5f0d23ffc4d8 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -52,7 +52,7 @@ def test_create_and_get_run(self) -> None: """Test if create_run and get_run work correctly.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("Mock/mock", "v1.0.0") + run_id = state.create_run("Mock/mock", "v1.0.0", {"test_key": "test_value"}) # Execute run = state.get_run(run_id) @@ -62,6 +62,7 @@ def test_create_and_get_run(self) -> None: assert run.run_id == run_id assert run.fab_id == "Mock/mock" assert run.fab_version == "v1.0.0" + assert run.override_config["test_key"] == "test_value" def test_get_task_ins_empty(self) -> None: """Validate that a new state has no TaskIns.""" @@ -90,7 +91,7 @@ def test_store_task_ins_one(self) -> None: # Prepare consumer_node_id = 1 state = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins( consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) @@ -125,7 +126,7 @@ def test_store_and_delete_tasks(self) -> None: # Prepare consumer_node_id = 1 state = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins_0 = create_task_ins( consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) @@ -199,7 +200,7 @@ def test_task_ins_store_anonymous_and_retrieve_anonymous(self) -> None: """ # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Execute @@ -214,7 +215,7 @@ def test_task_ins_store_anonymous_and_fail_retrieving_identitiy(self) -> None: """Store anonymous TaskIns and fail to retrieve it.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Execute @@ -228,7 +229,7 @@ def test_task_ins_store_identity_and_fail_retrieving_anonymous(self) -> None: """Store identity TaskIns and fail retrieving it as anonymous.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute @@ -242,7 +243,7 @@ def test_task_ins_store_identity_and_retrieve_identity(self) -> None: """Store identity TaskIns and retrieve it.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute @@ -259,7 +260,7 @@ def test_task_ins_store_delivered_and_fail_retrieving(self) -> None: """Fail retrieving delivered task.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute @@ -302,7 +303,7 @@ def test_task_res_store_and_retrieve_by_task_ins_id(self) -> None: """Store TaskRes retrieve it by task_ins_id.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins_id = uuid4() task_res = create_task_res( producer_node_id=0, @@ -323,7 +324,7 @@ def test_node_ids_initial_state(self) -> None: """Test retrieving all node_ids and empty initial state.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) # Execute retrieved_node_ids = state.get_nodes(run_id) @@ -335,7 +336,7 @@ def test_create_node_and_get_nodes(self) -> None: """Test creating a client node.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_ids = [] # Execute @@ -352,7 +353,7 @@ def test_create_node_public_key(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) # Execute node_id = state.create_node(ping_interval=10, public_key=public_key) @@ -368,7 +369,7 @@ def test_create_node_public_key_twice(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute @@ -390,7 +391,7 @@ def test_delete_node(self) -> None: """Test deleting a client node.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = state.create_node(ping_interval=10) # Execute @@ -405,7 +406,7 @@ def test_delete_node_public_key(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute @@ -422,7 +423,7 @@ def test_delete_node_public_key_none(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = 0 # Execute & Assert @@ -441,7 +442,7 @@ def test_delete_node_wrong_public_key(self) -> None: state: State = self.state_factory() public_key = b"mock" wrong_public_key = b"mock_mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute & Assert @@ -460,7 +461,7 @@ def test_get_node_id_wrong_public_key(self) -> None: state: State = self.state_factory() public_key = b"mock" wrong_public_key = b"mock_mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) # Execute state.create_node(ping_interval=10, public_key=public_key) @@ -475,7 +476,7 @@ def test_get_nodes_invalid_run_id(self) -> None: """Test retrieving all node_ids with invalid run_id.""" # Prepare state: State = self.state_factory() - state.create_run("mock/mock", "v1.0.0") + state.create_run("mock/mock", "v1.0.0", {}) invalid_run_id = 61016 state.create_node(ping_interval=10) @@ -489,7 +490,7 @@ def test_num_task_ins(self) -> None: """Test if num_tasks returns correct number of not delivered task_ins.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_0 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) task_1 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) @@ -507,7 +508,7 @@ def test_num_task_res(self) -> None: """Test if num_tasks returns correct number of not delivered task_res.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_0 = create_task_res( producer_node_id=0, anonymous=True, ancestry=["1"], run_id=run_id ) @@ -608,7 +609,7 @@ def test_acknowledge_ping(self) -> None: """Test if acknowledge_ping works and if get_nodes return online nodes.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_ids = [state.create_node(ping_interval=10) for _ in range(100)] for node_id in node_ids[:70]: state.acknowledge_ping(node_id, ping_interval=30) @@ -627,7 +628,7 @@ def test_node_unavailable_error(self) -> None: """Test if get_task_res return TaskRes containing node unavailable error.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id_0 = state.create_node(ping_interval=90) node_id_1 = state.create_node(ping_interval=30) # Create and store TaskIns diff --git a/src/py/flwr/server/typing.py b/src/py/flwr/server/typing.py index 01143af74392..cdb1c0db4fe7 100644 --- a/src/py/flwr/server/typing.py +++ b/src/py/flwr/server/typing.py @@ -20,6 +20,8 @@ from flwr.common import Context from .driver import Driver +from .serverapp_components import ServerAppComponents ServerAppCallable = Callable[[Driver, Context], None] Workflow = Callable[[Driver, Context], None] +ServerFn = Callable[[Context], ServerAppComponents] diff --git a/src/py/flwr/simulation/app.py b/src/py/flwr/simulation/app.py index 50a7dcad70a9..446b0bdeba38 100644 --- a/src/py/flwr/simulation/app.py +++ b/src/py/flwr/simulation/app.py @@ -29,12 +29,14 @@ from flwr.client import ClientFnExt from flwr.common import EventType, event -from flwr.common.logger import log, set_logger_propagation +from flwr.common.constant import NODE_ID_NUM_BYTES +from flwr.common.logger import log, set_logger_propagation, warn_unsupported_feature from flwr.server.client_manager import ClientManager from flwr.server.history import History from flwr.server.server import Server, init_defaults, run_fl from flwr.server.server_config import ServerConfig from flwr.server.strategy import Strategy +from flwr.server.superlink.state.utils import generate_rand_int_from_bytes from flwr.simulation.ray_transport.ray_actor import ( ClientAppActor, VirtualClientEngineActor, @@ -51,7 +53,7 @@ `start_simulation( *, client_fn: ClientFn, - num_clients: Optional[int] = None, + num_clients: int, clients_ids: Optional[List[str]] = None, client_resources: Optional[Dict[str, float]] = None, server: Optional[Server] = None, @@ -70,13 +72,29 @@ """ +NodeToPartitionMapping = Dict[int, int] + + +def _create_node_id_to_partition_mapping( + num_clients: int, +) -> NodeToPartitionMapping: + """Generate a node_id:partition_id mapping.""" + nodes_mapping: NodeToPartitionMapping = {} # {node-id; partition-id} + for i in range(num_clients): + while True: + node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES) + if node_id not in nodes_mapping: + break + nodes_mapping[node_id] = i + return nodes_mapping + # pylint: disable=too-many-arguments,too-many-statements,too-many-branches def start_simulation( *, client_fn: ClientFnExt, - num_clients: Optional[int] = None, - clients_ids: Optional[List[str]] = None, + num_clients: int, + clients_ids: Optional[List[str]] = None, # UNSUPPORTED, WILL BE REMOVED client_resources: Optional[Dict[str, float]] = None, server: Optional[Server] = None, config: Optional[ServerConfig] = None, @@ -102,13 +120,14 @@ def start_simulation( (model, dataset, hyperparameters, ...) should be (re-)created in either the call to `client_fn` or the call to any of the client methods (e.g., load evaluation data in the `evaluate` method itself). - num_clients : Optional[int] - The total number of clients in this simulation. This must be set if - `clients_ids` is not set and vice-versa. + num_clients : int + The total number of clients in this simulation. clients_ids : Optional[List[str]] + UNSUPPORTED, WILL BE REMOVED. USE `num_clients` INSTEAD. List `client_id`s for each client. This is only required if `num_clients` is not set. Setting both `num_clients` and `clients_ids` with `len(clients_ids)` not equal to `num_clients` generates an error. + Using this argument will raise an error. client_resources : Optional[Dict[str, float]] (default: `{"num_cpus": 1, "num_gpus": 0.0}`) CPU and GPU resources for a single client. Supported keys are `num_cpus` and `num_gpus`. To understand the GPU utilization caused by @@ -158,7 +177,6 @@ def start_simulation( is an advanced feature. For all details, please refer to the Ray documentation: https://docs.ray.io/en/latest/ray-core/scheduling/index.html - Returns ------- hist : flwr.server.history.History @@ -170,6 +188,14 @@ def start_simulation( {"num_clients": len(clients_ids) if clients_ids is not None else num_clients}, ) + if clients_ids is not None: + warn_unsupported_feature( + "Passing `clients_ids` to `start_simulation` is deprecated and not longer " + "used by `start_simulation`. Use `num_clients` exclusively instead." + ) + log(ERROR, "`clients_ids` argument used.") + sys.exit() + # Set logger propagation loop: Optional[asyncio.AbstractEventLoop] = None try: @@ -196,20 +222,8 @@ def start_simulation( initialized_config, ) - # clients_ids takes precedence - cids: List[str] - if clients_ids is not None: - if (num_clients is not None) and (len(clients_ids) != num_clients): - log(ERROR, INVALID_ARGUMENTS_START_SIMULATION) - sys.exit() - else: - cids = clients_ids - else: - if num_clients is None: - log(ERROR, INVALID_ARGUMENTS_START_SIMULATION) - sys.exit() - else: - cids = [str(x) for x in range(num_clients)] + # Create node-id to partition-id mapping + nodes_mapping = _create_node_id_to_partition_mapping(num_clients) # Default arguments for Ray initialization if not ray_init_args: @@ -308,10 +322,11 @@ def update_resources(f_stop: threading.Event) -> None: ) # Register one RayClientProxy object for each client with the ClientManager - for cid in cids: + for node_id, partition_id in nodes_mapping.items(): client_proxy = RayActorClientProxy( client_fn=client_fn, - cid=cid, + node_id=node_id, + partition_id=partition_id, actor_pool=pool, ) initialized_server.client_manager().register(client=client_proxy) diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 7afffb865334..b1c9d2b9c0c1 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -14,7 +14,6 @@ # ============================================================================== """Ray-based Flower Actor and ActorPool implementation.""" -import asyncio import threading from abc import ABC from logging import DEBUG, ERROR, WARNING @@ -411,9 +410,7 @@ def __init__( self.client_resources = client_resources # Queue of idle actors - self.pool: "asyncio.Queue[Type[VirtualClientEngineActor]]" = asyncio.Queue( - maxsize=1024 - ) + self.pool: List[VirtualClientEngineActor] = [] self.num_actors = 0 # Resolve arguments to pass during actor init @@ -427,38 +424,37 @@ def __init__( # Figure out how many actors can be created given the cluster resources # and the resources the user indicates each VirtualClient will need self.actors_capacity = pool_size_from_resources(client_resources) - self._future_to_actor: Dict[Any, Type[VirtualClientEngineActor]] = {} + self._future_to_actor: Dict[Any, VirtualClientEngineActor] = {} def is_actor_available(self) -> bool: """Return true if there is an idle actor.""" - return self.pool.qsize() > 0 + return len(self.pool) > 0 - async def add_actors_to_pool(self, num_actors: int) -> None: + def add_actors_to_pool(self, num_actors: int) -> None: """Add actors to the pool. This method may be executed also if new resources are added to your Ray cluster (e.g. you add a new node). """ for _ in range(num_actors): - await self.pool.put(self.create_actor_fn()) # type: ignore + self.pool.append(self.create_actor_fn()) # type: ignore self.num_actors += num_actors - async def terminate_all_actors(self) -> None: + def terminate_all_actors(self) -> None: """Terminate actors in pool.""" num_terminated = 0 - while self.pool.qsize(): - actor = await self.pool.get() + for actor in self.pool: actor.terminate.remote() # type: ignore num_terminated += 1 log(DEBUG, "Terminated %i actors", num_terminated) - async def submit( + def submit( self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context] ) -> Any: """On idle actor, submit job and return future.""" # Remove idle actor from pool - actor = await self.pool.get() + actor = self.pool.pop() # Submit job to actor app_fn, mssg, cid, context = job future = actor_fn(actor, app_fn, mssg, cid, context) @@ -467,18 +463,18 @@ async def submit( self._future_to_actor[future] = actor return future - async def add_actor_back_to_pool(self, future: Any) -> None: + def add_actor_back_to_pool(self, future: Any) -> None: """Ad actor assigned to run future back into the pool.""" actor = self._future_to_actor.pop(future) - await self.pool.put(actor) + self.pool.append(actor) - async def fetch_result_and_return_actor_to_pool( + def fetch_result_and_return_actor_to_pool( self, future: Any ) -> Tuple[Message, Context]: """Pull result given a future and add actor back to pool.""" - # Get actor that ran job - await self.add_actor_back_to_pool(future) # Retrieve result for object store # Instead of doing ray.get(future) we await it - _, out_mssg, updated_context = await future + _, out_mssg, updated_context = ray.get(future) + # Get actor that ran job + self.add_actor_back_to_pool(future) return out_mssg, updated_context diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 3a9712df978e..31bc22c84bd5 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -44,16 +44,22 @@ class RayActorClientProxy(ClientProxy): """Flower client proxy which delegates work using Ray.""" def __init__( - self, client_fn: ClientFnExt, cid: str, actor_pool: VirtualClientEngineActorPool + self, + client_fn: ClientFnExt, + node_id: int, + partition_id: int, + actor_pool: VirtualClientEngineActorPool, ): - super().__init__(cid) + super().__init__(cid=str(node_id)) + self.node_id = node_id + self.partition_id = partition_id def _load_app() -> ClientApp: return ClientApp(client_fn=client_fn) self.app_fn = _load_app self.actor_pool = actor_pool - self.proxy_state = NodeState(partition_id=int(self.cid)) + self.proxy_state = NodeState(partition_id=self.partition_id) def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: """Sumbit a message to the ActorPool.""" @@ -67,11 +73,13 @@ def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: try: self.actor_pool.submit_client_job( - lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), - (self.app_fn, message, self.cid, state), + lambda a, a_fn, mssg, partition_id, state: a.run.remote( + a_fn, mssg, partition_id, state + ), + (self.app_fn, message, str(self.partition_id), state), ) out_mssg, updated_context = self.actor_pool.get_client_result( - self.cid, timeout + str(self.partition_id), timeout ) # Update state @@ -103,7 +111,7 @@ def _wrap_recordset_in_message( message_id="", group_id=str(group_id) if group_id is not None else "", src_node_id=0, - dst_node_id=int(self.cid), + dst_node_id=self.node_id, reply_to_message="", ttl=timeout if timeout else DEFAULT_TTL, message_type=message_type, diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index df059135b4e0..83f6cfe05313 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -39,6 +39,7 @@ recordset_to_getpropertiesres, ) from flwr.common.recordset_compat_test import _get_valid_getpropertiesins +from flwr.simulation.app import _create_node_id_to_partition_mapping from flwr.simulation.ray_transport.ray_actor import ( ClientAppActor, VirtualClientEngineActor, @@ -68,9 +69,7 @@ def get_dummy_client( node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument ) -> Client: """Return a DummyClient converted to Client type.""" - if partition_id is None: - raise ValueError("`partition_id` is not set.") - return DummyClient(partition_id).to_client() + return DummyClient(node_id).to_client() def prep( @@ -91,13 +90,15 @@ def create_actor_fn() -> Type[VirtualClientEngineActor]: # Create 373 client proxies num_proxies = 373 # a prime number + mapping = _create_node_id_to_partition_mapping(num_proxies) proxies = [ RayActorClientProxy( client_fn=get_dummy_client, - cid=str(cid), + node_id=node_id, + partition_id=partition_id, actor_pool=pool, ) - for cid in range(num_proxies) + for node_id, partition_id in mapping.items() ] return proxies, pool @@ -127,7 +128,7 @@ def test_cid_consistency_one_at_a_time() -> None: res = recordset_to_getpropertiesres(message_out.content) - assert int(prox.cid) * pi == res.properties["result"] + assert int(prox.node_id) * pi == res.properties["result"] ray.shutdown() @@ -160,21 +161,21 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: ) prox.actor_pool.submit_client_job( lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), - (prox.app_fn, message, prox.cid, state), + (prox.app_fn, message, str(prox.node_id), state), ) # fetch results one at a time shuffle(proxies) for prox in proxies: message_out, updated_context = prox.actor_pool.get_client_result( - prox.cid, timeout=None + str(prox.node_id), timeout=None ) prox.proxy_state.update_context(run_id, context=updated_context) res = recordset_to_getpropertiesres(message_out.content) - assert int(prox.cid) * pi == res.properties["result"] + assert prox.node_id * pi == res.properties["result"] assert ( - str(int(prox.cid) * pi) + str(prox.node_id * pi) == prox.proxy_state.retrieve_context(run_id).state.configs_records[ "result" ]["result"] @@ -187,7 +188,7 @@ def test_cid_consistency_without_proxies() -> None: """Test cid consistency of jobs submitted/retrieved to/from pool w/o ClientProxy.""" proxies, pool = prep() num_clients = len(proxies) - cids = [str(cid) for cid in range(num_clients)] + node_ids = list(range(num_clients)) getproperties_ins = _get_valid_getpropertiesins() recordset = getpropertiesins_to_recordset(getproperties_ins) @@ -196,8 +197,8 @@ def _load_app() -> ClientApp: return ClientApp(client_fn=get_dummy_client) # submit all jobs (collect later) - shuffle(cids) - for cid in cids: + shuffle(node_ids) + for node_id in node_ids: message = Message( content=recordset, metadata=Metadata( @@ -205,27 +206,27 @@ def _load_app() -> ClientApp: message_id="", group_id=str(0), src_node_id=0, - dst_node_id=12345, + dst_node_id=node_id, reply_to_message="", ttl=DEFAULT_TTL, message_type=MessageTypeLegacy.GET_PROPERTIES, ), ) pool.submit_client_job( - lambda a, c_fn, j_fn, cid_, state: a.run.remote(c_fn, j_fn, cid_, state), + lambda a, c_fn, j_fn, nid_, state: a.run.remote(c_fn, j_fn, nid_, state), ( _load_app, message, - cid, - Context(state=RecordSet(), partition_id=int(cid)), + str(node_id), + Context(state=RecordSet(), run_config={}, partition_id=node_id), ), ) # fetch results one at a time - shuffle(cids) - for cid in cids: - message_out, _ = pool.get_client_result(cid, timeout=None) + shuffle(node_ids) + for node_id in node_ids: + message_out, _ = pool.get_client_result(str(node_id), timeout=None) res = recordset_to_getpropertiesres(message_out.content) - assert int(cid) * pi == res.properties["result"] + assert node_id * pi == res.properties["result"] ray.shutdown() diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 7c7a412a245b..de101fe3e09f 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -22,7 +22,7 @@ import traceback from logging import DEBUG, ERROR, INFO, WARNING from time import sleep -from typing import Optional +from typing import Dict, Optional from flwr.client import ClientApp from flwr.common import EventType, event, log @@ -126,16 +126,25 @@ def run_simulation( def run_serverapp_th( server_app_attr: Optional[str], server_app: Optional[ServerApp], + server_app_run_config: Dict[str, str], driver: Driver, app_dir: str, - f_stop: asyncio.Event, + f_stop: threading.Event, + has_exception: threading.Event, enable_tf_gpu_growth: bool, delay_launch: int = 3, ) -> threading.Thread: """Run SeverApp in a thread.""" - def server_th_with_start_checks( # type: ignore - tf_gpu_growth: bool, stop_event: asyncio.Event, **kwargs + def server_th_with_start_checks( + tf_gpu_growth: bool, + stop_event: threading.Event, + exception_event: threading.Event, + _driver: Driver, + _server_app_dir: str, + _server_app_run_config: Dict[str, str], + _server_app_attr: Optional[str], + _server_app: Optional[ServerApp], ) -> None: """Run SeverApp, after check if GPU memory growth has to be set. @@ -147,10 +156,18 @@ def server_th_with_start_checks( # type: ignore enable_gpu_growth() # Run ServerApp - run(**kwargs) + run( + driver=_driver, + server_app_dir=_server_app_dir, + server_app_run_config=_server_app_run_config, + server_app_attr=_server_app_attr, + loaded_server_app=_server_app, + ) except Exception as ex: # pylint: disable=broad-exception-caught log(ERROR, "ServerApp thread raised an exception: %s", ex) log(ERROR, traceback.format_exc()) + exception_event.set() + raise finally: log(DEBUG, "ServerApp finished running.") # Upon completion, trigger stop event if one was passed @@ -160,13 +177,16 @@ def server_th_with_start_checks( # type: ignore serverapp_th = threading.Thread( target=server_th_with_start_checks, - args=(enable_tf_gpu_growth, f_stop), - kwargs={ - "server_app_attr": server_app_attr, - "loaded_server_app": server_app, - "driver": driver, - "server_app_dir": app_dir, - }, + args=( + enable_tf_gpu_growth, + f_stop, + has_exception, + driver, + app_dir, + server_app_run_config, + server_app_attr, + server_app, + ), ) sleep(delay_launch) serverapp_th.start() @@ -196,20 +216,18 @@ def _main_loop( server_app: Optional[ServerApp] = None, server_app_attr: Optional[str] = None, ) -> None: - """Launch SuperLink with Simulation Engine, then ServerApp on a separate thread. - - Everything runs on the main thread or a separate one, depending on whether the main - thread already contains a running Asyncio event loop. This is the case if running - the Simulation Engine on a Jupyter/Colab notebook. - """ + """Launch SuperLink with Simulation Engine, then ServerApp on a separate thread.""" # Initialize StateFactory state_factory = StateFactory(":flwr-in-memory-state:") - f_stop = asyncio.Event() + f_stop = threading.Event() + # A Threading event to indicate if an exception was raised in the ServerApp thread + server_app_thread_has_exception = threading.Event() serverapp_th = None try: # Create run (with empty fab_id and fab_version) - run_id_ = state_factory.state().create_run("", "") + run_id_ = state_factory.state().create_run("", "", {}) + server_app_run_config: Dict[str, str] = {} if run_id: _override_run_id(state_factory, run_id_to_replace=run_id_, run_id=run_id) @@ -222,9 +240,11 @@ def _main_loop( serverapp_th = run_serverapp_th( server_app_attr=server_app_attr, server_app=server_app, + server_app_run_config=server_app_run_config, driver=driver, app_dir=app_dir, f_stop=f_stop, + has_exception=server_app_thread_has_exception, enable_tf_gpu_growth=enable_tf_gpu_growth, ) @@ -253,6 +273,8 @@ def _main_loop( event(EventType.RUN_SUPERLINK_LEAVE) if serverapp_th: serverapp_th.join() + if server_app_thread_has_exception.is_set(): + raise RuntimeError("Exception in ServerApp thread") log(DEBUG, "Stopping Simulation Engine now.") @@ -349,7 +371,6 @@ def _run_simulation( # Convert config to original JSON-stream format backend_config_stream = json.dumps(backend_config) - simulation_engine_th = None args = ( num_supernodes, backend_name, @@ -363,31 +384,26 @@ def _run_simulation( server_app_attr, ) # Detect if there is an Asyncio event loop already running. - # If yes, run everything on a separate thread. In environments - # like Jupyter/Colab notebooks, there is an event loop present. - run_in_thread = False + # If yes, disable logger propagation. In environmnets + # like Jupyter/Colab notebooks, it's often better to do this. + asyncio_loop_running = False try: _ = ( asyncio.get_running_loop() ) # Raises RuntimeError if no event loop is present log(DEBUG, "Asyncio event loop already running.") - run_in_thread = True + asyncio_loop_running = True except RuntimeError: - log(DEBUG, "No asyncio event loop running") + pass finally: - if run_in_thread: + if asyncio_loop_running: # Set logger propagation to False to prevent duplicated log output in Colab. logger = set_logger_propagation(logger, False) - log(DEBUG, "Starting Simulation Engine on a new thread.") - simulation_engine_th = threading.Thread(target=_main_loop, args=args) - simulation_engine_th.start() - simulation_engine_th.join() - else: - log(DEBUG, "Starting Simulation Engine on the main thread.") - _main_loop(*args) + + _main_loop(*args) def _parse_args_run_simulation() -> argparse.ArgumentParser: diff --git a/src/py/flwr/superexec/app.py b/src/py/flwr/superexec/app.py index fa89e83b5e75..b4d4b462bbcc 100644 --- a/src/py/flwr/superexec/app.py +++ b/src/py/flwr/superexec/app.py @@ -77,6 +77,7 @@ def _parse_args_run_superexec() -> argparse.ArgumentParser: parser.add_argument( "executor", help="For example: `deployment:exec` or `project.package.module:wrapper.exec`.", + default="flwr.superexec.deployment:executor", ) parser.add_argument( "--address", diff --git a/src/py/flwr/superexec/deployment.py b/src/py/flwr/superexec/deployment.py index 6f931e81eefa..d117f280b38d 100644 --- a/src/py/flwr/superexec/deployment.py +++ b/src/py/flwr/superexec/deployment.py @@ -17,7 +17,7 @@ import subprocess import sys from logging import ERROR, INFO -from typing import Optional +from typing import Dict, Optional from typing_extensions import override @@ -53,18 +53,29 @@ def _connect(self) -> None: ) self.stub = DriverStub(channel) - def _create_run(self, fab_id: str, fab_version: str) -> int: + def _create_run( + self, + fab_id: str, + fab_version: str, + override_config: Dict[str, str], + ) -> int: if self.stub is None: self._connect() assert self.stub is not None - req = CreateRunRequest(fab_id=fab_id, fab_version=fab_version) + req = CreateRunRequest( + fab_id=fab_id, + fab_version=fab_version, + override_config=override_config, + ) res = self.stub.CreateRun(request=req) return int(res.run_id) @override - def start_run(self, fab_file: bytes) -> Optional[RunTracker]: + def start_run( + self, fab_file: bytes, override_config: Dict[str, str] + ) -> Optional[RunTracker]: """Start run using the Flower Deployment Engine.""" try: # Install FAB to flwr dir @@ -79,7 +90,7 @@ def start_run(self, fab_file: bytes) -> Optional[RunTracker]: ) # Call SuperLink to create run - run_id: int = self._create_run(fab_id, fab_version) + run_id: int = self._create_run(fab_id, fab_version, override_config) log(INFO, "Created run %s", str(run_id)) # Start ServerApp diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index e5ef2bd59a79..61a7bc289af3 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -45,7 +45,10 @@ def StartRun( """Create run ID.""" log(INFO, "ExecServicer.StartRun") - run = self.executor.start_run(request.fab_file) + run = self.executor.start_run( + request.fab_file, + dict(request.override_config.items()), + ) if run is None: log(ERROR, "Executor failed to start run") diff --git a/src/py/flwr/superexec/exec_servicer_test.py b/src/py/flwr/superexec/exec_servicer_test.py index 41f67b74c48b..edc91df4530e 100644 --- a/src/py/flwr/superexec/exec_servicer_test.py +++ b/src/py/flwr/superexec/exec_servicer_test.py @@ -36,7 +36,7 @@ def test_start_run() -> None: run_res.proc = proc executor = MagicMock() - executor.start_run = lambda _: run_res + executor.start_run = lambda _, __: run_res context_mock = MagicMock() diff --git a/src/py/flwr/superexec/executor.py b/src/py/flwr/superexec/executor.py index f85ac4c157fc..85b6e5c3e095 100644 --- a/src/py/flwr/superexec/executor.py +++ b/src/py/flwr/superexec/executor.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from subprocess import Popen -from typing import Optional +from typing import Dict, Optional @dataclass @@ -33,8 +33,7 @@ class Executor(ABC): @abstractmethod def start_run( - self, - fab_file: bytes, + self, fab_file: bytes, override_config: Dict[str, str] ) -> Optional[RunTracker]: """Start a run using the given Flower FAB ID and version.