diff --git a/.github/workflows/_docker-build.yml b/.github/workflows/_docker-build.yml index b8ddd355eb8e..7d78ea881034 100644 --- a/.github/workflows/_docker-build.yml +++ b/.github/workflows/_docker-build.yml @@ -81,7 +81,7 @@ jobs: - name: Set up QEMU if: matrix.platform.qemu != '' - uses: docker/setup-qemu-action@68827325e0b33c7199eb31dd4e31fbe9023e06e3 # v3.0.0 + uses: docker/setup-qemu-action@5927c834f5b4fdf503fca6f4c7eccda82949e1ee # v3.1.0 with: platforms: ${{ matrix.platform.qemu }} @@ -92,7 +92,7 @@ jobs: images: ${{ inputs.namespace-repository }} - name: Set up Docker Buildx - uses: docker/setup-buildx-action@d70bba72b1f3fd22344832f00baa16ece964efeb # v3.3.0 + uses: docker/setup-buildx-action@4fd812986e6c8c2a69e18311145f9371337f27d4 # v3.4.0 - name: Login to Docker Hub uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3.2.0 @@ -122,7 +122,7 @@ jobs: touch "/tmp/digests/${digest#sha256:}" - name: Upload digest - uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # v4.3.3 + uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4.3.4 with: name: digests-${{ steps.build-id.outputs.id }}-${{ matrix.platform.name }} path: /tmp/digests/* @@ -138,7 +138,7 @@ jobs: metadata: ${{ steps.meta.outputs.json }} steps: - name: Download digests - uses: actions/download-artifact@65a9edc5881444af0b9093a5e628f2fe47ea3b2e # v4.1.7 + uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: pattern: digests-${{ needs.build.outputs.build-id }}-* path: /tmp/digests @@ -152,7 +152,7 @@ jobs: tags: ${{ inputs.tags }} - name: Set up Docker Buildx - uses: docker/setup-buildx-action@d70bba72b1f3fd22344832f00baa16ece964efeb # v3.3.0 + uses: docker/setup-buildx-action@4fd812986e6c8c2a69e18311145f9371337f27d4 # v3.4.0 - name: Login to Docker Hub uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3.2.0 diff --git a/.github/workflows/docker-images.yml b/.github/workflows/docker-images.yml deleted file mode 100644 index e341ae62e3f7..000000000000 --- a/.github/workflows/docker-images.yml +++ /dev/null @@ -1,75 +0,0 @@ -name: Build docker images - -on: - workflow_dispatch: - inputs: - flwr-version: - description: "Version of Flower." - required: true - type: string - -permissions: - contents: read - -jobs: - parameters: - name: Collect build parameters - runs-on: ubuntu-22.04 - timeout-minutes: 10 - outputs: - pip-version: ${{ steps.versions.outputs.pip-version }} - setuptools-version: ${{ steps.versions.outputs.setuptools-version }} - matrix: ${{ steps.matrix.outputs.matrix }} - steps: - - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - - - uses: ./.github/actions/bootstrap - id: bootstrap - - - id: versions - run: | - echo "pip-version=${{ steps.bootstrap.outputs.pip-version }}" >> "$GITHUB_OUTPUT" - echo "setuptools-version=${{ steps.bootstrap.outputs.setuptools-version }}" >> "$GITHUB_OUTPUT" - - - id: matrix - run: | - python dev/build-docker-image-matrix.py --flwr-version ${{ github.event.inputs.flwr-version }} > matrix.json - echo "matrix=$(cat matrix.json)" >> $GITHUB_OUTPUT - - build-base-images: - name: Build base images - uses: ./.github/workflows/_docker-build.yml - needs: parameters - strategy: - fail-fast: false - matrix: ${{ fromJson(needs.parameters.outputs.matrix).base }} - with: - namespace-repository: ${{ matrix.images.namespace_repository }} - file-dir: ${{ matrix.images.file_dir }} - build-args: | - PYTHON_VERSION=${{ matrix.images.python_version }} - PIP_VERSION=${{ needs.parameters.outputs.pip-version }} - SETUPTOOLS_VERSION=${{ needs.parameters.outputs.setuptools-version }} - DISTRO=${{ matrix.images.distro.name }} - DISTRO_VERSION=${{ matrix.images.distro.version }} - FLWR_VERSION=${{ matrix.images.flwr_version }} - tags: ${{ matrix.images.tag }} - secrets: - dockerhub-user: ${{ secrets.DOCKERHUB_USERNAME }} - dockerhub-token: ${{ secrets.DOCKERHUB_TOKEN }} - - build-binary-images: - name: Build binary images - uses: ./.github/workflows/_docker-build.yml - needs: [parameters, build-base-images] - strategy: - fail-fast: false - matrix: ${{ fromJson(needs.parameters.outputs.matrix).binary }} - with: - namespace-repository: ${{ matrix.images.namespace_repository }} - file-dir: ${{ matrix.images.file_dir }} - build-args: BASE_IMAGE=${{ matrix.images.base_image }} - tags: ${{ matrix.images.tags }} - secrets: - dockerhub-user: ${{ secrets.DOCKERHUB_USERNAME }} - dockerhub-token: ${{ secrets.DOCKERHUB_TOKEN }} diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index 4cbee6f770d6..11c0073fcf1f 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -66,44 +66,39 @@ jobs: - directory: bare-client-auth - - directory: jax + - directory: framework-jax - - directory: pytorch + - directory: framework-pytorch dataset: | from torchvision.datasets import CIFAR10 CIFAR10('./data', download=True) - - directory: tensorflow + - directory: framework-tensorflow dataset: | import tensorflow as tf tf.keras.datasets.cifar10.load_data() - - directory: tabnet - dataset: | - import tensorflow_datasets as tfds - tfds.load(name='iris', split=tfds.Split.TRAIN) - - - directory: opacus + - directory: framework-opacus dataset: | from torchvision.datasets import CIFAR10 CIFAR10('./data', download=True) - - directory: pytorch-lightning + - directory: framework-pytorch-lightning dataset: | from torchvision.datasets import MNIST MNIST('./data', download=True) - - directory: scikit-learn + - directory: framework-scikit-learn dataset: | import openml openml.datasets.get_dataset(554) - - directory: fastai + - directory: framework-fastai dataset: | from fastai.vision.all import untar_data, URLs untar_data(URLs.MNIST) - - directory: pandas + - directory: framework-pandas dataset: | from pathlib import Path from sklearn.datasets import load_iris @@ -145,7 +140,7 @@ jobs: run: python -c "${{ matrix.dataset }}" - name: Run edge client test if: ${{ matrix.directory != 'bare-client-auth' }} - run: ./../test.sh "${{ matrix.directory }}" + run: ./../test_legacy.sh "${{ matrix.directory }}" - name: Run virtual client test if: ${{ matrix.directory != 'bare-client-auth' }} run: python simulation.py @@ -154,16 +149,16 @@ jobs: run: python simulation_next.py - name: Run driver test if: ${{ matrix.directory != 'bare-client-auth' }} - run: ./../test_driver.sh "${{ matrix.directory }}" + run: ./../test_superlink.sh "${{ matrix.directory }}" - name: Run driver test with REST if: ${{ matrix.directory == 'bare' }} - run: ./../test_driver.sh bare rest + run: ./../test_superlink.sh bare rest - name: Run driver test with SQLite database if: ${{ matrix.directory == 'bare' }} - run: ./../test_driver.sh bare sqlite + run: ./../test_superlink.sh bare sqlite - name: Run driver test with client authentication if: ${{ matrix.directory == 'bare-client-auth' }} - run: ./../test_driver.sh bare client-auth + run: ./../test_superlink.sh bare client-auth - name: Run reconnection test with SQLite database if: ${{ matrix.directory == 'bare' }} run: ./../test_reconnection.sh sqlite diff --git a/.github/workflows/framework-release.yml b/.github/workflows/framework-release.yml index a941b47d58fc..812d5b1e398e 100644 --- a/.github/workflows/framework-release.yml +++ b/.github/workflows/framework-release.yml @@ -43,3 +43,70 @@ jobs: curl $tar_url --output dist/$tar_name python -m poetry publish -u __token__ -p ${{ secrets.PYPI_TOKEN_RELEASE_FLWR }} + + parameters: + if: ${{ github.repository == 'adap/flower' }} + name: Collect docker build parameters + runs-on: ubuntu-22.04 + timeout-minutes: 10 + needs: publish + outputs: + pip-version: ${{ steps.versions.outputs.pip-version }} + setuptools-version: ${{ steps.versions.outputs.setuptools-version }} + matrix: ${{ steps.matrix.outputs.matrix }} + steps: + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 + + - uses: ./.github/actions/bootstrap + id: bootstrap + + - id: versions + run: | + echo "pip-version=${{ steps.bootstrap.outputs.pip-version }}" >> "$GITHUB_OUTPUT" + echo "setuptools-version=${{ steps.bootstrap.outputs.setuptools-version }}" >> "$GITHUB_OUTPUT" + + - id: matrix + run: | + FLWR_VERSION=$(poetry version -s) + python dev/build-docker-image-matrix.py --flwr-version "${FLWR_VERSION}" > matrix.json + echo "matrix=$(cat matrix.json)" >> $GITHUB_OUTPUT + + build-base-images: + if: ${{ github.repository == 'adap/flower' }} + name: Build base images + uses: ./.github/workflows/_docker-build.yml + needs: parameters + strategy: + fail-fast: false + matrix: ${{ fromJson(needs.parameters.outputs.matrix).base }} + with: + namespace-repository: ${{ matrix.images.namespace_repository }} + file-dir: ${{ matrix.images.file_dir }} + build-args: | + PYTHON_VERSION=${{ matrix.images.python_version }} + PIP_VERSION=${{ needs.parameters.outputs.pip-version }} + SETUPTOOLS_VERSION=${{ needs.parameters.outputs.setuptools-version }} + DISTRO=${{ matrix.images.distro.name }} + DISTRO_VERSION=${{ matrix.images.distro.version }} + FLWR_VERSION=${{ matrix.images.flwr_version }} + tags: ${{ matrix.images.tag }} + secrets: + dockerhub-user: ${{ secrets.DOCKERHUB_USERNAME }} + dockerhub-token: ${{ secrets.DOCKERHUB_TOKEN }} + + build-binary-images: + if: ${{ github.repository == 'adap/flower' }} + name: Build binary images + uses: ./.github/workflows/_docker-build.yml + needs: [parameters, build-base-images] + strategy: + fail-fast: false + matrix: ${{ fromJson(needs.parameters.outputs.matrix).binary }} + with: + namespace-repository: ${{ matrix.images.namespace_repository }} + file-dir: ${{ matrix.images.file_dir }} + build-args: BASE_IMAGE=${{ matrix.images.base_image }} + tags: ${{ matrix.images.tags }} + secrets: + dockerhub-user: ${{ secrets.DOCKERHUB_USERNAME }} + dockerhub-token: ${{ secrets.DOCKERHUB_TOKEN }} diff --git a/.github/workflows/release-nightly.yml b/.github/workflows/release-nightly.yml index 2b72190bede5..97751aafc031 100644 --- a/.github/workflows/release-nightly.yml +++ b/.github/workflows/release-nightly.yml @@ -69,7 +69,8 @@ jobs: images: [ { repository: "flwr/superlink", file_dir: "src/docker/superlink" }, { repository: "flwr/supernode", file_dir: "src/docker/supernode" }, - { repository: "flwr/serverapp", file_dir: "src/docker/serverapp" } + { repository: "flwr/serverapp", file_dir: "src/docker/serverapp" }, + { repository: "flwr/superexec", file_dir: "src/docker/superexec" } ] with: namespace-repository: ${{ matrix.images.repository }} diff --git a/datasets/README.md b/datasets/README.md index 1d8014d57ea3..50fc67376ae4 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -42,11 +42,12 @@ Create **custom partitioning schemes** or choose from the **implemented [partiti * IID partitioning `IidPartitioner(num_partitions)` * Dirichlet partitioning `DirichletPartitioner(num_partitions, partition_by, alpha)` * InnerDirichlet partitioning `InnerDirichletPartitioner(partition_sizes, partition_by, alpha)` -* Natural ID partitioner `NaturalIdPartitioner(partition_by)` -* Size partitioner (the abstract base class for the partitioners dictating the division based the number of samples) `SizePartitioner` -* Linear partitioner `LinearPartitioner(num_partitions)` -* Square partitioner `SquarePartitioner(num_partitions)` -* Exponential partitioner `ExponentialPartitioner(num_partitions)` +* Pathological partitioning `PathologicalPartitioner(num_partitions, partition_by, num_classes_per_partition, class_assignment_mode)` +* Natural ID partitioning `NaturalIdPartitioner(partition_by)` +* Size based partitioning (the abstract base class for the partitioners dictating the division based the number of samples) `SizePartitioner` +* Linear partitioning `LinearPartitioner(num_partitions)` +* Square partitioning `SquarePartitioner(num_partitions)` +* Exponential partitioning `ExponentialPartitioner(num_partitions)` * more to come in the future releases (contributions are welcome).

Comparison of partitioning schemes. diff --git a/datasets/doc/source/conf.py b/datasets/doc/source/conf.py index 11285c375f96..c287ec318b5d 100644 --- a/datasets/doc/source/conf.py +++ b/datasets/doc/source/conf.py @@ -162,7 +162,7 @@ def find_test_modules(package_path): .. raw:: html
- + Open in Colab """ diff --git a/datasets/doc/source/index.rst b/datasets/doc/source/index.rst index bdcea7650bbc..fcc7920711bf 100644 --- a/datasets/doc/source/index.rst +++ b/datasets/doc/source/index.rst @@ -94,6 +94,7 @@ Here are a few of the ``Partitioner`` s that are available: (for a full list see * IID partitioning ``IidPartitioner(num_partitions)`` * Dirichlet partitioning ``DirichletPartitioner(num_partitions, partition_by, alpha)`` * InnerDirichlet partitioning ``InnerDirichletPartitioner(partition_sizes, partition_by, alpha)`` +* PathologicalPartitioner ``PathologicalPartitioner(num_partitions, partition_by, num_classes_per_partition, class_assignment_mode)`` * Natural ID partitioner ``NaturalIdPartitioner(partition_by)`` * Size partitioner (the abstract base class for the partitioners dictating the division based the number of samples) ``SizePartitioner`` * Linear partitioner ``LinearPartitioner(num_partitions)`` diff --git a/datasets/flwr_datasets/mock_utils_test.py b/datasets/flwr_datasets/mock_utils_test.py index 78aff1f1cdd7..bd49de8033de 100644 --- a/datasets/flwr_datasets/mock_utils_test.py +++ b/datasets/flwr_datasets/mock_utils_test.py @@ -190,7 +190,7 @@ def _generate_random_image_column( pil_imgs = [] for np_image in np_images: # Convert the NumPy array to a PIL image - pil_img_beg = Image.fromarray(np_image) # type: ignore + pil_img_beg = Image.fromarray(np_image) # Save the image to an in-memory bytes buffer in_memory_file = io.BytesIO() diff --git a/datasets/flwr_datasets/partitioner/__init__.py b/datasets/flwr_datasets/partitioner/__init__.py index 1fc00ed90323..0c75dbce387a 100644 --- a/datasets/flwr_datasets/partitioner/__init__.py +++ b/datasets/flwr_datasets/partitioner/__init__.py @@ -22,6 +22,7 @@ from .linear_partitioner import LinearPartitioner from .natural_id_partitioner import NaturalIdPartitioner from .partitioner import Partitioner +from .pathological_partitioner import PathologicalPartitioner from .shard_partitioner import ShardPartitioner from .size_partitioner import SizePartitioner from .square_partitioner import SquarePartitioner @@ -34,6 +35,7 @@ "LinearPartitioner", "NaturalIdPartitioner", "Partitioner", + "PathologicalPartitioner", "ShardPartitioner", "SizePartitioner", "SquarePartitioner", diff --git a/datasets/flwr_datasets/partitioner/pathological_partitioner.py b/datasets/flwr_datasets/partitioner/pathological_partitioner.py new file mode 100644 index 000000000000..1ee60d283044 --- /dev/null +++ b/datasets/flwr_datasets/partitioner/pathological_partitioner.py @@ -0,0 +1,305 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Pathological partitioner class that works with Hugging Face Datasets.""" + + +import warnings +from typing import Any, Dict, List, Literal, Optional + +import numpy as np + +import datasets +from flwr_datasets.common.typing import NDArray +from flwr_datasets.partitioner.partitioner import Partitioner + + +# pylint: disable=too-many-arguments, too-many-instance-attributes +class PathologicalPartitioner(Partitioner): + """Partition dataset such that each partition has a chosen number of classes. + + Implementation based on Federated Learning on Non-IID Data Silos: An Experimental + Study https://arxiv.org/pdf/2102.02079. + + The algorithm firstly determines which classe will be assigned to which partitions. + For each partition `num_classes_per_partition` are sampled in a way chosen in + `class_assignment_mode`. Given the information about the required classes for each + partition, it is determined into how many parts the samples corresponding to this + label should be divided. Such division is performed for each class. + + Parameters + ---------- + num_partitions : int + The total number of partitions that the data will be divided into. + partition_by : str + Column name of the labels (targets) based on which partitioning works. + num_classes_per_partition: int + The (exact) number of unique classes that each partition will have. + class_assignment_mode: Literal["random", "deterministic", "first-deterministic"] + The way how the classes are assigned to the partitions. The default is "random". + The possible values are: + + - "random": Randomly assign classes to the partitions. For each partition choose + the `num_classes_per_partition` classes without replacement. + - "first-deterministic": Assign the first class for each partition in a + deterministic way (class id is the partition_id % num_unique_classes). + The rest of the classes are assigned randomly. In case the number of + partitions is smaller than the number of unique classes, not all classes will + be used in the first iteration, otherwise all the classes will be used (such + it will be present in at least one partition). + - "deterministic": Assign all the classes to the partitions in a deterministic + way. Classes are assigned based on the formula: partion_id has classes + identified by the index: (partition_id + i) % num_unique_classes + where i in {0, ..., num_classes_per_partition}. So, partition 0 will have + classes 0, 1, 2, ..., `num_classes_per_partition`-1, partition 1 will have + classes 1, 2, 3, ...,`num_classes_per_partition`, .... + + The list representing the unique lables is sorted in ascending order. In case + of numbers starting from zero the class id corresponds to the number itself. + `class_assignment_mode="first-deterministic"` was used in the orginal paper, + here we provide the option to use the other modes as well. + shuffle: bool + Whether to randomize the order of samples. Shuffling applied after the + samples assignment to partitions. + seed: int + Seed used for dataset shuffling. It has no effect if `shuffle` is False. + + Examples + -------- + In order to mimic the original behavior of the paper follow the setup below + (the `class_assignment_mode="first-deterministic"`): + + >>> from flwr_datasets.partitioner import PathologicalPartitioner + >>> from flwr_datasets import FederatedDataset + >>> + >>> partitioner = PathologicalPartitioner( + >>> num_partitions=10, + >>> partition_by="label", + >>> num_classes_per_partition=2, + >>> class_assignment_mode="first-deterministic" + >>> ) + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) + >>> partition = fds.load_partition(0) + """ + + def __init__( + self, + num_partitions: int, + partition_by: str, + num_classes_per_partition: int, + class_assignment_mode: Literal[ + "random", "deterministic", "first-deterministic" + ] = "random", + shuffle: bool = True, + seed: Optional[int] = 42, + ) -> None: + super().__init__() + self._num_partitions = num_partitions + self._partition_by = partition_by + self._num_classes_per_partition = num_classes_per_partition + self._class_assignment_mode = class_assignment_mode + self._shuffle = shuffle + self._seed = seed + self._rng = np.random.default_rng(seed=self._seed) + + # Utility attributes + self._partition_id_to_indices: Dict[int, List[int]] = {} + self._partition_id_to_unique_labels: Dict[int, List[Any]] = { + pid: [] for pid in range(self._num_partitions) + } + self._unique_labels: List[Any] = [] + # Count in how many partitions the label is used + self._unique_label_to_times_used_counter: Dict[Any, int] = {} + self._partition_id_to_indices_determined = False + + def load_partition(self, partition_id: int) -> datasets.Dataset: + """Load a partition based on the partition index. + + Parameters + ---------- + partition_id : int + The index that corresponds to the requested partition. + + Returns + ------- + dataset_partition : Dataset + Single partition of a dataset. + """ + # The partitioning is done lazily - only when the first partition is + # requested. Only the first call creates the indices assignments for all the + # partition indices. + self._check_num_partitions_correctness_if_needed() + self._determine_partition_id_to_indices_if_needed() + return self.dataset.select(self._partition_id_to_indices[partition_id]) + + @property + def num_partitions(self) -> int: + """Total number of partitions.""" + self._check_num_partitions_correctness_if_needed() + self._determine_partition_id_to_indices_if_needed() + return self._num_partitions + + def _determine_partition_id_to_indices_if_needed(self) -> None: + """Create an assignment of indices to the partition indices.""" + if self._partition_id_to_indices_determined: + return + self._determine_partition_id_to_unique_labels() + assert self._unique_labels is not None + self._count_partitions_having_each_unique_label() + + labels = np.asarray(self.dataset[self._partition_by]) + self._check_correctness_of_unique_label_to_times_used_counter(labels) + for partition_id in range(self._num_partitions): + self._partition_id_to_indices[partition_id] = [] + + unused_labels = [] + for unique_label in self._unique_labels: + if self._unique_label_to_times_used_counter[unique_label] == 0: + unused_labels.append(unique_label) + continue + # Get the indices in the original dataset where the y == unique_label + unique_label_to_indices = np.where(labels == unique_label)[0] + + split_unique_labels_to_indices = np.array_split( + unique_label_to_indices, + self._unique_label_to_times_used_counter[unique_label], + ) + + split_index = 0 + for partition_id in range(self._num_partitions): + if unique_label in self._partition_id_to_unique_labels[partition_id]: + self._partition_id_to_indices[partition_id].extend( + split_unique_labels_to_indices[split_index] + ) + split_index += 1 + + if len(unused_labels) >= 1: + warnings.warn( + f"Classes: {unused_labels} will NOT be used due to the chosen " + f"configuration. If it is undesired behavior consider setting" + f" 'first_class_deterministic_assignment=True' which in case when" + f" the number of classes is smaller than the number of partitions will " + f"utilize all the classes for the created partitions.", + stacklevel=1, + ) + if self._shuffle: + for indices in self._partition_id_to_indices.values(): + # In place shuffling + self._rng.shuffle(indices) + + self._partition_id_to_indices_determined = True + + def _check_num_partitions_correctness_if_needed(self) -> None: + """Test num_partitions when the dataset is given (in load_partition).""" + if not self._partition_id_to_indices_determined: + if self._num_partitions > self.dataset.num_rows: + raise ValueError( + "The number of partitions needs to be smaller than the number of " + "samples in the dataset." + ) + + def _determine_partition_id_to_unique_labels(self) -> None: + """Determine the assignment of unique labels to the partitions.""" + self._unique_labels = sorted(self.dataset.unique(self._partition_by)) + num_unique_classes = len(self._unique_labels) + + if self._num_classes_per_partition > num_unique_classes: + raise ValueError( + f"The specified `num_classes_per_partition`" + f"={self._num_classes_per_partition} is greater than the number " + f"of unique classes in the given dataset={num_unique_classes}. " + f"Reduce the `num_classes_per_partition` or make use different dataset " + f"to apply this partitioning." + ) + if self._class_assignment_mode == "first-deterministic": + # if self._first_class_deterministic_assignment: + for partition_id in range(self._num_partitions): + label = partition_id % num_unique_classes + self._partition_id_to_unique_labels[partition_id].append(label) + + while ( + len(self._partition_id_to_unique_labels[partition_id]) + < self._num_classes_per_partition + ): + label = self._rng.choice(self._unique_labels, size=1)[0] + if label not in self._partition_id_to_unique_labels[partition_id]: + self._partition_id_to_unique_labels[partition_id].append(label) + elif self._class_assignment_mode == "deterministic": + for partition_id in range(self._num_partitions): + labels = [] + for i in range(self._num_classes_per_partition): + label = self._unique_labels[ + (partition_id + i) % len(self._unique_labels) + ] + labels.append(label) + self._partition_id_to_unique_labels[partition_id] = labels + elif self._class_assignment_mode == "random": + for partition_id in range(self._num_partitions): + labels = self._rng.choice( + self._unique_labels, + size=self._num_classes_per_partition, + replace=False, + ).tolist() + self._partition_id_to_unique_labels[partition_id] = labels + else: + raise ValueError( + f"The supported class_assignment_mode are: 'random', 'deterministic', " + f"'first-deterministic'. You provided: {self._class_assignment_mode}." + ) + + def _count_partitions_having_each_unique_label(self) -> None: + """Count the number of partitions that have each unique label. + + This computation is based on the assigment of the label to the partition_id in + the `_determine_partition_id_to_unique_labels` method. + Given: + * partition 0 has only labels: 0,1 (not necessarily just two samples it can have + many samples but either from 0 or 1) + * partition 1 has only labels: 1, 2 (same count note as above) + * and there are only two partitions then the following will be computed: + { + 0: 1, + 1: 2, + 2: 1 + } + """ + for unique_label in self._unique_labels: + self._unique_label_to_times_used_counter[unique_label] = 0 + for unique_labels in self._partition_id_to_unique_labels.values(): + for unique_label in unique_labels: + self._unique_label_to_times_used_counter[unique_label] += 1 + + def _check_correctness_of_unique_label_to_times_used_counter( + self, labels: NDArray + ) -> None: + """Check if partitioning is possible given the presence requirements. + + The number of times the label can be used must be smaller or equal to the number + of times that the label is present in the dataset. + """ + for unique_label in self._unique_labels: + num_unique = np.sum(labels == unique_label) + if self._unique_label_to_times_used_counter[unique_label] > num_unique: + raise ValueError( + f"Label: {unique_label} is needed to be assigned to more " + f"partitions " + f"({self._unique_label_to_times_used_counter[unique_label]})" + f" than there are samples (corresponding to this label) in the " + f"dataset ({num_unique}). Please decrease the `num_partitions`, " + f"`num_classes_per_partition` to avoid this situation, " + f"or try `class_assigment_mode='deterministic'` to create a more " + f"even distribution of classes along the partitions. " + f"Alternatively use a different dataset if you can not adjust" + f" the any of these parameters." + ) diff --git a/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py b/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py new file mode 100644 index 000000000000..151b7e14659c --- /dev/null +++ b/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py @@ -0,0 +1,262 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test cases for PathologicalPartitioner.""" + + +import unittest +from typing import Dict + +import numpy as np +from parameterized import parameterized + +import datasets +from datasets import Dataset +from flwr_datasets.partitioner.pathological_partitioner import PathologicalPartitioner + + +def _dummy_dataset_setup( + num_samples: int, partition_by: str, num_unique_classes: int +) -> Dataset: + """Create a dummy dataset for testing.""" + data = { + partition_by: np.tile( + np.arange(num_unique_classes), num_samples // num_unique_classes + 1 + )[:num_samples], + "features": np.random.randn(num_samples), + } + return Dataset.from_dict(data) + + +def _dummy_heterogeneous_dataset_setup( + num_samples: int, partition_by: str, num_unique_classes: int +) -> Dataset: + """Create a dummy dataset for testing.""" + data = { + partition_by: np.tile( + np.arange(num_unique_classes), num_samples // num_unique_classes + 1 + )[:num_samples], + "features": np.random.randn(num_samples), + } + return Dataset.from_dict(data) + + +class TestClassConstrainedPartitioner(unittest.TestCase): + """Unit tests for PathologicalPartitioner.""" + + @parameterized.expand( # type: ignore + [ + # num_partition, num_classes_per_partition, num_samples, total_classes + (3, 1, 60, 3), # Single class per partition scenario + (5, 2, 100, 5), + (5, 2, 100, 10), + (4, 3, 120, 6), + ] + ) + def test_correct_num_classes_when_partitioned( + self, + num_partitions: int, + num_classes_per_partition: int, + num_samples: int, + num_unique_classes: int, + ) -> None: + """Test correct number of unique classes.""" + dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes) + partitioner = PathologicalPartitioner( + num_partitions=num_partitions, + partition_by="labels", + num_classes_per_partition=num_classes_per_partition, + ) + partitioner.dataset = dataset + partitions: Dict[int, Dataset] = { + pid: partitioner.load_partition(pid) for pid in range(num_partitions) + } + unique_classes_per_partition = { + pid: np.unique(partition["labels"]) for pid, partition in partitions.items() + } + + for unique_classes in unique_classes_per_partition.values(): + self.assertEqual(num_classes_per_partition, len(unique_classes)) + + def test_first_class_deterministic_assignment(self) -> None: + """Test deterministic assignment of first classes to partitions. + + Test if all the classes are used (which has to be the case, given num_partitions + >= than the number of unique classes). + """ + dataset = _dummy_dataset_setup(100, "labels", 10) + partitioner = PathologicalPartitioner( + num_partitions=10, + partition_by="labels", + num_classes_per_partition=2, + class_assignment_mode="first-deterministic", + ) + partitioner.dataset = dataset + partitioner.load_partition(0) + expected_classes = set(range(10)) + actual_classes = set() + for pid in range(10): + partition = partitioner.load_partition(pid) + actual_classes.update(np.unique(partition["labels"])) + self.assertEqual(expected_classes, actual_classes) + + @parameterized.expand( + [ # type: ignore + # num_partitions, num_classes_per_partition, num_samples, num_unique_classes + (4, 2, 80, 8), + (10, 2, 100, 10), + ] + ) + def test_deterministic_class_assignment( + self, num_partitions, num_classes_per_partition, num_samples, num_unique_classes + ): + """Test deterministic assignment of classes to partitions.""" + dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes) + partitioner = PathologicalPartitioner( + num_partitions=num_partitions, + partition_by="labels", + num_classes_per_partition=num_classes_per_partition, + class_assignment_mode="deterministic", + ) + partitioner.dataset = dataset + partitions = { + pid: partitioner.load_partition(pid) for pid in range(num_partitions) + } + + # Verify each partition has the expected classes, order does not matter + for pid, partition in partitions.items(): + expected_labels = sorted( + [ + (pid + i) % num_unique_classes + for i in range(num_classes_per_partition) + ] + ) + actual_labels = sorted(np.unique(partition["labels"])) + self.assertTrue( + np.array_equal(expected_labels, actual_labels), + f"Partition {pid} does not have the expected labels: " + f"{expected_labels} but instead {actual_labels}.", + ) + + @parameterized.expand( + [ # type: ignore + # num_partitions, num_classes_per_partition, num_samples, num_unique_classes + (10, 3, 20, 3), + ] + ) + def test_too_many_partitions_for_a_class( + self, num_partitions, num_classes_per_partition, num_samples, num_unique_classes + ) -> None: + """Test too many partitions for the number of samples in a class.""" + dataset_1 = _dummy_dataset_setup( + num_samples // 2, "labels", num_unique_classes - 1 + ) + # Create a skewed part of the dataset for the last label + data = { + "labels": np.array([num_unique_classes - 1] * (num_samples // 2)), + "features": np.random.randn(num_samples // 2), + } + dataset_2 = Dataset.from_dict(data) + dataset = datasets.concatenate_datasets([dataset_1, dataset_2]) + + partitioner = PathologicalPartitioner( + num_partitions=num_partitions, + partition_by="labels", + num_classes_per_partition=num_classes_per_partition, + class_assignment_mode="random", + ) + partitioner.dataset = dataset + + with self.assertRaises(ValueError) as context: + _ = partitioner.load_partition(0) + self.assertEqual( + str(context.exception), + "Label: 0 is needed to be assigned to more partitions (10) than there are " + "samples (corresponding to this label) in the dataset (5). " + "Please decrease the `num_partitions`, `num_classes_per_partition` to " + "avoid this situation, or try `class_assigment_mode='deterministic'` to " + "create a more even distribution of classes along the partitions. " + "Alternatively use a different dataset if you can not adjust the any of " + "these parameters.", + ) + + @parameterized.expand( # type: ignore + [ + # num_partitions, num_classes_per_partition, num_samples, num_unique_classes + (10, 11, 100, 10), # 11 > 10 + (5, 11, 100, 10), # 11 > 10 + (10, 20, 100, 5), # 20 > 5 + ] + ) + def test_more_classes_per_partition_than_num_unique_classes_in_dataset_raises( + self, + num_partitions: int, + num_classes_per_partition: int, + num_samples: int, + num_unique_classes: int, + ) -> None: + """Test more num_classes_per_partition > num_unique_classes in the dataset.""" + dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes) + with self.assertRaises(ValueError) as context: + partitioner = PathologicalPartitioner( + num_partitions=num_partitions, + partition_by="labels", + num_classes_per_partition=num_classes_per_partition, + ) + partitioner.dataset = dataset + partitioner.load_partition(0) + self.assertEqual( + str(context.exception), + "The specified " + f"`num_classes_per_partition`={num_classes_per_partition} is " + f"greater than the number of unique classes in the given " + f"dataset={len(dataset.unique('labels'))}. Reduce the " + f"`num_classes_per_partition` or make use different dataset " + f"to apply this partitioning.", + ) + + @parameterized.expand( # type: ignore + [ + # num_classes_per_partition should be irrelevant since the exception should + # be raised at the very beginning + # num_partitions, num_classes_per_partition, num_samples + (10, 2, 5), + (10, 10, 5), + (100, 10, 99), + ] + ) + def test_more_partitions_than_samples_raises( + self, num_partitions: int, num_classes_per_partition: int, num_samples: int + ) -> None: + """Test if generation of more partitions that there are samples raises.""" + # The number of unique classes in the dataset should be irrelevant since the + # exception should be raised at the very beginning + dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes=5) + with self.assertRaises(ValueError) as context: + partitioner = PathologicalPartitioner( + num_partitions=num_partitions, + partition_by="labels", + num_classes_per_partition=num_classes_per_partition, + ) + partitioner.dataset = dataset + partitioner.load_partition(0) + self.assertEqual( + str(context.exception), + "The number of partitions needs to be smaller than the number of " + "samples in the dataset.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/dev/build-docker-image-matrix.py b/dev/build-docker-image-matrix.py index 51d7fd0083d1..b7c4d2daaefd 100644 --- a/dev/build-docker-image-matrix.py +++ b/dev/build-docker-image-matrix.py @@ -168,6 +168,13 @@ def tag_latest_ubuntu_with_flwr_version(image: BaseImage) -> List[str]: tag_latest_ubuntu_with_flwr_version, lambda image: image.distro.name == DistroName.UBUNTU, ) + # ubuntu images for each supported python version + + generate_binary_images( + "superexec", + base_images, + tag_latest_ubuntu_with_flwr_version, + lambda image: image.distro.name == DistroName.UBUNTU, + ) ) print( diff --git a/dev/build-docs.sh b/dev/build-docs.sh index f8d4f91508de..f4bf958b0ebf 100755 --- a/dev/build-docs.sh +++ b/dev/build-docs.sh @@ -8,9 +8,7 @@ cd $ROOT ./dev/build-baseline-docs.sh cd $ROOT -./dev/update-examples.sh -cd examples/doc -make docs +python dev/build-example-docs.py cd $ROOT ./datasets/dev/build-flwr-datasets-docs.sh diff --git a/dev/build-example-docs.py b/dev/build-example-docs.py new file mode 100644 index 000000000000..367994708bf9 --- /dev/null +++ b/dev/build-example-docs.py @@ -0,0 +1,283 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Build the Flower Example docs.""" + +import os +import shutil +import re +import subprocess +from pathlib import Path + +ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +INDEX = os.path.join(ROOT, "examples", "doc", "source", "index.rst") + +initial_text = """ +Flower Examples Documentation +----------------------------- + +Welcome to Flower Examples' documentation. `Flower `_ is +a friendly federated learning framework. + +Join the Flower Community +------------------------- + +The Flower Community is growing quickly - we're a friendly group of researchers, +engineers, students, professionals, academics, and other enthusiasts. + +.. button-link:: https://flower.ai/join-slack + :color: primary + :shadow: + + Join us on Slack + +Quickstart Examples +------------------- + +Flower Quickstart Examples are a collection of demo projects that show how you +can use Flower in combination with other existing frameworks or technologies. + +""" + +table_headers = ( + "\n.. list-table::\n :widths: 50 15 15 15\n " + ":header-rows: 1\n\n * - Title\n - Framework\n - Dataset\n - Tags\n\n" +) + +categories = { + "quickstart": {"table": table_headers, "list": ""}, + "advanced": {"table": table_headers, "list": ""}, + "other": {"table": table_headers, "list": ""}, +} + +urls = { + # Frameworks + "Android": "https://www.android.com/", + "C++": "https://isocpp.org/", + "Docker": "https://www.docker.com/", + "JAX": "https://jax.readthedocs.io/en/latest/", + "Java": "https://www.java.com/", + "Keras": "https://keras.io/", + "Kotlin": "https://kotlinlang.org/", + "mlcube": "https://docs.mlcommons.org/mlcube/", + "MLX": "https://ml-explore.github.io/mlx/build/html/index.html", + "MONAI": "https://monai.io/", + "PEFT": "https://huggingface.co/docs/peft/index", + "Swift": "https://www.swift.org/", + "TensorFlowLite": "https://www.tensorflow.org/lite", + "fastai": "https://fast.ai/", + "lifelines": "https://lifelines.readthedocs.io/en/latest/index.html", + "lightning": "https://lightning.ai/docs/pytorch/stable/", + "numpy": "https://numpy.org/", + "opacus": "https://opacus.ai/", + "pandas": "https://pandas.pydata.org/", + "scikit-learn": "https://scikit-learn.org/", + "tabnet": "https://github.com/titu1994/tf-TabNet", + "tensorboard": "https://www.tensorflow.org/tensorboard", + "tensorflow": "https://www.tensorflow.org/", + "torch": "https://pytorch.org/", + "torchvision": "https://pytorch.org/vision/stable/index.html", + "transformers": "https://huggingface.co/docs/transformers/index", + "wandb": "https://wandb.ai/home", + "whisper": "https://huggingface.co/openai/whisper-tiny", + "xgboost": "https://xgboost.readthedocs.io/en/stable/", + # Datasets + "Adult Census Income": "https://www.kaggle.com/datasets/uciml/adult-census-income/data", + "Alpaca-GPT4": "https://huggingface.co/datasets/vicgalle/alpaca-gpt4", + "CIFAR-10": "https://huggingface.co/datasets/uoft-cs/cifar10", + "HIGGS": "https://archive.ics.uci.edu/dataset/280/higgs", + "IMDB": "https://huggingface.co/datasets/stanfordnlp/imdb", + "Iris": "https://scikit-learn.org/stable/auto_examples/datasets/plot_iris_dataset.html", + "MNIST": "https://huggingface.co/datasets/ylecun/mnist", + "MedNIST": "https://medmnist.com/", + "Oxford Flower-102": "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/", + "SpeechCommands": "https://huggingface.co/datasets/google/speech_commands", + "Titanic": "https://www.kaggle.com/competitions/titanic", + "Waltons": "https://lifelines.readthedocs.io/en/latest/lifelines.datasets.html#lifelines.datasets.load_waltons", +} + + +def _convert_to_link(search_result): + if "," in search_result: + result = "" + for part in search_result.split(","): + result += f"{_convert_to_link(part)}, " + return result[:-2] + else: + search_result = search_result.strip() + name, url = search_result, urls.get(search_result, None) + if url: + return f"`{name.strip()} <{url.strip()}>`_" + else: + return search_result + + +def _read_metadata(example): + with open(os.path.join(example, "README.md")) as f: + content = f.read() + + metadata_match = re.search(r"^---(.*?)^---", content, re.DOTALL | re.MULTILINE) + if not metadata_match: + raise ValueError("Metadata block not found") + metadata = metadata_match.group(1) + + title_match = re.search(r"^# (.+)$", content, re.MULTILINE) + if not title_match: + raise ValueError("Title not found in metadata") + title = title_match.group(1).strip() + + tags_match = re.search(r"^tags:\s*\[(.+?)\]$", metadata, re.MULTILINE) + if not tags_match: + raise ValueError("Tags not found in metadata") + tags = tags_match.group(1).strip() + + dataset_match = re.search( + r"^dataset:\s*\[(.*?)\]$", metadata, re.DOTALL | re.MULTILINE + ) + if not dataset_match: + raise ValueError("Dataset not found in metadata") + dataset = dataset_match.group(1).strip() + + framework_match = re.search( + r"^framework:\s*\[(.*?|)\]$", metadata, re.DOTALL | re.MULTILINE + ) + if not framework_match: + raise ValueError("Framework not found in metadata") + framework = framework_match.group(1).strip() + + dataset = _convert_to_link(re.sub(r"\s+", " ", dataset).strip()) + framework = _convert_to_link(re.sub(r"\s+", " ", framework).strip()) + return title, tags, dataset, framework + + +def _add_table_entry(example, tag, table_var): + title, tags, dataset, framework = _read_metadata(example) + example_name = Path(example).stem + table_entry = ( + f" * - `{title} <{example_name}.html>`_ \n " + f"- {framework} \n - {dataset} \n - {tags}\n\n" + ) + if tag in tags: + categories[table_var]["table"] += table_entry + categories[table_var]["list"] += f" {example_name}\n" + return True + return False + + +def _copy_markdown_files(example): + for file in os.listdir(example): + if file.endswith(".md"): + src = os.path.join(example, file) + dest = os.path.join( + ROOT, "examples", "doc", "source", os.path.basename(example) + ".md" + ) + shutil.copyfile(src, dest) + + +def _add_gh_button(example): + gh_text = f'[View on GitHub](https://github.com/adap/flower/blob/main/examples/{example})' + readme_file = os.path.join(ROOT, "examples", "doc", "source", example + ".md") + with open(readme_file, "r+") as f: + content = f.read() + if gh_text not in content: + content = re.sub( + r"(^# .+$)", rf"\1\n\n{gh_text}", content, count=1, flags=re.MULTILINE + ) + f.seek(0) + f.write(content) + f.truncate() + + +def _copy_images(example): + static_dir = os.path.join(example, "_static") + dest_dir = os.path.join(ROOT, "examples", "doc", "source", "_static") + if os.path.isdir(static_dir): + for file in os.listdir(static_dir): + if file.endswith((".jpg", ".png", ".jpeg")): + shutil.copyfile( + os.path.join(static_dir, file), os.path.join(dest_dir, file) + ) + + +def _add_all_entries(): + examples_dir = os.path.join(ROOT, "examples") + for example in sorted(os.listdir(examples_dir)): + example_path = os.path.join(examples_dir, example) + if os.path.isdir(example_path) and example != "doc": + _copy_markdown_files(example_path) + _add_gh_button(example) + _copy_images(example) + + +def _main(): + if os.path.exists(INDEX): + os.remove(INDEX) + + with open(INDEX, "w") as index_file: + index_file.write(initial_text) + + examples_dir = os.path.join(ROOT, "examples") + for example in sorted(os.listdir(examples_dir)): + example_path = os.path.join(examples_dir, example) + if os.path.isdir(example_path) and example != "doc": + _copy_markdown_files(example_path) + _add_gh_button(example) + _copy_images(example_path) + if not _add_table_entry(example_path, "quickstart", "quickstart"): + if not _add_table_entry(example_path, "comprehensive", "comprehensive"): + if not _add_table_entry(example_path, "advanced", "advanced"): + _add_table_entry(example_path, "", "other") + + with open(INDEX, "a") as index_file: + index_file.write(categories["quickstart"]["table"]) + + index_file.write("\nAdvanced Examples\n-----------------\n") + index_file.write( + "Advanced Examples are mostly for users that are both familiar with " + "Federated Learning but also somewhat familiar with Flower's main " + "features.\n" + ) + index_file.write(categories["advanced"]["table"]) + + index_file.write("\nOther Examples\n--------------\n") + index_file.write( + "Flower Examples are a collection of example projects written with " + "Flower that explore different domains and features. You can check " + "which examples already exist and/or contribute your own example.\n" + ) + index_file.write(categories["other"]["table"]) + + _add_all_entries() + + index_file.write( + "\n.. toctree::\n :maxdepth: 1\n :caption: Quickstart\n :hidden:\n\n" + ) + index_file.write(categories["quickstart"]["list"]) + + index_file.write( + "\n.. toctree::\n :maxdepth: 1\n :caption: Advanced\n :hidden:\n\n" + ) + index_file.write(categories["advanced"]["list"]) + + index_file.write( + "\n.. toctree::\n :maxdepth: 1\n :caption: Others\n :hidden:\n\n" + ) + index_file.write(categories["other"]["list"]) + + index_file.write("\n") + + +if __name__ == "__main__": + _main() + subprocess.call(f"cd {ROOT}/examples/doc && make html", shell=True) diff --git a/dev/changelog_config.toml b/dev/changelog_config.toml index c5ff1bcdd1c1..637ea9b4b2c6 100644 --- a/dev/changelog_config.toml +++ b/dev/changelog_config.toml @@ -3,7 +3,7 @@ type = ["ci", "docs", "feat", "fix", "refactor", "break"] -project = ["framework", "baselines", "datasets", "examples"] +project = ["framework", "baselines", "datasets", "examples", "benchmarks"] scope = "skip" diff --git a/dev/format.sh b/dev/format.sh index 71edf9c6065a..e1e2abc307f1 100755 --- a/dev/format.sh +++ b/dev/format.sh @@ -18,6 +18,11 @@ find src/proto/flwr/proto -name *.proto | grep "\.proto" | xargs clang-format -i python -m black -q examples python -m docformatter -i -r examples +# Benchmarks +python -m isort benchmarks +python -m black -q benchmarks +python -m docformatter -i -r benchmarks + # E2E python -m isort e2e python -m black -q e2e diff --git a/dev/test.sh b/dev/test.sh index 8cbe88c9298b..58ac0b3d24cd 100755 --- a/dev/test.sh +++ b/dev/test.sh @@ -11,11 +11,11 @@ clang-format --Werror --dry-run src/proto/flwr/proto/* echo "- clang-format: done" echo "- isort: start" -python -m isort --check-only --skip src/py/flwr/proto src/py/flwr e2e +python -m isort --check-only --skip src/py/flwr/proto src/py/flwr benchmarks e2e echo "- isort: done" echo "- black: start" -python -m black --exclude "src\/py\/flwr\/proto" --check src/py/flwr examples e2e +python -m black --exclude "src\/py\/flwr\/proto" --check src/py/flwr benchmarks examples e2e echo "- black: done" echo "- init_py_check: start" diff --git a/dev/update-examples.sh b/dev/update-examples.sh deleted file mode 100755 index 1076b4621984..000000000000 --- a/dev/update-examples.sh +++ /dev/null @@ -1,91 +0,0 @@ -#!/bin/bash -set -e -cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/../ - -ROOT=`pwd` -INDEX=$ROOT/examples/doc/source/index.md -INSERT_LINE=6 - -copy_markdown_files () { - for file in $1/*.md; do - # Copy the README into the source of the Example docs as the name of the example - if [[ $(basename "$file") = "README.md" ]]; then - cp $file $ROOT/examples/doc/source/$1.md 2>&1 >/dev/null - else - # If the example contains other markdown files, copy them to the source of the Example docs - cp $file $ROOT/examples/doc/source/$(basename "$file") 2>&1 >/dev/null - fi - done -} - -add_gh_button () { - gh_text="[\"View](https://github.com/adap/flower/blob/main/examples/$1)" - readme_file="$ROOT/examples/doc/source/$1.md" - - if ! grep -Fq "$gh_text" "$readme_file"; then - awk -v text="$gh_text" ' - /^# / && !found { - print $0 "\n" text; - found=1; - next; - } - { print } - ' "$readme_file" > tmpfile && mv tmpfile "$readme_file" - fi -} - -copy_images () { - if [ -d "$1/_static" ]; then - cp $1/_static/**.{jpg,png,jpeg} $ROOT/examples/doc/source/_static/ 2>/dev/null || true - fi -} - -add_to_index () { - (echo $INSERT_LINE; echo a; echo $1; echo .; echo wq) | ed $INDEX 2>&1 >/dev/null -} - -add_single_entry () { - # Copy markdown files to correct folder - copy_markdown_files $1 - - # Add button linked to GitHub - add_gh_button $1 - - # Copy all images of the _static folder into the examples - # docs static folder - copy_images $1 - - # Insert the name of the example into the index file - add_to_index $1 -} - -add_all_entries () { - cd $ROOT/examples - # Iterate through each folder in examples/ - for d in $(printf '%s\n' */ | sort -V); do - # Add entry based on the name of the folder - example=${d%/} - - if [[ $example != doc ]]; then - add_single_entry $example - fi - done -} - -# Clean up before starting -rm -f $ROOT/examples/doc/source/*.md -rm -f $INDEX - -# Create empty index file -touch $INDEX - -echo "# Flower Examples Documentation" >> $INDEX -echo "" >> $INDEX -echo "\`\`\`{toctree}" >> $INDEX -echo "---" >> $INDEX -echo "maxdepth: 1" >> $INDEX -echo "---" >> $INDEX - -add_all_entries - -echo "\`\`\`" >> $INDEX diff --git a/doc/source/_static/docker-ci-release.png b/doc/source/_static/docker-ci-release.png deleted file mode 100644 index 6ec97ce9fb06..000000000000 Binary files a/doc/source/_static/docker-ci-release.png and /dev/null differ diff --git a/doc/source/contributor-how-to-release-flower.rst b/doc/source/contributor-how-to-release-flower.rst index fc4c2d436b05..4853d87bc4c1 100644 --- a/doc/source/contributor-how-to-release-flower.rst +++ b/doc/source/contributor-how-to-release-flower.rst @@ -12,24 +12,6 @@ The version number of a release is stated in ``pyproject.toml``. To release a ne 2. Once the changelog has been updated with all the changes, run ``./dev/prepare-release-changelog.sh v``, where ```` is the version stated in ``pyproject.toml`` (notice the ``v`` added before it). This will replace the ``Unreleased`` header of the changelog by the version and current date, and it will add a thanking message for the contributors. Open a pull request with those changes. 3. Once the pull request is merged, tag the release commit with the version number as soon as the PR is merged: ``git tag v`` (notice the ``v`` added before the version number), then ``git push --tags``. This will create a draft release on GitHub containing the correct artifacts and the relevant part of the changelog. 4. Check the draft release on GitHub, and if everything is good, publish it. -5. Trigger the CI for building the Docker images. - -To trigger the workflow, a collaborator must create a ``workflow_dispatch`` event in the -GitHub CI. This can be done either through the UI or via the GitHub CLI. The event requires only one -input, the Flower version, to be released. - -**Via the UI** - -1. Go to the ``Build docker images`` workflow `page `_. -2. Click on the ``Run workflow`` button and type the new version of Flower in the ``Version of Flower`` input field. -3. Click on the **green** ``Run workflow`` button. - -.. image:: _static/docker-ci-release.png - -**Via the GitHub CI** - -1. Make sure you are logged in via ``gh auth login`` and that the current working directory is the root of the Flower repository. -2. Trigger the workflow via ``gh workflow run docker-images.yml -f flwr-version=``. After the release ----------------- diff --git a/doc/source/explanation-differential-privacy.rst b/doc/source/explanation-differential-privacy.rst index 69fd333f9b13..e488f5ccbd57 100644 --- a/doc/source/explanation-differential-privacy.rst +++ b/doc/source/explanation-differential-privacy.rst @@ -32,7 +32,7 @@ and for all possible outputs S ⊆ Range(A): .. math:: \small - P[M(D_{1} \in A)] \leq e^{\delta} P[M(D_{2} \in A)] + \delta + P[M(D_{1} \in A)] \leq e^{\epsilon} P[M(D_{2} \in A)] + \delta The :math:`\epsilon` parameter, also known as the privacy budget, is a metric of privacy loss. diff --git a/doc/source/tutorial-quickstart-xgboost.rst b/doc/source/tutorial-quickstart-xgboost.rst index 7ac055138814..34ad5f6e99c0 100644 --- a/doc/source/tutorial-quickstart-xgboost.rst +++ b/doc/source/tutorial-quickstart-xgboost.rst @@ -96,26 +96,26 @@ Prior to local training, we require loading the HIGGS dataset from Flower Datase fds = FederatedDataset(dataset="jxie/higgs", partitioners={"train": partitioner}) # Load the partition for this `node_id` - partition = fds.load_partition(node_id=args.node_id, split="train") + partition = fds.load_partition(partition_id=args.partition_id, split="train") partition.set_format("numpy") -In this example, we split the dataset into two partitions with uniform distribution (:code:`IidPartitioner(num_partitions=2)`). -Then, we load the partition for the given client based on :code:`node_id`: +In this example, we split the dataset into 30 partitions with uniform distribution (:code:`IidPartitioner(num_partitions=30)`). +Then, we load the partition for the given client based on :code:`partition_id`: .. code-block:: python - # We first define arguments parser for user to specify the client/node ID. + # We first define arguments parser for user to specify the client/partition ID. parser = argparse.ArgumentParser() parser.add_argument( - "--node-id", + "--partition-id", default=0, type=int, - help="Node ID used for the current client.", + help="Partition ID used for the current client.", ) args = parser.parse_args() - # Load the partition for this `node_id`. - partition = fds.load_partition(idx=args.node_id, split="train") + # Load the partition for this `partition_id`. + partition = fds.load_partition(idx=args.partition_id, split="train") partition.set_format("numpy") After that, we do train/test splitting on the given partition (client's local data), and transform data format for :code:`xgboost` package. @@ -186,12 +186,23 @@ We follow the general rule to define :code:`XgbClient` class inherited from :cod .. code-block:: python class XgbClient(fl.client.Client): - def __init__(self): - self.bst = None - self.config = None + def __init__( + self, + train_dmatrix, + valid_dmatrix, + num_train, + num_val, + num_local_round, + params, + ): + self.train_dmatrix = train_dmatrix + self.valid_dmatrix = valid_dmatrix + self.num_train = num_train + self.num_val = num_val + self.num_local_round = num_local_round + self.params = params -The :code:`self.bst` is used to keep the Booster objects that remain consistent across rounds, -allowing them to store predictions from trees integrated in earlier rounds and maintain other essential data structures for training. +All required parameters defined above are passed to :code:`XgbClient`'s constructor. Then, we override :code:`get_parameters`, :code:`fit` and :code:`evaluate` methods insides :code:`XgbClient` class as follows. @@ -214,27 +225,27 @@ As a result, let's return an empty tensor in :code:`get_parameters` when it is c .. code-block:: python def fit(self, ins: FitIns) -> FitRes: - if not self.bst: + global_round = int(ins.config["global_round"]) + if global_round == 1: # First round local training - log(INFO, "Start training at round 1") bst = xgb.train( - params, - train_dmatrix, - num_boost_round=num_local_round, - evals=[(valid_dmatrix, "validate"), (train_dmatrix, "train")], + self.params, + self.train_dmatrix, + num_boost_round=self.num_local_round, + evals=[(self.valid_dmatrix, "validate"), (self.train_dmatrix, "train")], ) - self.config = bst.save_config() - self.bst = bst else: + bst = xgb.Booster(params=self.params) for item in ins.parameters.tensors: global_model = bytearray(item) # Load global model into booster - self.bst.load_model(global_model) - self.bst.load_config(self.config) + bst.load_model(global_model) - bst = self._local_boost() + # Local training + bst = self._local_boost(bst) + # Save model local_model = bst.save_raw("json") local_model_bytes = bytes(local_model) @@ -244,60 +255,81 @@ As a result, let's return an empty tensor in :code:`get_parameters` when it is c message="OK", ), parameters=Parameters(tensor_type="", tensors=[local_model_bytes]), - num_examples=num_train, + num_examples=self.num_train, metrics={}, ) In :code:`fit`, at the first round, we call :code:`xgb.train()` to build up the first set of trees. -the returned Booster object and config are stored in :code:`self.bst` and :code:`self.config`, respectively. -From the second round, we load the global model sent from server to :code:`self.bst`, +From the second round, we load the global model sent from server to new build Booster object, and then update model weights on local training data with function :code:`local_boost` as follows: .. code-block:: python - def _local_boost(self): + def _local_boost(self, bst_input): # Update trees based on local training data. - for i in range(num_local_round): - self.bst.update(train_dmatrix, self.bst.num_boosted_rounds()) + for i in range(self.num_local_round): + bst_input.update(self.train_dmatrix, bst_input.num_boosted_rounds()) - # Extract the last N=num_local_round trees for sever aggregation - bst = self.bst[ - self.bst.num_boosted_rounds() - - num_local_round : self.bst.num_boosted_rounds() + # Bagging: extract the last N=num_local_round trees for sever aggregation + bst = bst_input[ + bst_input.num_boosted_rounds() + - self.num_local_round : bst_input.num_boosted_rounds() ] -Given :code:`num_local_round`, we update trees by calling :code:`self.bst.update` method. + return bst + +Given :code:`num_local_round`, we update trees by calling :code:`bst_input.update` method. After training, the last :code:`N=num_local_round` trees will be extracted to send to the server. .. code-block:: python def evaluate(self, ins: EvaluateIns) -> EvaluateRes: - eval_results = self.bst.eval_set( - evals=[(valid_dmatrix, "valid")], - iteration=self.bst.num_boosted_rounds() - 1, + # Load global model + bst = xgb.Booster(params=self.params) + for para in ins.parameters.tensors: + para_b = bytearray(para) + bst.load_model(para_b) + + # Run evaluation + eval_results = bst.eval_set( + evals=[(self.valid_dmatrix, "valid")], + iteration=bst.num_boosted_rounds() - 1, ) auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) + global_round = ins.config["global_round"] + log(INFO, f"AUC = {auc} at round {global_round}") + return EvaluateRes( status=Status( code=Code.OK, message="OK", ), loss=0.0, - num_examples=num_val, + num_examples=self.num_val, metrics={"AUC": auc}, ) -In :code:`evaluate`, we call :code:`self.bst.eval_set` function to conduct evaluation on valid set. +In :code:`evaluate`, after loading the global model, we call :code:`bst.eval_set` function to conduct evaluation on valid set. The AUC value will be returned. Now, we can create an instance of our class :code:`XgbClient` and add one line to actually run this client: .. code-block:: python - fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient()) + fl.client.start_client( + server_address="127.0.0.1:8080", + client=XgbClient( + train_dmatrix, + valid_dmatrix, + num_train, + num_val, + num_local_round, + params, + ).to_client(), + ) -That's it for the client. We only have to implement :code:`Client`and call :code:`fl.client.start_client()`. +That's it for the client. We only have to implement :code:`Client` and call :code:`fl.client.start_client()`. The string :code:`"[::]:8080"` tells the client which server to connect to. In our case we can run the server and the client on the same machine, therefore we use :code:`"[::]:8080"`. If we run a truly federated workload with the server and @@ -325,6 +357,8 @@ We first define a strategy for XGBoost bagging aggregation. min_evaluate_clients=2, fraction_evaluate=1.0, evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation, + on_evaluate_config_fn=config_func, + on_fit_config_fn=config_func, ) def evaluate_metrics_aggregation(eval_metrics): @@ -336,8 +370,16 @@ We first define a strategy for XGBoost bagging aggregation. metrics_aggregated = {"AUC": auc_aggregated} return metrics_aggregated + def config_func(rnd: int) -> Dict[str, str]: + """Return a configuration with global epochs.""" + config = { + "global_round": str(rnd), + } + return config + We use two clients for this example. An :code:`evaluate_metrics_aggregation` function is defined to collect and wighted average the AUC values from clients. +The :code:`config_func` function is to return the current FL round number to client's :code:`fit()` and :code:`evaluate()` methods. Then, we start the server: @@ -346,7 +388,7 @@ Then, we start the server: # Start Flower server fl.server.start_server( server_address="0.0.0.0:8080", - config=fl.server.ServerConfig(num_rounds=num_rounds), + config=fl.server.ServerConfig(num_rounds=5), strategy=strategy, ) @@ -535,52 +577,66 @@ Open a new terminal and start the first client: .. code-block:: shell - $ python3 client.py --node-id=0 + $ python3 client.py --partition-id=0 Open another terminal and start the second client: .. code-block:: shell - $ python3 client.py --node-id=1 + $ python3 client.py --partition-id=1 Each client will have its own dataset. You should now see how the training does in the very first terminal (the one that started the server): .. code-block:: shell - INFO flwr 2023-11-20 11:21:56,454 | app.py:163 | Starting Flower server, config: ServerConfig(num_rounds=5, round_timeout=None) - INFO flwr 2023-11-20 11:21:56,473 | app.py:176 | Flower ECE: gRPC server running (5 rounds), SSL is disabled - INFO flwr 2023-11-20 11:21:56,473 | server.py:89 | Initializing global parameters - INFO flwr 2023-11-20 11:21:56,473 | server.py:276 | Requesting initial parameters from one random client - INFO flwr 2023-11-20 11:22:38,302 | server.py:280 | Received initial parameters from one random client - INFO flwr 2023-11-20 11:22:38,302 | server.py:91 | Evaluating initial parameters - INFO flwr 2023-11-20 11:22:38,302 | server.py:104 | FL starting - DEBUG flwr 2023-11-20 11:22:38,302 | server.py:222 | fit_round 1: strategy sampled 2 clients (out of 2) - DEBUG flwr 2023-11-20 11:22:38,636 | server.py:236 | fit_round 1 received 2 results and 0 failures - DEBUG flwr 2023-11-20 11:22:38,643 | server.py:173 | evaluate_round 1: strategy sampled 2 clients (out of 2) - DEBUG flwr 2023-11-20 11:22:38,653 | server.py:187 | evaluate_round 1 received 2 results and 0 failures - DEBUG flwr 2023-11-20 11:22:38,653 | server.py:222 | fit_round 2: strategy sampled 2 clients (out of 2) - DEBUG flwr 2023-11-20 11:22:38,721 | server.py:236 | fit_round 2 received 2 results and 0 failures - DEBUG flwr 2023-11-20 11:22:38,745 | server.py:173 | evaluate_round 2: strategy sampled 2 clients (out of 2) - DEBUG flwr 2023-11-20 11:22:38,756 | server.py:187 | evaluate_round 2 received 2 results and 0 failures - DEBUG flwr 2023-11-20 11:22:38,756 | server.py:222 | fit_round 3: strategy sampled 2 clients (out of 2) - DEBUG flwr 2023-11-20 11:22:38,831 | server.py:236 | fit_round 3 received 2 results and 0 failures - DEBUG flwr 2023-11-20 11:22:38,868 | server.py:173 | evaluate_round 3: strategy sampled 2 clients (out of 2) - DEBUG flwr 2023-11-20 11:22:38,881 | server.py:187 | evaluate_round 3 received 2 results and 0 failures - DEBUG flwr 2023-11-20 11:22:38,881 | server.py:222 | fit_round 4: strategy sampled 2 clients (out of 2) - DEBUG flwr 2023-11-20 11:22:38,960 | server.py:236 | fit_round 4 received 2 results and 0 failures - DEBUG flwr 2023-11-20 11:22:39,012 | server.py:173 | evaluate_round 4: strategy sampled 2 clients (out of 2) - DEBUG flwr 2023-11-20 11:22:39,026 | server.py:187 | evaluate_round 4 received 2 results and 0 failures - DEBUG flwr 2023-11-20 11:22:39,026 | server.py:222 | fit_round 5: strategy sampled 2 clients (out of 2) - DEBUG flwr 2023-11-20 11:22:39,111 | server.py:236 | fit_round 5 received 2 results and 0 failures - DEBUG flwr 2023-11-20 11:22:39,177 | server.py:173 | evaluate_round 5: strategy sampled 2 clients (out of 2) - DEBUG flwr 2023-11-20 11:22:39,193 | server.py:187 | evaluate_round 5 received 2 results and 0 failures - INFO flwr 2023-11-20 11:22:39,193 | server.py:153 | FL finished in 0.8905023969999988 - INFO flwr 2023-11-20 11:22:39,193 | app.py:226 | app_fit: losses_distributed [(1, 0), (2, 0), (3, 0), (4, 0), (5, 0)] - INFO flwr 2023-11-20 11:22:39,193 | app.py:227 | app_fit: metrics_distributed_fit {} - INFO flwr 2023-11-20 11:22:39,193 | app.py:228 | app_fit: metrics_distributed {'AUC': [(1, 0.7572), (2, 0.7705), (3, 0.77595), (4, 0.78), (5, 0.78385)]} - INFO flwr 2023-11-20 11:22:39,193 | app.py:229 | app_fit: losses_centralized [] - INFO flwr 2023-11-20 11:22:39,193 | app.py:230 | app_fit: metrics_centralized {} + INFO : Starting Flower server, config: num_rounds=5, no round_timeout + INFO : Flower ECE: gRPC server running (5 rounds), SSL is disabled + INFO : [INIT] + INFO : Requesting initial parameters from one random client + INFO : Received initial parameters from one random client + INFO : Evaluating initial global parameters + INFO : + INFO : [ROUND 1] + INFO : configure_fit: strategy sampled 2 clients (out of 2) + INFO : aggregate_fit: received 2 results and 0 failures + INFO : configure_evaluate: strategy sampled 2 clients (out of 2) + INFO : aggregate_evaluate: received 2 results and 0 failures + INFO : + INFO : [ROUND 2] + INFO : configure_fit: strategy sampled 2 clients (out of 2) + INFO : aggregate_fit: received 2 results and 0 failures + INFO : configure_evaluate: strategy sampled 2 clients (out of 2) + INFO : aggregate_evaluate: received 2 results and 0 failures + INFO : + INFO : [ROUND 3] + INFO : configure_fit: strategy sampled 2 clients (out of 2) + INFO : aggregate_fit: received 2 results and 0 failures + INFO : configure_evaluate: strategy sampled 2 clients (out of 2) + INFO : aggregate_evaluate: received 2 results and 0 failures + INFO : + INFO : [ROUND 4] + INFO : configure_fit: strategy sampled 2 clients (out of 2) + INFO : aggregate_fit: received 2 results and 0 failures + INFO : configure_evaluate: strategy sampled 2 clients (out of 2) + INFO : aggregate_evaluate: received 2 results and 0 failures + INFO : + INFO : [ROUND 5] + INFO : configure_fit: strategy sampled 2 clients (out of 2) + INFO : aggregate_fit: received 2 results and 0 failures + INFO : configure_evaluate: strategy sampled 2 clients (out of 2) + INFO : aggregate_evaluate: received 2 results and 0 failures + INFO : + INFO : [SUMMARY] + INFO : Run finished 5 round(s) in 1.67s + INFO : History (loss, distributed): + INFO : round 1: 0 + INFO : round 2: 0 + INFO : round 3: 0 + INFO : round 4: 0 + INFO : round 5: 0 + INFO : History (metrics, distributed, evaluate): + INFO : {'AUC': [(1, 0.76755), (2, 0.775), (3, 0.77935), (4, 0.7836), (5, 0.7872)]} Congratulations! You've successfully built and run your first federated XGBoost system. diff --git a/e2e/bare-client-auth/client.py b/e2e/bare-client-auth/client.py index e82f17088bd9..c7b0d59b8ea5 100644 --- a/e2e/bare-client-auth/client.py +++ b/e2e/bare-client-auth/client.py @@ -1,13 +1,14 @@ import numpy as np -import flwr as fl +from flwr.client import ClientApp, NumPyClient +from flwr.common import Context model_params = np.array([1]) objective = 5 # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model_params @@ -23,10 +24,10 @@ def evaluate(self, parameters, config): return loss, 1, {"accuracy": accuracy} -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) diff --git a/e2e/bare-https/client.py b/e2e/bare-https/client.py index 8f5c1412fd01..4a682af3aec3 100644 --- a/e2e/bare-https/client.py +++ b/e2e/bare-https/client.py @@ -2,14 +2,15 @@ import numpy as np -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context model_params = np.array([1]) objective = 5 # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model_params @@ -25,17 +26,17 @@ def evaluate(self, parameters, config): return loss, 1, {"accuracy": accuracy} -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient().to_client(), root_certificates=Path("certificates/ca.crt").read_bytes(), diff --git a/e2e/bare/client.py b/e2e/bare/client.py index 402d775ac3a9..943e60d5db9f 100644 --- a/e2e/bare/client.py +++ b/e2e/bare/client.py @@ -2,8 +2,8 @@ import numpy as np -import flwr as fl -from flwr.common import ConfigsRecord +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import ConfigsRecord, Context SUBSET_SIZE = 1000 STATE_VAR = "timestamp" @@ -14,7 +14,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model_params @@ -51,16 +51,14 @@ def evaluate(self, parameters, config): ) -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() - ) + start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/docker/client.py b/e2e/docker/client.py index 8451b810416b..44313c7c3af6 100644 --- a/e2e/docker/client.py +++ b/e2e/docker/client.py @@ -9,6 +9,7 @@ from torchvision.transforms import Compose, Normalize, ToTensor from flwr.client import ClientApp, NumPyClient +from flwr.common import Context # ############################################################################# # 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader @@ -122,7 +123,7 @@ def evaluate(self, parameters, config): return loss, len(testloader.dataset), {"accuracy": accuracy} -def client_fn(cid: str): +def client_fn(context: Context): """Create and return an instance of Flower `Client`.""" return FlowerClient().to_client() diff --git a/e2e/fastai/README.md b/e2e/framework-fastai/README.md similarity index 100% rename from e2e/fastai/README.md rename to e2e/framework-fastai/README.md diff --git a/e2e/fastai/client.py b/e2e/framework-fastai/client.py similarity index 90% rename from e2e/fastai/client.py rename to e2e/framework-fastai/client.py index 1d98a1134941..161b27b5a548 100644 --- a/e2e/fastai/client.py +++ b/e2e/framework-fastai/client.py @@ -5,7 +5,8 @@ import torch from fastai.vision.all import * -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context warnings.filterwarnings("ignore", category=UserWarning) @@ -29,7 +30,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return [val.cpu().numpy() for _, val in learn.model.state_dict().items()] @@ -49,18 +50,18 @@ def evaluate(self, parameters, config): return loss, len(dls.valid), {"accuracy": 1 - error_rate} -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient().to_client(), ) diff --git a/e2e/fastai/pyproject.toml b/e2e/framework-fastai/pyproject.toml similarity index 100% rename from e2e/fastai/pyproject.toml rename to e2e/framework-fastai/pyproject.toml diff --git a/e2e/fastai/simulation.py b/e2e/framework-fastai/simulation.py similarity index 100% rename from e2e/fastai/simulation.py rename to e2e/framework-fastai/simulation.py diff --git a/e2e/jax/README.md b/e2e/framework-jax/README.md similarity index 100% rename from e2e/jax/README.md rename to e2e/framework-jax/README.md diff --git a/e2e/jax/client.py b/e2e/framework-jax/client.py similarity index 86% rename from e2e/jax/client.py rename to e2e/framework-jax/client.py index 347a005d923a..c9ff67b3e38e 100644 --- a/e2e/jax/client.py +++ b/e2e/framework-jax/client.py @@ -6,7 +6,8 @@ import jax_training import numpy as np -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context # Load data and determine model shape train_x, train_y, test_x, test_y = jax_training.load_data() @@ -14,7 +15,7 @@ model_shape = train_x.shape[1:] -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def __init__(self): self.params = jax_training.load_model(model_shape) @@ -48,16 +49,14 @@ def evaluate( return float(loss), num_examples, {"loss": float(loss)} -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() - ) + start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/jax/jax_training.py b/e2e/framework-jax/jax_training.py similarity index 100% rename from e2e/jax/jax_training.py rename to e2e/framework-jax/jax_training.py diff --git a/e2e/jax/pyproject.toml b/e2e/framework-jax/pyproject.toml similarity index 100% rename from e2e/jax/pyproject.toml rename to e2e/framework-jax/pyproject.toml diff --git a/e2e/jax/simulation.py b/e2e/framework-jax/simulation.py similarity index 100% rename from e2e/jax/simulation.py rename to e2e/framework-jax/simulation.py diff --git a/e2e/opacus/.gitignore b/e2e/framework-opacus/.gitignore similarity index 100% rename from e2e/opacus/.gitignore rename to e2e/framework-opacus/.gitignore diff --git a/e2e/opacus/README.md b/e2e/framework-opacus/README.md similarity index 100% rename from e2e/opacus/README.md rename to e2e/framework-opacus/README.md diff --git a/e2e/opacus/client.py b/e2e/framework-opacus/client.py similarity index 96% rename from e2e/opacus/client.py rename to e2e/framework-opacus/client.py index c9ebe319063a..167fa4584e37 100644 --- a/e2e/opacus/client.py +++ b/e2e/framework-opacus/client.py @@ -9,7 +9,8 @@ from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10 -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context # Define parameters. PARAMS = { @@ -95,7 +96,7 @@ def load_data(): # Define Flower client. -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def __init__(self, model) -> None: super().__init__() # Create a privacy engine which will add DP and keep track of the privacy budget. @@ -139,16 +140,16 @@ def evaluate(self, parameters, config): return float(loss), len(testloader), {"accuracy": float(accuracy)} -def client_fn(cid): +def client_fn(context: Context): model = Net() return FlowerClient(model).to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient(model).to_client() ) diff --git a/e2e/opacus/pyproject.toml b/e2e/framework-opacus/pyproject.toml similarity index 100% rename from e2e/opacus/pyproject.toml rename to e2e/framework-opacus/pyproject.toml diff --git a/e2e/opacus/simulation.py b/e2e/framework-opacus/simulation.py similarity index 100% rename from e2e/opacus/simulation.py rename to e2e/framework-opacus/simulation.py diff --git a/e2e/pandas/README.md b/e2e/framework-pandas/README.md similarity index 100% rename from e2e/pandas/README.md rename to e2e/framework-pandas/README.md diff --git a/e2e/pandas/client.py b/e2e/framework-pandas/client.py similarity index 82% rename from e2e/pandas/client.py rename to e2e/framework-pandas/client.py index 19e15f5a3b11..0c3300e1dd3f 100644 --- a/e2e/pandas/client.py +++ b/e2e/framework-pandas/client.py @@ -3,7 +3,8 @@ import numpy as np import pandas as pd -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context df = pd.read_csv("./data/client.csv") @@ -16,7 +17,7 @@ def compute_hist(df: pd.DataFrame, col_name: str) -> np.ndarray: # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def fit( self, parameters: List[np.ndarray], config: Dict[str, str] ) -> Tuple[List[np.ndarray], int, Dict]: @@ -32,17 +33,17 @@ def fit( ) -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient().to_client(), ) diff --git a/e2e/pandas/pyproject.toml b/e2e/framework-pandas/pyproject.toml similarity index 100% rename from e2e/pandas/pyproject.toml rename to e2e/framework-pandas/pyproject.toml diff --git a/e2e/pandas/server.py b/e2e/framework-pandas/server.py similarity index 100% rename from e2e/pandas/server.py rename to e2e/framework-pandas/server.py diff --git a/e2e/pandas/simulation.py b/e2e/framework-pandas/simulation.py similarity index 100% rename from e2e/pandas/simulation.py rename to e2e/framework-pandas/simulation.py diff --git a/e2e/pandas/strategy.py b/e2e/framework-pandas/strategy.py similarity index 100% rename from e2e/pandas/strategy.py rename to e2e/framework-pandas/strategy.py diff --git a/e2e/pytorch-lightning/README.md b/e2e/framework-pytorch-lightning/README.md similarity index 100% rename from e2e/pytorch-lightning/README.md rename to e2e/framework-pytorch-lightning/README.md diff --git a/e2e/pytorch-lightning/client.py b/e2e/framework-pytorch-lightning/client.py similarity index 89% rename from e2e/pytorch-lightning/client.py rename to e2e/framework-pytorch-lightning/client.py index fdd55b3dc344..bf291a1ca2c5 100644 --- a/e2e/pytorch-lightning/client.py +++ b/e2e/framework-pytorch-lightning/client.py @@ -4,10 +4,11 @@ import pytorch_lightning as pl import torch -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def __init__(self, model, train_loader, val_loader, test_loader): self.model = model self.train_loader = train_loader @@ -51,7 +52,7 @@ def _set_parameters(model, parameters): model.load_state_dict(state_dict, strict=True) -def client_fn(cid): +def client_fn(context: Context): model = mnist.LitAutoEncoder() train_loader, val_loader, test_loader = mnist.load_data() @@ -59,7 +60,7 @@ def client_fn(cid): return FlowerClient(model, train_loader, val_loader, test_loader).to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) @@ -71,7 +72,7 @@ def main() -> None: # Flower client client = FlowerClient(model, train_loader, val_loader, test_loader).to_client() - fl.client.start_client(server_address="127.0.0.1:8080", client=client) + start_client(server_address="127.0.0.1:8080", client=client) if __name__ == "__main__": diff --git a/e2e/pytorch-lightning/mnist.py b/e2e/framework-pytorch-lightning/mnist.py similarity index 100% rename from e2e/pytorch-lightning/mnist.py rename to e2e/framework-pytorch-lightning/mnist.py diff --git a/e2e/pytorch-lightning/pyproject.toml b/e2e/framework-pytorch-lightning/pyproject.toml similarity index 100% rename from e2e/pytorch-lightning/pyproject.toml rename to e2e/framework-pytorch-lightning/pyproject.toml diff --git a/e2e/pytorch-lightning/simulation.py b/e2e/framework-pytorch-lightning/simulation.py similarity index 100% rename from e2e/pytorch-lightning/simulation.py rename to e2e/framework-pytorch-lightning/simulation.py diff --git a/e2e/pytorch/README.md b/e2e/framework-pytorch/README.md similarity index 100% rename from e2e/pytorch/README.md rename to e2e/framework-pytorch/README.md diff --git a/e2e/pytorch/client.py b/e2e/framework-pytorch/client.py similarity index 95% rename from e2e/pytorch/client.py rename to e2e/framework-pytorch/client.py index dbfbfed1ffa7..ab4bc7b5c5b9 100644 --- a/e2e/pytorch/client.py +++ b/e2e/framework-pytorch/client.py @@ -10,8 +10,8 @@ from torchvision.transforms import Compose, Normalize, ToTensor from tqdm import tqdm -import flwr as fl -from flwr.common import ConfigsRecord +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import ConfigsRecord, Context # ############################################################################# # 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader @@ -89,7 +89,7 @@ def load_data(): # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return [val.cpu().numpy() for _, val in net.state_dict().items()] @@ -136,18 +136,18 @@ def set_parameters(model, parameters): return -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient().to_client(), ) diff --git a/e2e/pytorch/pyproject.toml b/e2e/framework-pytorch/pyproject.toml similarity index 100% rename from e2e/pytorch/pyproject.toml rename to e2e/framework-pytorch/pyproject.toml diff --git a/e2e/pytorch/simulation.py b/e2e/framework-pytorch/simulation.py similarity index 100% rename from e2e/pytorch/simulation.py rename to e2e/framework-pytorch/simulation.py diff --git a/e2e/pytorch/simulation_next.py b/e2e/framework-pytorch/simulation_next.py similarity index 100% rename from e2e/pytorch/simulation_next.py rename to e2e/framework-pytorch/simulation_next.py diff --git a/e2e/scikit-learn/README.md b/e2e/framework-scikit-learn/README.md similarity index 100% rename from e2e/scikit-learn/README.md rename to e2e/framework-scikit-learn/README.md diff --git a/e2e/scikit-learn/client.py b/e2e/framework-scikit-learn/client.py similarity index 86% rename from e2e/scikit-learn/client.py rename to e2e/framework-scikit-learn/client.py index b0691e75a79d..24c6617c1289 100644 --- a/e2e/scikit-learn/client.py +++ b/e2e/framework-scikit-learn/client.py @@ -5,7 +5,8 @@ from sklearn.linear_model import LogisticRegression from sklearn.metrics import log_loss -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context # Load MNIST dataset from https://www.openml.org/d/554 (X_train, y_train), (X_test, y_test) = utils.load_mnist() @@ -26,7 +27,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): # type: ignore return utils.get_model_parameters(model) @@ -45,16 +46,14 @@ def evaluate(self, parameters, config): # type: ignore return loss, len(X_test), {"accuracy": accuracy} -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="0.0.0.0:8080", client=FlowerClient().to_client() - ) + start_client(server_address="0.0.0.0:8080", client=FlowerClient().to_client()) diff --git a/e2e/scikit-learn/pyproject.toml b/e2e/framework-scikit-learn/pyproject.toml similarity index 100% rename from e2e/scikit-learn/pyproject.toml rename to e2e/framework-scikit-learn/pyproject.toml diff --git a/e2e/scikit-learn/simulation.py b/e2e/framework-scikit-learn/simulation.py similarity index 100% rename from e2e/scikit-learn/simulation.py rename to e2e/framework-scikit-learn/simulation.py diff --git a/e2e/scikit-learn/utils.py b/e2e/framework-scikit-learn/utils.py similarity index 100% rename from e2e/scikit-learn/utils.py rename to e2e/framework-scikit-learn/utils.py diff --git a/e2e/tensorflow/README.md b/e2e/framework-tensorflow/README.md similarity index 100% rename from e2e/tensorflow/README.md rename to e2e/framework-tensorflow/README.md diff --git a/e2e/tensorflow/client.py b/e2e/framework-tensorflow/client.py similarity index 81% rename from e2e/tensorflow/client.py rename to e2e/framework-tensorflow/client.py index 779be0c3746d..351f495a3acb 100644 --- a/e2e/tensorflow/client.py +++ b/e2e/framework-tensorflow/client.py @@ -2,7 +2,8 @@ import tensorflow as tf -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context SUBSET_SIZE = 1000 @@ -18,7 +19,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model.get_weights() @@ -33,16 +34,14 @@ def evaluate(self, parameters, config): return loss, len(x_test), {"accuracy": accuracy} -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() - ) + start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/tensorflow/pyproject.toml b/e2e/framework-tensorflow/pyproject.toml similarity index 100% rename from e2e/tensorflow/pyproject.toml rename to e2e/framework-tensorflow/pyproject.toml diff --git a/e2e/tabnet/simulation.py b/e2e/framework-tensorflow/simulation.py similarity index 100% rename from e2e/tabnet/simulation.py rename to e2e/framework-tensorflow/simulation.py diff --git a/e2e/tensorflow/simulation_next.py b/e2e/framework-tensorflow/simulation_next.py similarity index 100% rename from e2e/tensorflow/simulation_next.py rename to e2e/framework-tensorflow/simulation_next.py diff --git a/e2e/strategies/client.py b/e2e/strategies/client.py index 505340e013a5..0403416cc3b7 100644 --- a/e2e/strategies/client.py +++ b/e2e/strategies/client.py @@ -2,7 +2,8 @@ import tensorflow as tf -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context SUBSET_SIZE = 1000 @@ -33,7 +34,7 @@ def get_model(): # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model.get_weights() @@ -48,17 +49,15 @@ def evaluate(self, parameters, config): return loss, len(x_test), {"accuracy": accuracy} -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() - ) + start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/strategies/test.py b/e2e/strategies/test.py index abf9cdb5a5c7..c567f33b236b 100644 --- a/e2e/strategies/test.py +++ b/e2e/strategies/test.py @@ -3,8 +3,8 @@ import tensorflow as tf from client import SUBSET_SIZE, FlowerClient, get_model -import flwr as fl -from flwr.common import ndarrays_to_parameters +from flwr.common import Context, ndarrays_to_parameters +from flwr.server import ServerConfig from flwr.server.strategy import ( FaultTolerantFedAvg, FedAdagrad, @@ -15,6 +15,7 @@ FedYogi, QFedAvg, ) +from flwr.simulation import start_simulation STRATEGY_LIST = [ FedMedian, @@ -42,8 +43,7 @@ def get_strat(name): init_model = get_model() -def client_fn(cid): - _ = cid +def client_fn(context: Context): return FlowerClient() @@ -71,10 +71,10 @@ def evaluate(server_round, parameters, config): if start_idx >= OPT_IDX: strat_args["tau"] = 0.01 -hist = fl.simulation.start_simulation( +hist = start_simulation( client_fn=client_fn, num_clients=2, - config=fl.server.ServerConfig(num_rounds=3), + config=ServerConfig(num_rounds=3), strategy=strategy(**strat_args), ) diff --git a/e2e/tabnet/README.md b/e2e/tabnet/README.md deleted file mode 100644 index 258043c3ffa8..000000000000 --- a/e2e/tabnet/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Flower with Tabnet testing - -This directory is used for testing Flower with Tabnet. - -It uses the `FedAvg` strategy. diff --git a/e2e/tabnet/client.py b/e2e/tabnet/client.py deleted file mode 100644 index 1a7ecfd68f73..000000000000 --- a/e2e/tabnet/client.py +++ /dev/null @@ -1,95 +0,0 @@ -import os - -import tabnet -import tensorflow as tf -import tensorflow_datasets as tfds - -import flwr as fl - -train_size = 125 -BATCH_SIZE = 50 -col_names = ["sepal_length", "sepal_width", "petal_length", "petal_width"] - - -def transform(ds): - features = tf.unstack(ds["features"]) - labels = ds["label"] - - x = dict(zip(col_names, features)) - y = tf.one_hot(labels, 3) - return x, y - - -def prepare_iris_dataset(): - ds_full = tfds.load(name="iris", split=tfds.Split.TRAIN) - ds_full = ds_full.shuffle(150, seed=0) - - ds_train = ds_full.take(train_size) - ds_train = ds_train.map(transform) - ds_train = ds_train.batch(BATCH_SIZE) - - ds_test = ds_full.skip(train_size) - ds_test = ds_test.map(transform) - ds_test = ds_test.batch(BATCH_SIZE) - - feature_columns = [] - for col_name in col_names: - feature_columns.append(tf.feature_column.numeric_column(col_name)) - - return ds_train, ds_test, feature_columns - - -ds_train, ds_test, feature_columns = prepare_iris_dataset() -# Make TensorFlow log less verbose -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - -# Load TabNet model -model = tabnet.TabNetClassifier( - feature_columns, - num_classes=3, - feature_dim=8, - output_dim=4, - num_decision_steps=4, - relaxation_factor=1.0, - sparsity_coefficient=1e-5, - batch_momentum=0.98, - virtual_batch_size=None, - norm_type="group", - num_groups=1, -) -lr = tf.keras.optimizers.schedules.ExponentialDecay( - 0.01, decay_steps=100, decay_rate=0.9, staircase=False -) -optimizer = tf.keras.optimizers.Adam(lr) -model.compile(optimizer, loss="categorical_crossentropy", metrics=["accuracy"]) - - -# Define Flower client -class FlowerClient(fl.client.NumPyClient): - def get_parameters(self, config): - return model.get_weights() - - def fit(self, parameters, config): - model.set_weights(parameters) - model.fit(ds_train, epochs=25) - return model.get_weights(), len(ds_train), {} - - def evaluate(self, parameters, config): - model.set_weights(parameters) - loss, accuracy = model.evaluate(ds_test) - return loss, len(ds_train), {"accuracy": accuracy} - - -def client_fn(cid): - return FlowerClient().to_client() - - -app = fl.client.ClientApp( - client_fn=client_fn, -) - -if __name__ == "__main__": - # Start Flower client - fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() - ) diff --git a/e2e/tabnet/pyproject.toml b/e2e/tabnet/pyproject.toml deleted file mode 100644 index 99379ddb607e..000000000000 --- a/e2e/tabnet/pyproject.toml +++ /dev/null @@ -1,25 +0,0 @@ -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[project] -name = "quickstart-tabnet-test" -version = "0.1.0" -description = "Tabnet Federated Learning E2E test with Flower" -authors = [ - { name = "The Flower Authors", email = "hello@flower.ai" }, -] -dependencies = [ - "flwr[simulation] @ {root:parent:parent:uri}", - "tensorflow-cpu>=2.9.1,!=2.11.1; platform_machine == \"x86_64\"", - "tensorflow-macos>=2.9.1,!=2.11.1; sys_platform == \"darwin\" and platform_machine == \"arm64\"", - "tensorflow_datasets==4.9.2", - "tensorflow-io-gcs-filesystem<0.35.0", - "tabnet==0.1.6", -] - -[tool.hatch.build.targets.wheel] -packages = ["."] - -[tool.hatch.metadata] -allow-direct-references = true diff --git a/e2e/tensorflow/simulation.py b/e2e/tensorflow/simulation.py deleted file mode 100644 index bf05a77cf32a..000000000000 --- a/e2e/tensorflow/simulation.py +++ /dev/null @@ -1,14 +0,0 @@ -from client import client_fn - -import flwr as fl - -hist = fl.simulation.start_simulation( - client_fn=client_fn, - num_clients=2, - config=fl.server.ServerConfig(num_rounds=3), -) - -assert ( - hist.losses_distributed[-1][1] == 0 - or (hist.losses_distributed[0][1] / hist.losses_distributed[-1][1]) >= 0.98 -) diff --git a/e2e/test.sh b/e2e/test_legacy.sh similarity index 96% rename from e2e/test.sh rename to e2e/test_legacy.sh index 4ea17a4f994b..dc17ca8c6378 100755 --- a/e2e/test.sh +++ b/e2e/test_legacy.sh @@ -2,7 +2,7 @@ set -e case "$1" in - pandas) + framework-pandas) server_file="server.py" ;; bare-https) diff --git a/e2e/test_driver.sh b/e2e/test_superlink.sh similarity index 90% rename from e2e/test_driver.sh rename to e2e/test_superlink.sh index e177863bab78..1bb81cc47ea1 100755 --- a/e2e/test_driver.sh +++ b/e2e/test_superlink.sh @@ -2,7 +2,7 @@ set -e case "$1" in - pandas) + framework-pandas) server_arg="--insecure" client_arg="--insecure" server_dir="./" @@ -70,11 +70,11 @@ timeout 2m flower-superlink $server_arg $db_arg $rest_arg_superlink $server_auth sl_pid=$! sleep 3 -timeout 2m flower-client-app client:app $client_arg $rest_arg_supernode --superlink $server_address $client_auth_1 & +timeout 2m flower-supernode client:app $client_arg $rest_arg_supernode --superlink $server_address $client_auth_1 & cl1_pid=$! sleep 3 -timeout 2m flower-client-app client:app $client_arg $rest_arg_supernode --superlink $server_address $client_auth_2 & +timeout 2m flower-supernode client:app $client_arg $rest_arg_supernode --superlink $server_address $client_auth_2 & cl2_pid=$! sleep 3 diff --git a/examples/advanced-pytorch/README.md b/examples/advanced-pytorch/README.md index c1ba85b95879..ac0737673407 100644 --- a/examples/advanced-pytorch/README.md +++ b/examples/advanced-pytorch/README.md @@ -1,3 +1,9 @@ +--- +tags: [advanced, vision, fds] +dataset: [CIFAR-10] +framework: [torch, torchvision] +--- + # Advanced Flower Example (PyTorch) This example demonstrates an advanced federated learning setup using Flower with PyTorch. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) and it differs from the quickstart example in the following ways: diff --git a/examples/advanced-tensorflow/README.md b/examples/advanced-tensorflow/README.md index 94707b5cbc98..375c539d13dd 100644 --- a/examples/advanced-tensorflow/README.md +++ b/examples/advanced-tensorflow/README.md @@ -1,3 +1,9 @@ +--- +tags: [advanced, vision, fds] +dataset: [CIFAR-10] +framework: [tensorflow, Keras] +--- + # Advanced Flower Example (TensorFlow/Keras) This example demonstrates an advanced federated learning setup using Flower with TensorFlow/Keras. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) and it differs from the quickstart example in the following ways: diff --git a/examples/android-kotlin/README.md b/examples/android-kotlin/README.md index 2d0f704fdc0e..6cadb8e436fe 100644 --- a/examples/android-kotlin/README.md +++ b/examples/android-kotlin/README.md @@ -1,3 +1,9 @@ +--- +tags: [mobile, vision, sdk] +dataset: [CIFAR-10] +framework: [Android, Kotlin, TensorFlowLite] +--- + # Flower Android Client Example with Kotlin and TensorFlow Lite 2022 This example is similar to the Flower Android Example in Java: diff --git a/examples/android/README.md b/examples/android/README.md index f9f2bb93b8dc..83519f15d04d 100644 --- a/examples/android/README.md +++ b/examples/android/README.md @@ -1,3 +1,9 @@ +--- +tags: [mobile, vision, sdk] +dataset: [CIFAR-10] +framework: [Android, Java, TensorFlowLite] +--- + # Flower Android Example (TensorFlowLite) This example demonstrates a federated learning setup with Android clients in a background thread. The training on Android is done on a CIFAR10 dataset using TensorFlow Lite. The setup is as follows: diff --git a/examples/app-pytorch/README.md b/examples/app-pytorch/README.md index 14de3c7d632e..5cfae8440ed2 100644 --- a/examples/app-pytorch/README.md +++ b/examples/app-pytorch/README.md @@ -1,3 +1,9 @@ +--- +tags: [basic, vision, fds] +dataset: [CIFAR-10] +framework: [torch, torchvision] +--- + # Flower App (PyTorch) 🧪 > 🧪 = This example covers experimental features that might change in future versions of Flower diff --git a/examples/app-secure-aggregation/README.md b/examples/app-secure-aggregation/README.md index d1ea7bdc893f..8e483fb2f6bd 100644 --- a/examples/app-secure-aggregation/README.md +++ b/examples/app-secure-aggregation/README.md @@ -1,3 +1,9 @@ +--- +tags: [basic, vision, fds] +dataset: [] +framework: [numpy] +--- + # Secure aggregation with Flower (the SecAgg+ protocol) 🧪 > 🧪 = This example covers experimental features that might change in future versions of Flower diff --git a/examples/custom-metrics/README.md b/examples/custom-metrics/README.md index 317fb6336106..dd6985070cef 100644 --- a/examples/custom-metrics/README.md +++ b/examples/custom-metrics/README.md @@ -1,3 +1,10 @@ +--- +title: Example Flower App with Custom Metrics +tags: [basic, vision, fds] +dataset: [CIFAR-10] +framework: [tensorflow] +--- + # Flower Example using Custom Metrics This simple example demonstrates how to calculate custom metrics over multiple clients beyond the traditional ones available in the ML frameworks. In this case, it demonstrates the use of ready-available `scikit-learn` metrics: accuracy, recall, precision, and f1-score. diff --git a/examples/custom-mods/README.md b/examples/custom-mods/README.md index 6b03abcfbfe0..c2007eb323ae 100644 --- a/examples/custom-mods/README.md +++ b/examples/custom-mods/README.md @@ -1,3 +1,9 @@ +--- +tags: [mods, monitoring, app] +dataset: [CIFAR-10] +framework: [wandb, tensorboard] +--- + # Using custom mods 🧪 > 🧪 = This example covers experimental features that might change in future versions of Flower @@ -207,7 +213,7 @@ app = fl.client.ClientApp( client_fn=client_fn, mods=[ get_wandb_mod("Custom mods example"), - ], + ], ) ``` diff --git a/examples/doc/source/.gitignore b/examples/doc/source/.gitignore index dd449725e188..73ee14e96f68 100644 --- a/examples/doc/source/.gitignore +++ b/examples/doc/source/.gitignore @@ -1 +1,2 @@ *.md +index.rst diff --git a/examples/embedded-devices/README.md b/examples/embedded-devices/README.md index f1c5931b823a..86f19399932d 100644 --- a/examples/embedded-devices/README.md +++ b/examples/embedded-devices/README.md @@ -1,3 +1,9 @@ +--- +tags: [basic, vision, fds] +dataset: [CIFAR-10, MNIST] +framework: [torch, tensorflow] +--- + # Federated Learning on Embedded Devices with Flower This example will show you how Flower makes it very easy to run Federated Learning workloads on edge devices. Here we'll be showing how to use NVIDIA Jetson devices and Raspberry Pi as Flower clients. You can run this example using either PyTorch or Tensorflow. The FL workload (i.e. model, dataset and training loop) is mostly borrowed from the [quickstart-pytorch](https://github.com/adap/flower/tree/main/examples/simulation-pytorch) and [quickstart-tensorflow](https://github.com/adap/flower/tree/main/examples/quickstart-tensorflow) examples. @@ -65,7 +71,7 @@ If you are working on this tutorial on your laptop or desktop, it can host the F - Install `pip`. In the terminal type: `sudo apt install python3-pip -y` - Now clone this directory. You just need to execute the `git clone` command shown at the top of this README.md on your device. - - Install Flower and your ML framework: We have prepared some convenient installation scripts that will install everything you need. You are free to install other versions of these ML frameworks to suit your needs. + - Install Flower and your ML framework of choice: We have prepared some convenient installation scripts that will install everything you need. You are free to install other versions of these ML frameworks to suit your needs. - If you want your clients to use PyTorch: `pip3 install -r requirements_pytorch.txt` - If you want your clients to use TensorFlow: `pip3 install -r requirements_tf.txt` diff --git a/examples/federated-kaplan-meier-fitter/README.md b/examples/federated-kaplan-meier-fitter/README.md index 1569467d6f82..20d4ca4c47af 100644 --- a/examples/federated-kaplan-meier-fitter/README.md +++ b/examples/federated-kaplan-meier-fitter/README.md @@ -1,3 +1,9 @@ +--- +tags: [estimator, medical] +dataset: [Waltons] +framework: [lifelines] +--- + # Flower Example using KaplanMeierFitter This is an introductory example on **federated survival analysis** using [Flower](https://flower.ai/) diff --git a/examples/fl-dp-sa/README.md b/examples/fl-dp-sa/README.md index 47eedb70a2b8..65c8a5b18fa8 100644 --- a/examples/fl-dp-sa/README.md +++ b/examples/fl-dp-sa/README.md @@ -1,4 +1,10 @@ -# fl_dp_sa +--- +tags: [basic, vision, fds] +dataset: [MNIST] +framework: [torch, torchvision] +--- + +# Example of Flower App with DP and SA This is a simple example that utilizes central differential privacy with client-side fixed clipping and secure aggregation. Note: This example is designed for a small number of rounds and is intended for demonstration purposes. diff --git a/examples/fl-tabular/README.md b/examples/fl-tabular/README.md index 58afd1080b70..ee6dd7d00ef0 100644 --- a/examples/fl-tabular/README.md +++ b/examples/fl-tabular/README.md @@ -1,3 +1,9 @@ +--- +tags: [basic, tabular, fds] +dataset: [Adult Census Income] +framework: [scikit-learn, torch] +--- + # Flower Example on Adult Census Income Tabular Dataset This code exemplifies a federated learning setup using the Flower framework on the ["Adult Census Income"](https://huggingface.co/datasets/scikit-learn/adult-census-income) tabular dataset. The "Adult Census Income" dataset contains demographic information such as age, education, occupation, etc., with the target attribute being income level (\<=50K or >50K). The dataset is partitioned into subsets, simulating a federated environment with 5 clients, each holding a distinct portion of the data. Categorical variables are one-hot encoded, and the data is split into training and testing sets. Federated learning is conducted using the FedAvg strategy for 5 rounds. diff --git a/examples/flower-authentication/README.md b/examples/flower-authentication/README.md index 589270e621c9..d10780eeae5d 100644 --- a/examples/flower-authentication/README.md +++ b/examples/flower-authentication/README.md @@ -1,3 +1,9 @@ +--- +tags: [advanced, vision, fds] +dataset: [CIFAR-10] +framework: [torch, torchvision] +--- + # Flower Authentication with PyTorch 🧪 > 🧪 = This example covers experimental features that might change in future versions of Flower diff --git a/examples/flower-in-30-minutes/README.md b/examples/flower-in-30-minutes/README.md index 5fd9b882413b..faec3d72dae2 100644 --- a/examples/flower-in-30-minutes/README.md +++ b/examples/flower-in-30-minutes/README.md @@ -1,3 +1,9 @@ +--- +tags: [colab, vision, simulation] +dataset: [CIFAR-10] +framework: [torch] +--- + # 30-minute tutorial running Flower simulation with PyTorch This README links to a Jupyter notebook that you can either download and run locally or [![open it in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adap/flower/blob/main/examples/flower-in-30-minutes/tutorial.ipynb). This is a short 30-minute (or less!) tutorial showcasing the basics of Flower federated learning simulations using PyTorch. diff --git a/examples/flower-simulation-step-by-step-pytorch/README.md b/examples/flower-simulation-step-by-step-pytorch/README.md index beb8dd7f6f95..b00afedbe80b 100644 --- a/examples/flower-simulation-step-by-step-pytorch/README.md +++ b/examples/flower-simulation-step-by-step-pytorch/README.md @@ -1,3 +1,9 @@ +--- +tags: [basic, vision, simulation] +dataset: [MNIST] +framework: [torch] +--- + # Flower Simulation Step-by-Step > Since this tutorial (and its video series) was put together, Flower has been updated a few times. As a result, some of the steps to construct the environment (see below) have been updated. Some parts of the code have also been updated. Overall, the content of this tutorial and how things work remains the same as in the video tutorials. diff --git a/examples/flower-via-docker-compose/README.md b/examples/flower-via-docker-compose/README.md index 3ef1ac37bcda..3325a731fecf 100644 --- a/examples/flower-via-docker-compose/README.md +++ b/examples/flower-via-docker-compose/README.md @@ -1,3 +1,10 @@ +--- +title: Leveraging Flower and Docker for Device Heterogeneity Management in FL +tags: [deployment, vision, tutorial] +dataset: [CIFAR-10] +framework: [Docker, tensorflow] +--- + # Leveraging Flower and Docker for Device Heterogeneity Management in Federated Learning

diff --git a/examples/ios/README.md b/examples/ios/README.md index 4e17e7a674f3..aef4177dddf7 100644 --- a/examples/ios/README.md +++ b/examples/ios/README.md @@ -1,3 +1,9 @@ +--- +tags: [mobile, vision, sdk] +dataset: [MNIST] +framework: [Swift] +--- + # FLiOS - A Flower SDK for iOS Devices with Example FLiOS is a sample application for testing and benchmarking the Swift implementation of Flower. The default scenario uses the MNIST dataset and the associated digit recognition model. The app includes the Swift package in `./src/swift` and allows extension for other benchmarking scenarios. The app guides the user through the steps of the machine learning process that would be executed in a normal production environment as a background task of the application. The app is therefore aimed at researchers and research institutions to test their hypotheses and perform performance analyses. diff --git a/examples/llm-flowertune/README.md b/examples/llm-flowertune/README.md index 4f98072f8c7f..46076e0b2078 100644 --- a/examples/llm-flowertune/README.md +++ b/examples/llm-flowertune/README.md @@ -1,3 +1,10 @@ +--- +title: Federated LLM Fine-tuning with Flower +tags: [llm, nlp, LLama2] +dataset: [Alpaca-GPT4] +framework: [PEFT, torch] +--- + # LLM FlowerTune: Federated LLM Fine-tuning with Flower Large language models (LLMs), which have been trained on vast amounts of publicly accessible data, have shown remarkable effectiveness in a wide range of areas. diff --git a/examples/opacus/README.md b/examples/opacus/README.md index 6fc0d2ff49a0..aea5d0f689fe 100644 --- a/examples/opacus/README.md +++ b/examples/opacus/README.md @@ -1,3 +1,9 @@ +--- +tags: [dp, security, fds] +dataset: [CIFAR-10] +framework: [opacus, torch] +--- + # Training with Sample-Level Differential Privacy using Opacus Privacy Engine In this example, we demonstrate how to train a model with differential privacy (DP) using Flower. We employ PyTorch and integrate the Opacus Privacy Engine to achieve sample-level differential privacy. This setup ensures robust privacy guarantees during the client training phase. The code is adapted from the [PyTorch Quickstart example](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch). diff --git a/examples/pytorch-federated-variational-autoencoder/README.md b/examples/pytorch-federated-variational-autoencoder/README.md index 00af7a6328b2..52f94a16307c 100644 --- a/examples/pytorch-federated-variational-autoencoder/README.md +++ b/examples/pytorch-federated-variational-autoencoder/README.md @@ -1,3 +1,9 @@ +--- +tags: [basic, vision, fds] +dataset: [CIFAR-10] +framework: [torch, torchvision] +--- + # Flower Example for Federated Variational Autoencoder using Pytorch This example demonstrates how a variational autoencoder (VAE) can be trained in a federated way using the Flower framework. diff --git a/examples/pytorch-from-centralized-to-federated/README.md b/examples/pytorch-from-centralized-to-federated/README.md index 06ee89dddcac..1bff7d02f52c 100644 --- a/examples/pytorch-from-centralized-to-federated/README.md +++ b/examples/pytorch-from-centralized-to-federated/README.md @@ -1,3 +1,9 @@ +--- +tags: [basic, vision, fds] +dataset: [CIFAR-10] +framework: [torch] +--- + # PyTorch: From Centralized To Federated This example demonstrates how an already existing centralized PyTorch-based machine learning project can be federated with Flower. diff --git a/examples/quickstart-cpp/README.md b/examples/quickstart-cpp/README.md index d6cbeebe1bc6..61b76ece52b0 100644 --- a/examples/quickstart-cpp/README.md +++ b/examples/quickstart-cpp/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, linear regression, tabular] +dataset: [Synthetic] +framework: [C++] +--- + # Flower Clients in C++ (under development) In this example you will train a linear model on synthetic data using C++ clients. diff --git a/examples/quickstart-fastai/README.md b/examples/quickstart-fastai/README.md index 38ef23c95a1e..d1bf97cd4203 100644 --- a/examples/quickstart-fastai/README.md +++ b/examples/quickstart-fastai/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, vision] +dataset: [MNIST] +framework: [fastai] +--- + # Flower Example using fastai This introductory example to Flower uses [fastai](https://www.fast.ai/), but deep knowledge of fastai is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. diff --git a/examples/quickstart-huggingface/README.md b/examples/quickstart-huggingface/README.md index ce7790cd4af5..fa4330040ea7 100644 --- a/examples/quickstart-huggingface/README.md +++ b/examples/quickstart-huggingface/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, llm, nlp, sentiment] +dataset: [IMDB] +framework: [transformers] +--- + # Federated HuggingFace Transformers using Flower and PyTorch This introductory example to using [HuggingFace](https://huggingface.co) Transformers with Flower with PyTorch. This example has been extended from the [quickstart-pytorch](https://flower.ai/docs/examples/quickstart-pytorch.html) example. The training script closely follows the [HuggingFace course](https://huggingface.co/course/chapter3?fw=pt), so you are encouraged to check that out for a detailed explanation of the transformer pipeline. diff --git a/examples/quickstart-jax/README.md b/examples/quickstart-jax/README.md index 836adf558d88..b47f3a82e13b 100644 --- a/examples/quickstart-jax/README.md +++ b/examples/quickstart-jax/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, linear regression] +dataset: [Synthetic] +framework: [JAX] +--- + # JAX: From Centralized To Federated This example demonstrates how an already existing centralized JAX-based machine learning project can be federated with Flower. diff --git a/examples/quickstart-mlcube/README.md b/examples/quickstart-mlcube/README.md index 8e6fc29b3ad8..f0c6c5664a82 100644 --- a/examples/quickstart-mlcube/README.md +++ b/examples/quickstart-mlcube/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, vision, deployment] +dataset: [MNIST] +framework: [mlcube, tensorflow, Keras] +--- + # Flower Example using TensorFlow/Keras + MLCube This introductory example to Flower uses MLCube together with Keras, but deep knowledge of Keras is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use-cases with MLCube. Running this example in itself is quite easy. diff --git a/examples/quickstart-mlx/README.md b/examples/quickstart-mlx/README.md index cca55bcb946a..a4ac44bf8460 100644 --- a/examples/quickstart-mlx/README.md +++ b/examples/quickstart-mlx/README.md @@ -1,3 +1,10 @@ +--- +title: Simple Flower Example using MLX +tags: [quickstart, vision] +dataset: [MNIST] +framework: [MLX] +--- + # Flower Example using MLX This introductory example to Flower uses [MLX](https://ml-explore.github.io/mlx/build/html/index.html), but deep knowledge of MLX is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. diff --git a/examples/quickstart-monai/README.md b/examples/quickstart-monai/README.md index 4a9afef4f86a..dc31f03e4b1b 100644 --- a/examples/quickstart-monai/README.md +++ b/examples/quickstart-monai/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, medical, vision] +dataset: [MedNIST] +framework: [MONAI] +--- + # Flower Example using MONAI This introductory example to Flower uses MONAI, but deep knowledge of MONAI is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. diff --git a/examples/quickstart-pandas/README.md b/examples/quickstart-pandas/README.md index dd69f3ead3cb..0b4b3a6ac78a 100644 --- a/examples/quickstart-pandas/README.md +++ b/examples/quickstart-pandas/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, tabular, federated analytics] +dataset: [Iris] +framework: [pandas] +--- + # Flower Example using Pandas This introductory example to Flower uses Pandas, but deep knowledge of Pandas is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to diff --git a/examples/quickstart-pytorch-lightning/README.md b/examples/quickstart-pytorch-lightning/README.md index fb29c7e9e9ea..04eb911818fc 100644 --- a/examples/quickstart-pytorch-lightning/README.md +++ b/examples/quickstart-pytorch-lightning/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, vision, fds] +dataset: [MNIST] +framework: [lightning] +--- + # Flower Example using PyTorch Lightning This introductory example to Flower uses PyTorch, but deep knowledge of PyTorch Lightning is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the MNIST dataset. diff --git a/examples/quickstart-pytorch/README.md b/examples/quickstart-pytorch/README.md index 93d6a593f362..8eace1ea6845 100644 --- a/examples/quickstart-pytorch/README.md +++ b/examples/quickstart-pytorch/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, vision, fds] +dataset: [CIFAR-10] +framework: [torch, torchvision] +--- + # Flower Example using PyTorch This introductory example to Flower uses PyTorch, but deep knowledge of PyTorch is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset. diff --git a/examples/quickstart-sklearn-tabular/README.md b/examples/quickstart-sklearn-tabular/README.md index a975a9392800..b0b4cd1b84c0 100644 --- a/examples/quickstart-sklearn-tabular/README.md +++ b/examples/quickstart-sklearn-tabular/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, tabular, fds] +dataset: [Iris] +framework: [scikit-learn] +--- + # Flower Example using scikit-learn This example of Flower uses `scikit-learn`'s `LogisticRegression` model to train a federated learning system on diff --git a/examples/quickstart-tabnet/README.md b/examples/quickstart-tabnet/README.md index 19a139f83064..e8be55eaacef 100644 --- a/examples/quickstart-tabnet/README.md +++ b/examples/quickstart-tabnet/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, tabular] +dataset: [Iris] +framework: [tabnet] +--- + # Flower TabNet Example using TensorFlow This introductory example to Flower uses Keras but deep knowledge of Keras is not necessarily required to run the example. However, it will help you understanding how to adapt Flower to your use-cases. You can learn more about TabNet from [paper](https://arxiv.org/abs/1908.07442) and its implementation using TensorFlow at [this repository](https://github.com/titu1994/tf-TabNet). Note also that the basis of this example using federated learning is the example from the repository above. diff --git a/examples/quickstart-tensorflow/README.md b/examples/quickstart-tensorflow/README.md index ae1fe19834a3..386f8bbd96f0 100644 --- a/examples/quickstart-tensorflow/README.md +++ b/examples/quickstart-tensorflow/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, vision, fds] +dataset: [CIFAR-10] +framework: [tensorflow] +--- + # Flower Example using TensorFlow/Keras This introductory example to Flower uses Keras but deep knowledge of Keras is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. diff --git a/examples/simulation-pytorch/README.md b/examples/simulation-pytorch/README.md index 93f9e1acbac7..2dbfbc849ab7 100644 --- a/examples/simulation-pytorch/README.md +++ b/examples/simulation-pytorch/README.md @@ -1,3 +1,9 @@ +--- +tags: [basic, vision, fds, simulation] +dataset: [MNIST] +framework: [torch, torchvision] +--- + # Flower Simulation example using PyTorch This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on a single machine. Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for a deep dive into how Flower simulation works. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the MNIST dataset. This examples uses 100 clients by default. diff --git a/examples/simulation-tensorflow/README.md b/examples/simulation-tensorflow/README.md index 917d7b34c7af..047cb4379659 100644 --- a/examples/simulation-tensorflow/README.md +++ b/examples/simulation-tensorflow/README.md @@ -1,3 +1,9 @@ +--- +tags: [basic, vision, fds, simulation] +dataset: [MNIST] +framework: [tensorflow, Keras] +--- + # Flower Simulation example using TensorFlow/Keras This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on a single machine. Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for a deep dive into how Flower simulation works. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the MNIST dataset. This examples uses 100 clients by default. diff --git a/examples/sklearn-logreg-mnist/README.md b/examples/sklearn-logreg-mnist/README.md index 12b1a5e3bc1a..b117c5452086 100644 --- a/examples/sklearn-logreg-mnist/README.md +++ b/examples/sklearn-logreg-mnist/README.md @@ -1,4 +1,10 @@ -# Flower Example using scikit-learn +--- +tags: [basic, vision, logistic regression, fds] +dataset: [MNIST] +framework: [scikit-learn] +--- + +# Flower Logistic Regression Example using scikit-learn This example of Flower uses `scikit-learn`'s `LogisticRegression` model to train a federated learning system. It will help you understand how to adapt Flower for use with `scikit-learn`. Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the MNIST dataset. diff --git a/examples/tensorflow-privacy/README.md b/examples/tensorflow-privacy/README.md index a1f1be00f6b0..8156f92f60c9 100644 --- a/examples/tensorflow-privacy/README.md +++ b/examples/tensorflow-privacy/README.md @@ -1,3 +1,9 @@ +--- +tags: [basic, vision, fds, privacy, dp] +dataset: [MNIST] +framework: [tensorflow] +--- + # Training with Sample-Level Differential Privacy using TensorFlow-Privacy Engine In this example, we demonstrate how to train a model with sample-level differential privacy (DP) using Flower. We employ TensorFlow and integrate the tensorflow-privacy Engine to achieve sample-level differential privacy. This setup ensures robust privacy guarantees during the client training phase. diff --git a/examples/vertical-fl/README.md b/examples/vertical-fl/README.md index ba8228a059f9..ab5d2210d8d5 100644 --- a/examples/vertical-fl/README.md +++ b/examples/vertical-fl/README.md @@ -1,3 +1,10 @@ +--- +title: Vertical FL Flower Example +tags: [vertical, tabular, advanced] +dataset: [Titanic] +framework: [torch, pandas, scikit-learn] +--- + # Vertical Federated Learning example This example will showcase how you can perform Vertical Federated Learning using diff --git a/examples/vit-finetune/README.md b/examples/vit-finetune/README.md index ac1652acf02d..957c0eda0b68 100644 --- a/examples/vit-finetune/README.md +++ b/examples/vit-finetune/README.md @@ -1,3 +1,10 @@ +--- +title: Federated finetuning of a ViT +tags: [finetuneing, vision, fds] +dataset: [Oxford Flower-102] +framework: [torch, torchvision] +--- + # Federated finetuning of a ViT This example shows how to use Flower's Simulation Engine to federate the finetuning of a Vision Transformer ([ViT-Base-16](https://pytorch.org/vision/main/models/generated/torchvision.models.vit_b_16.html#torchvision.models.vit_b_16)) that has been pretrained on ImageNet. To keep things simple we'll be finetuning it to [Oxford Flower-102](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html) datasset, creating 20 partitions using [Flower Datasets](https://flower.ai/docs/datasets/). We'll be finetuning just the exit `head` of the ViT, this means that the training is not that costly and each client requires just ~1GB of VRAM (for a batch size of 32 images). diff --git a/examples/whisper-federated-finetuning/README.md b/examples/whisper-federated-finetuning/README.md index ddebe51247b2..cfd0db842bae 100644 --- a/examples/whisper-federated-finetuning/README.md +++ b/examples/whisper-federated-finetuning/README.md @@ -1,3 +1,9 @@ +--- +tags: [finetuning, speech, transformers] +dataset: [SpeechCommands] +framework: [transformers, whisper] +--- + # On-device Federated Finetuning for Speech Classification This example demonstrates how to, from a pre-trained [Whisper](https://openai.com/research/whisper) model, finetune it for the downstream task of keyword spotting. We'll be implementing a federated downstream finetuning pipeline using Flower involving a total of 100 clients. As for the downstream dataset, we'll be using the [Google Speech Commands](https://huggingface.co/datasets/speech_commands) dataset for keyword spotting. We'll take the encoder part of the [Whisper-tiny](https://huggingface.co/openai/whisper-tiny) model, freeze its parameters, and learn a lightweight classification (\<800K parameters !!) head to correctly classify a spoken word. diff --git a/examples/xgboost-comprehensive/README.md b/examples/xgboost-comprehensive/README.md index dc6d7e3872d6..62fcba2bb06d 100644 --- a/examples/xgboost-comprehensive/README.md +++ b/examples/xgboost-comprehensive/README.md @@ -1,3 +1,9 @@ +--- +tags: [advanced, classification, tabular] +dataset: [HIGGS] +framework: [xgboost] +--- + # Flower Example using XGBoost (Comprehensive) This example demonstrates a comprehensive federated learning setup using Flower with XGBoost. diff --git a/examples/xgboost-quickstart/README.md b/examples/xgboost-quickstart/README.md index 713b6eab8bac..fa3e9d0dc6fb 100644 --- a/examples/xgboost-quickstart/README.md +++ b/examples/xgboost-quickstart/README.md @@ -1,3 +1,9 @@ +--- +tags: [quickstart, classification, tabular] +dataset: [HIGGS] +framework: [xgboost] +--- + # Flower Example using XGBoost This example demonstrates how to perform EXtreme Gradient Boosting (XGBoost) within Flower using `xgboost` package. diff --git a/pyproject.toml b/pyproject.toml index 5daf007471aa..c5ab0e5edcee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ flower-simulation = "flwr.simulation.run_simulation:run_simulation_from_cli" python = "^3.8" # Mandatory dependencies numpy = "^1.21.0" -grpcio = "^1.60.0" +grpcio = "^1.60.0,!=1.64.2,!=1.65.0" protobuf = "^4.25.2" cryptography = "^42.0.4" pycryptodome = "^3.18.0" @@ -131,12 +131,7 @@ licensecheck = "==2024" pre-commit = "==3.5.0" [tool.isort] -line_length = 88 -indent = " " -multi_line_output = 3 -include_trailing_comma = true -force_grid_wrap = 0 -use_parentheses = true +profile = "black" known_first_party = ["flwr", "flwr_tool"] [tool.black] diff --git a/src/docker/base/alpine/Dockerfile b/src/docker/base/alpine/Dockerfile index 04864b525e2e..9e58d82e3bda 100644 --- a/src/docker/base/alpine/Dockerfile +++ b/src/docker/base/alpine/Dockerfile @@ -54,9 +54,11 @@ FROM python:${PYTHON_VERSION}-${DISTRO}${DISTRO_VERSION} as base # required by the grpc package RUN apk add --no-cache \ libstdc++ \ + ca-certificates \ # add non-root user && adduser \ --no-create-home \ + --home /app \ --disabled-password \ --gecos "" \ --uid 49999 app \ diff --git a/src/docker/base/ubuntu/Dockerfile b/src/docker/base/ubuntu/Dockerfile index 4aeddc3f8d8d..960ed07edf96 100644 --- a/src/docker/base/ubuntu/Dockerfile +++ b/src/docker/base/ubuntu/Dockerfile @@ -48,29 +48,7 @@ RUN git clone https://github.com/pyenv/pyenv.git \ RUN LATEST=$(pyenv latest -k ${PYTHON_VERSION}) \ && python-build "${LATEST}" /usr/local/bin/python -FROM $DISTRO:$DISTRO_VERSION as base - -ENV DEBIAN_FRONTEND=noninteractive - -RUN apt-get update \ - && apt-get -y --no-install-recommends install \ - libsqlite3-0 \ - && rm -rf /var/lib/apt/lists/* - -COPY --from=python /usr/local/bin/python /usr/local/bin/python - -ENV PATH=/usr/local/bin/python/bin:$PATH \ - # Send stdout and stderr stream directly to the terminal. Ensures that no - # output is retained in a buffer if the application crashes. - PYTHONUNBUFFERED=1 \ - # Typically, bytecode is created on the first invocation to speed up following invocation. - # However, in Docker we only make a single invocation (when we start the container). - # Therefore, we can disable bytecode writing. - PYTHONDONTWRITEBYTECODE=1 \ - # Ensure that python encoding is always UTF-8. - PYTHONIOENCODING=UTF-8 \ - LANG=C.UTF-8 \ - LC_ALL=C.UTF-8 +ENV PATH=/usr/local/bin/python/bin:$PATH # Use a virtual environment to ensure that Python packages are installed in the same location # regardless of whether the subsequent image build is run with the app or the root user @@ -86,16 +64,43 @@ RUN pip install -U --no-cache-dir \ setuptools==${SETUPTOOLS_VERSION} \ ${FLWR_PACKAGE}==${FLWR_VERSION} -# add non-root user -RUN adduser \ +FROM $DISTRO:$DISTRO_VERSION as base + +COPY --from=python /usr/local/bin/python /usr/local/bin/python + +ENV DEBIAN_FRONTEND=noninteractive \ + PATH=/usr/local/bin/python/bin:$PATH + +RUN apt-get update \ + && apt-get -y --no-install-recommends install \ + libsqlite3-0 \ + ca-certificates \ + && rm -rf /var/lib/apt/lists/* \ + # add non-root user + && adduser \ --no-create-home \ + --home /app \ --disabled-password \ --gecos "" \ --uid 49999 app \ && mkdir -p /app \ - && chown -R app:app /python \ && chown -R app:app /app +COPY --from=python --chown=app:app /python/venv /python/venv + +ENV PATH=/python/venv/bin:$PATH \ + # Send stdout and stderr stream directly to the terminal. Ensures that no + # output is retained in a buffer if the application crashes. + PYTHONUNBUFFERED=1 \ + # Typically, bytecode is created on the first invocation to speed up following invocation. + # However, in Docker we only make a single invocation (when we start the container). + # Therefore, we can disable bytecode writing. + PYTHONDONTWRITEBYTECODE=1 \ + # Ensure that python encoding is always UTF-8. + PYTHONIOENCODING=UTF-8 \ + LANG=C.UTF-8 \ + LC_ALL=C.UTF-8 + WORKDIR /app USER app ENV HOME=/app diff --git a/src/docker/superexec/Dockerfile b/src/docker/superexec/Dockerfile new file mode 100644 index 000000000000..9e4cc722921e --- /dev/null +++ b/src/docker/superexec/Dockerfile @@ -0,0 +1,20 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +ARG BASE_REPOSITORY=flwr/base +ARG BASE_IMAGE +FROM $BASE_REPOSITORY:$BASE_IMAGE + +ENTRYPOINT ["flower-superexec"] diff --git a/src/docker/supernode/Dockerfile b/src/docker/supernode/Dockerfile index 8dce1c389a5b..8b78577b1201 100644 --- a/src/docker/supernode/Dockerfile +++ b/src/docker/supernode/Dockerfile @@ -17,4 +17,4 @@ ARG BASE_REPOSITORY=flwr/base ARG BASE_IMAGE FROM $BASE_REPOSITORY:$BASE_IMAGE -ENTRYPOINT ["flower-client-app"] +ENTRYPOINT ["flower-supernode"] diff --git a/src/proto/flwr/proto/driver.proto b/src/proto/flwr/proto/driver.proto index edbd5d91bb5b..77dc52b3258b 100644 --- a/src/proto/flwr/proto/driver.proto +++ b/src/proto/flwr/proto/driver.proto @@ -42,6 +42,7 @@ service Driver { message CreateRunRequest { string fab_id = 1; string fab_version = 2; + map override_config = 3; } message CreateRunResponse { sint64 run_id = 1; } diff --git a/src/proto/flwr/proto/exec.proto b/src/proto/flwr/proto/exec.proto index 8e5f53b02ca8..d0d8dfcbb273 100644 --- a/src/proto/flwr/proto/exec.proto +++ b/src/proto/flwr/proto/exec.proto @@ -25,7 +25,10 @@ service Exec { rpc StreamLogs(StreamLogsRequest) returns (stream StreamLogsResponse) {} } -message StartRunRequest { bytes fab_file = 1; } +message StartRunRequest { + bytes fab_file = 1; + map override_config = 2; +} message StartRunResponse { sint64 run_id = 1; } message StreamLogsRequest { sint64 run_id = 1; } message StreamLogsResponse { string log_output = 1; } diff --git a/src/proto/flwr/proto/run.proto b/src/proto/flwr/proto/run.proto index 76a7fd91532f..e41748381cab 100644 --- a/src/proto/flwr/proto/run.proto +++ b/src/proto/flwr/proto/run.proto @@ -21,6 +21,7 @@ message Run { sint64 run_id = 1; string fab_id = 2; string fab_version = 3; + map override_config = 4; } message GetRunRequest { sint64 run_id = 1; } message GetRunResponse { Run run = 1; } diff --git a/src/py/flwr/cli/config_utils.py b/src/py/flwr/cli/config_utils.py index d06a1d6dba96..33bf12e34b04 100644 --- a/src/py/flwr/cli/config_utils.py +++ b/src/py/flwr/cli/config_utils.py @@ -108,6 +108,14 @@ def load(path: Optional[Path] = None) -> Optional[Dict[str, Any]]: return load_from_string(toml_file.read()) +def _validate_run_config(config_dict: Dict[str, Any], errors: List[str]) -> None: + for key, value in config_dict.items(): + if isinstance(value, dict): + _validate_run_config(config_dict[key], errors) + elif not isinstance(value, str): + errors.append(f"Config value of key {key} is not of type `str`.") + + # pylint: disable=too-many-branches def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]: """Validate pyproject.toml fields.""" @@ -133,6 +141,8 @@ def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]] else: if "publisher" not in config["flower"]: errors.append('Property "publisher" missing in [flower]') + if "config" in config["flower"]: + _validate_run_config(config["flower"]["config"], errors) if "components" not in config["flower"]: errors.append("Missing [flower.components] section") else: diff --git a/src/py/flwr/cli/new/new.py b/src/py/flwr/cli/new/new.py index 9367cf6c9ffb..a0a2dc98556d 100644 --- a/src/py/flwr/cli/new/new.py +++ b/src/py/flwr/cli/new/new.py @@ -264,9 +264,11 @@ def new( bold=True, ) ) + + _add = " huggingface-cli login\n" if framework_str == "flowertune" else "" print( typer.style( - f" cd {package_name}\n" + " pip install -e .\n flwr run\n", + f" cd {package_name}\n" + " pip install -e .\n" + _add + " flwr run\n", fg=typer.colors.BRIGHT_CYAN, bold=True, ) diff --git a/src/py/flwr/cli/new/templates/app/code/client.hf.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.hf.py.tpl index 314da2120c53..56bac8543c50 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.hf.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.hf.py.tpl @@ -1,6 +1,7 @@ """$project_name: A Flower / HuggingFace Transformers app.""" from flwr.client import ClientApp, NumPyClient +from flwr.common import Context from transformers import AutoModelForSequenceClassification from $import_name.task import ( @@ -38,12 +39,15 @@ class FlowerClient(NumPyClient): return float(loss), len(self.testloader), {"accuracy": accuracy} -def client_fn(cid): +def client_fn(context: Context): # Load model and data net = AutoModelForSequenceClassification.from_pretrained( CHECKPOINT, num_labels=2 ).to(DEVICE) - trainloader, valloader = load_data(int(cid), 2) + + partition_id = int(context.node_config['partition-id']) + num_partitions = int(context.node_config['num-partitions]) + trainloader, valloader = load_data(partition_id, num_partitions) # Return Client instance return FlowerClient(net, trainloader, valloader).to_client() diff --git a/src/py/flwr/cli/new/templates/app/code/client.jax.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.jax.py.tpl index 3c6d2f03637a..48b667665f3f 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.jax.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.jax.py.tpl @@ -2,6 +2,7 @@ import jax from flwr.client import NumPyClient, ClientApp +from flwr.common import Context from $import_name.task import ( evaluation, @@ -44,7 +45,7 @@ class FlowerClient(NumPyClient): ) return float(loss), num_examples, {"loss": float(loss)} -def client_fn(cid): +def client_fn(context: Context): # Return Client instance return FlowerClient().to_client() diff --git a/src/py/flwr/cli/new/templates/app/code/client.mlx.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.mlx.py.tpl index 1722561370a8..37207c940d83 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.mlx.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.mlx.py.tpl @@ -4,6 +4,7 @@ import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim from flwr.client import NumPyClient, ClientApp +from flwr.common import Context from $import_name.task import ( batch_iterate, @@ -57,8 +58,10 @@ class FlowerClient(NumPyClient): return loss.item(), len(self.test_images), {"accuracy": accuracy.item()} -def client_fn(cid): - data = load_data(int(cid), 2) +def client_fn(context: Context): + partition_id = int(context.node_config["partition-id"]) + num_partitions = int(context.node_config["num-partitions"]) + data = load_data(partition_id, num_partitions) # Return Client instance return FlowerClient(data).to_client() diff --git a/src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl index 232c305fc2a9..1dd83e108bb5 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl @@ -1,6 +1,7 @@ """$project_name: A Flower / NumPy app.""" from flwr.client import NumPyClient, ClientApp +from flwr.common import Context import numpy as np @@ -15,7 +16,7 @@ class FlowerClient(NumPyClient): return float(0.0), 1, {"accuracy": float(1.0)} -def client_fn(cid: str): +def client_fn(context: Context): return FlowerClient().to_client() diff --git a/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl index c68974efaadf..addc71023a09 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl @@ -1,6 +1,7 @@ """$project_name: A Flower / PyTorch app.""" from flwr.client import NumPyClient, ClientApp +from flwr.common import Context from $import_name.task import ( Net, @@ -31,10 +32,12 @@ class FlowerClient(NumPyClient): return loss, len(self.valloader.dataset), {"accuracy": accuracy} -def client_fn(cid): +def client_fn(context: Context): # Load model and data net = Net().to(DEVICE) - trainloader, valloader = load_data(int(cid), 2) + partition_id = int(context.node_config["partition-id"]) + num_partitions = int(context.node_config["num-partitions"]) + trainloader, valloader = load_data(partition_id, num_partitions) # Return Client instance return FlowerClient(net, trainloader, valloader).to_client() diff --git a/src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl index 9181389cad1c..a1eefa034e7b 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl @@ -4,6 +4,7 @@ import warnings import numpy as np from flwr.client import NumPyClient, ClientApp +from flwr.common import Context from flwr_datasets import FederatedDataset from sklearn.linear_model import LogisticRegression from sklearn.metrics import log_loss @@ -68,8 +69,9 @@ class FlowerClient(NumPyClient): fds = FederatedDataset(dataset="mnist", partitioners={"train": 2}) -def client_fn(cid: str): - dataset = fds.load_partition(int(cid), "train").with_format("numpy") +def client_fn(context: Context): + partition_id = int(context.node_config["partition-id"]) + dataset = fds.load_partition(partition_id, "train").with_format("numpy") X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"] diff --git a/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl index dc55d4ca6569..0fe1c405a110 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl @@ -1,6 +1,7 @@ """$project_name: A Flower / TensorFlow app.""" from flwr.client import NumPyClient, ClientApp +from flwr.common import Context from $import_name.task import load_data, load_model @@ -28,10 +29,12 @@ class FlowerClient(NumPyClient): return loss, len(self.x_test), {"accuracy": accuracy} -def client_fn(cid): +def client_fn(context: Context): # Load model and data net = load_model() - x_train, y_train, x_test, y_test = load_data(int(cid), 2) + + partition_id = int(context.node_config["partition-id"]) + x_train, y_train, x_test, y_test = load_data(partition_id, 2) # Return Client instance return FlowerClient(net, x_train, y_train, x_test, y_test).to_client() diff --git a/src/py/flwr/cli/new/templates/app/code/server.hf.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.hf.py.tpl index d7d86931335b..039ea8619532 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.hf.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.hf.py.tpl @@ -1,17 +1,19 @@ """$project_name: A Flower / HuggingFace Transformers app.""" +from flwr.common import Context from flwr.server.strategy import FedAvg -from flwr.server import ServerApp, ServerConfig +from flwr.server import ServerApp, ServerAppComponents, ServerConfig -# Define strategy -strategy = FedAvg( - fraction_fit=1.0, - fraction_evaluate=1.0, -) +def server_fn(context: Context): + # Define strategy + strategy = FedAvg( + fraction_fit=1.0, + fraction_evaluate=1.0, + ) + config = ServerConfig(num_rounds=3) -# Start server -app = ServerApp( - config=ServerConfig(num_rounds=3), - strategy=strategy, -) + return ServerAppComponents(strategy=strategy, config=config) + +# Create ServerApp +app = ServerApp(server_fn=server_fn) diff --git a/src/py/flwr/cli/new/templates/app/code/server.jax.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.jax.py.tpl index 53cff7b905f4..122b884ab8bb 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.jax.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.jax.py.tpl @@ -1,12 +1,16 @@ """$project_name: A Flower / JAX app.""" -import flwr as fl +from flwr.common import Context +from flwr.server.strategy import FedAvg +from flwr.server import ServerApp, ServerAppComponents, ServerConfig -# Configure the strategy -strategy = fl.server.strategy.FedAvg() -# Flower ServerApp -app = fl.server.ServerApp( - config=fl.server.ServerConfig(num_rounds=3), - strategy=strategy, -) +def server_fn(context: Context): + # Define strategy + strategy = FedAvg() + config = ServerConfig(num_rounds=3) + + return ServerAppComponents(strategy=strategy, config=config) + +# Create ServerApp +app = ServerApp(server_fn=server_fn) diff --git a/src/py/flwr/cli/new/templates/app/code/server.mlx.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.mlx.py.tpl index b475e0e7dc36..403c68ac3405 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.mlx.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.mlx.py.tpl @@ -1,15 +1,16 @@ """$project_name: A Flower / MLX app.""" -from flwr.server import ServerApp, ServerConfig +from flwr.common import Context +from flwr.server import ServerApp, ServerAppComponents, ServerConfig from flwr.server.strategy import FedAvg -# Define strategy -strategy = FedAvg() +def server_fn(context: Context): + # Define strategy + strategy = FedAvg() + config = ServerConfig(num_rounds=3) + return ServerAppComponents(strategy=strategy, config=config) # Create ServerApp -app = ServerApp( - config=ServerConfig(num_rounds=3), - strategy=strategy, -) +app = ServerApp(server_fn=server_fn) diff --git a/src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl index 03f95ae35cfd..1ed2d36339db 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl @@ -1,12 +1,16 @@ """$project_name: A Flower / NumPy app.""" -import flwr as fl +from flwr.common import Context +from flwr.server import ServerApp, ServerAppComponents, ServerConfig +from flwr.server.strategy import FedAvg -# Configure the strategy -strategy = fl.server.strategy.FedAvg() -# Flower ServerApp -app = fl.server.ServerApp( - config=fl.server.ServerConfig(num_rounds=1), - strategy=strategy, -) +def server_fn(context: Context): + # Define strategy + strategy = FedAvg() + config = ServerConfig(num_rounds=3) + + return ServerAppComponents(strategy=strategy, config=config) + +# Create ServerApp +app = ServerApp(server_fn=server_fn) diff --git a/src/py/flwr/cli/new/templates/app/code/server.pytorch.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.pytorch.py.tpl index dc635f79a664..3638b9eba7b0 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.pytorch.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.pytorch.py.tpl @@ -1,7 +1,7 @@ """$project_name: A Flower / PyTorch app.""" -from flwr.common import ndarrays_to_parameters -from flwr.server import ServerApp, ServerConfig +from flwr.common import Context, ndarrays_to_parameters +from flwr.server import ServerApp, ServerAppComponents, ServerConfig from flwr.server.strategy import FedAvg from $import_name.task import Net, get_weights @@ -11,18 +11,17 @@ from $import_name.task import Net, get_weights ndarrays = get_weights(Net()) parameters = ndarrays_to_parameters(ndarrays) +def server_fn(context: Context): + # Define strategy + strategy = FedAvg( + fraction_fit=1.0, + fraction_evaluate=1.0, + min_available_clients=2, + initial_parameters=parameters, + ) + config = ServerConfig(num_rounds=3) -# Define strategy -strategy = FedAvg( - fraction_fit=1.0, - fraction_evaluate=1.0, - min_available_clients=2, - initial_parameters=parameters, -) - + return ServerAppComponents(strategy=strategy, config=config) # Create ServerApp -app = ServerApp( - config=ServerConfig(num_rounds=3), - strategy=strategy, -) +app = ServerApp(server_fn=server_fn) diff --git a/src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl index 266a53ac5794..2e463e8da09e 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl @@ -1,17 +1,20 @@ """$project_name: A Flower / Scikit-Learn app.""" -from flwr.server import ServerApp, ServerConfig +from flwr.common import Context +from flwr.server import ServerApp, ServerAppComponents, ServerConfig from flwr.server.strategy import FedAvg -strategy = FedAvg( - fraction_fit=1.0, - fraction_evaluate=1.0, - min_available_clients=2, -) +def server_fn(context: Context): + # Define strategy + strategy = FedAvg( + fraction_fit=1.0, + fraction_evaluate=1.0, + min_available_clients=2, + ) + config = ServerConfig(num_rounds=3) + + return ServerAppComponents(strategy=strategy, config=config) # Create ServerApp -app = ServerApp( - config=ServerConfig(num_rounds=3), - strategy=strategy, -) +app = ServerApp(server_fn=server_fn) diff --git a/src/py/flwr/cli/new/templates/app/code/server.tensorflow.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.tensorflow.py.tpl index 8d092164a468..eee727ba9025 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.tensorflow.py.tpl @@ -1,7 +1,7 @@ """$project_name: A Flower / TensorFlow app.""" -from flwr.common import ndarrays_to_parameters -from flwr.server import ServerApp, ServerConfig +from flwr.common import Context, ndarrays_to_parameters +from flwr.server import ServerApp, ServerAppComponents, ServerConfig from flwr.server.strategy import FedAvg from $import_name.task import load_model @@ -11,17 +11,17 @@ config = ServerConfig(num_rounds=3) parameters = ndarrays_to_parameters(load_model().get_weights()) -# Define strategy -strategy = FedAvg( - fraction_fit=1.0, - fraction_evaluate=1.0, - min_available_clients=2, - initial_parameters=parameters, -) +def server_fn(context: Context): + # Define strategy + strategy = strategy = FedAvg( + fraction_fit=1.0, + fraction_evaluate=1.0, + min_available_clients=2, + initial_parameters=parameters, + ) + config = ServerConfig(num_rounds=3) + return ServerAppComponents(strategy=strategy, config=config) # Create ServerApp -app = ServerApp( - config=config, - strategy=strategy, -) +app = ServerApp(server_fn=server_fn) diff --git a/src/py/flwr/cli/new/templates/app/code/task.hf.py.tpl b/src/py/flwr/cli/new/templates/app/code/task.hf.py.tpl index 8e89add66835..eb43acfce976 100644 --- a/src/py/flwr/cli/new/templates/app/code/task.hf.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/task.hf.py.tpl @@ -16,9 +16,9 @@ DEVICE = torch.device("cpu") CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint -def load_data(partition_id, num_clients): +def load_data(partition_id: int, num_partitions: int): """Load IMDB data (training and eval)""" - fds = FederatedDataset(dataset="imdb", partitioners={"train": num_clients}) + fds = FederatedDataset(dataset="imdb", partitioners={"train": num_partitions}) partition = fds.load_partition(partition_id) # Divide data: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2, seed=42) diff --git a/src/py/flwr/cli/new/templates/app/code/task.mlx.py.tpl b/src/py/flwr/cli/new/templates/app/code/task.mlx.py.tpl index bcd4dde93310..88053b0cd590 100644 --- a/src/py/flwr/cli/new/templates/app/code/task.mlx.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/task.mlx.py.tpl @@ -43,8 +43,8 @@ def batch_iterate(batch_size, X, y): yield X[ids], y[ids] -def load_data(partition_id, num_clients): - fds = FederatedDataset(dataset="mnist", partitioners={"train": num_clients}) +def load_data(partition_id: int, num_partitions: int): + fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions}) partition = fds.load_partition(partition_id) partition_splits = partition.train_test_split(test_size=0.2, seed=42) diff --git a/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl b/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl index b30c65a285b5..d5971ffb6ce5 100644 --- a/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl @@ -34,7 +34,7 @@ class Net(nn.Module): return self.fc3(x) -def load_data(partition_id, num_partitions): +def load_data(partition_id: int, num_partitions: int): """Load partition CIFAR10 data.""" fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions}) partition = fds.load_partition(partition_id) diff --git a/src/py/flwr/cli/run/run.py b/src/py/flwr/cli/run/run.py index f5882bd14ab8..4ee2368f5794 100644 --- a/src/py/flwr/cli/run/run.py +++ b/src/py/flwr/cli/run/run.py @@ -18,13 +18,14 @@ from enum import Enum from logging import DEBUG from pathlib import Path -from typing import Optional +from typing import Dict, Optional import typer from typing_extensions import Annotated from flwr.cli import config_utils from flwr.cli.build import build +from flwr.common.config import parse_config_args from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel from flwr.common.logger import log @@ -58,15 +59,20 @@ def run( Optional[Path], typer.Option(help="Path of the Flower project to run"), ] = None, + config_overrides: Annotated[ + Optional[str], + typer.Option( + "--config", + "-c", + help="Override configuration key-value pairs", + ), + ] = None, ) -> None: """Run Flower project.""" - if use_superexec: - _start_superexec_run(directory) - return - typer.secho("Loading project configuration... ", fg=typer.colors.BLUE) - config, errors, warnings = config_utils.load_and_validate() + pyproject_path = directory / "pyproject.toml" if directory else None + config, errors, warnings = config_utils.load_and_validate(path=pyproject_path) if config is None: typer.secho( @@ -88,6 +94,12 @@ def run( typer.secho("Success", fg=typer.colors.GREEN) + if use_superexec: + _start_superexec_run( + parse_config_args(config_overrides, separator=","), directory + ) + return + server_app_ref = config["flower"]["components"]["serverapp"] client_app_ref = config["flower"]["components"]["clientapp"] @@ -115,7 +127,9 @@ def run( ) -def _start_superexec_run(directory: Optional[Path]) -> None: +def _start_superexec_run( + override_config: Dict[str, str], directory: Optional[Path] +) -> None: def on_channel_state_change(channel_connectivity: str) -> None: """Log channel connectivity.""" log(DEBUG, channel_connectivity) @@ -132,6 +146,9 @@ def on_channel_state_change(channel_connectivity: str) -> None: fab_path = build(directory) - req = StartRunRequest(fab_file=Path(fab_path).read_bytes()) + req = StartRunRequest( + fab_file=Path(fab_path).read_bytes(), + override_config=override_config, + ) res = stub.StartRun(req) typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN) diff --git a/src/py/flwr/client/__init__.py b/src/py/flwr/client/__init__.py index 58fd94448586..218f2fe20d62 100644 --- a/src/py/flwr/client/__init__.py +++ b/src/py/flwr/client/__init__.py @@ -23,11 +23,13 @@ from .supernode import run_client_app as run_client_app from .supernode import run_supernode as run_supernode from .typing import ClientFn as ClientFn +from .typing import ClientFnExt as ClientFnExt __all__ = [ "Client", "ClientApp", "ClientFn", + "ClientFnExt", "NumPyClient", "mod", "run_client_app", diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 1226a0d7bc21..348ef8910dd3 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -18,7 +18,8 @@ import sys import time from dataclasses import dataclass -from logging import DEBUG, ERROR, INFO, WARN +from logging import ERROR, INFO, WARN +from pathlib import Path from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union from cryptography.hazmat.primitives.asymmetric import ec @@ -26,8 +27,8 @@ from flwr.client.client import Client from flwr.client.client_app import ClientApp, LoadClientAppError -from flwr.client.typing import ClientFn -from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event +from flwr.client.typing import ClientFnExt +from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event from flwr.common.address import parse_address from flwr.common.constant import ( MISSING_EXTRA_REST, @@ -41,6 +42,7 @@ from flwr.common.logger import log, warn_deprecated_feature from flwr.common.message import Error from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential +from flwr.common.typing import Run from .grpc_adapter_client.connection import grpc_adapter from .grpc_client.connection import grpc_connection @@ -51,7 +53,7 @@ def _check_actionable_client( - client: Optional[Client], client_fn: Optional[ClientFn] + client: Optional[Client], client_fn: Optional[ClientFnExt] ) -> None: if client_fn is None and client is None: raise ValueError( @@ -72,7 +74,7 @@ def _check_actionable_client( def start_client( *, server_address: str, - client_fn: Optional[ClientFn] = None, + client_fn: Optional[ClientFnExt] = None, client: Optional[Client] = None, grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, root_certificates: Optional[Union[bytes, str]] = None, @@ -92,7 +94,7 @@ def start_client( The IPv4 or IPv6 address of the server. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"[::]:8080"`. - client_fn : Optional[ClientFn] + client_fn : Optional[ClientFnExt] A callable that instantiates a Client. (default: None) client : Optional[flwr.client.Client] An implementation of the abstract base @@ -136,8 +138,8 @@ class `flwr.client.Client` (default: None) Starting an SSL-enabled gRPC client using system certificates: - >>> def client_fn(cid: str): - >>> return FlowerClient() + >>> def client_fn(context: Context): + >>> return FlowerClient().to_client() >>> >>> start_client( >>> server_address=localhost:8080, @@ -158,6 +160,7 @@ class `flwr.client.Client` (default: None) event(EventType.START_CLIENT_ENTER) _start_client_internal( server_address=server_address, + node_config={}, load_client_app_fn=None, client_fn=client_fn, client=client, @@ -179,8 +182,9 @@ class `flwr.client.Client` (default: None) def _start_client_internal( *, server_address: str, + node_config: Dict[str, str], load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None, - client_fn: Optional[ClientFn] = None, + client_fn: Optional[ClientFnExt] = None, client: Optional[Client] = None, grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, root_certificates: Optional[Union[bytes, str]] = None, @@ -191,6 +195,7 @@ def _start_client_internal( ] = None, max_retries: Optional[int] = None, max_wait_time: Optional[float] = None, + flwr_dir: Optional[Path] = None, ) -> None: """Start a Flower client node which connects to a Flower server. @@ -200,9 +205,11 @@ def _start_client_internal( The IPv4 or IPv6 address of the server. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"[::]:8080"`. + node_config: Dict[str, str] + The configuration of the node. load_client_app_fn : Optional[Callable[[], ClientApp]] (default: None) A function that can be used to load a `ClientApp` instance. - client_fn : Optional[ClientFn] + client_fn : Optional[ClientFnExt] A callable that instantiates a Client. (default: None) client : Optional[flwr.client.Client] An implementation of the abstract base @@ -234,6 +241,8 @@ class `flwr.client.Client` (default: None) The maximum duration before the client stops trying to connect to the server in case of connection error. If set to None, there is no limit to the total time. + flwr_dir: Optional[Path] (default: None) + The fully resolved path containing installed Flower Apps. """ if insecure is None: insecure = root_certificates is None @@ -244,7 +253,7 @@ class `flwr.client.Client` (default: None) if client_fn is None: # Wrap `Client` instance in `client_fn` def single_client_factory( - cid: str, # pylint: disable=unused-argument + context: Context, # pylint: disable=unused-argument ) -> Client: if client is None: # Added this to keep mypy happy raise ValueError( @@ -285,7 +294,7 @@ def _on_backoff(retry_state: RetryState) -> None: log(WARN, "Connection attempt failed, retrying...") else: log( - DEBUG, + WARN, "Connection attempt failed, retrying in %.2f seconds", retry_state.actual_wait, ) @@ -293,7 +302,7 @@ def _on_backoff(retry_state: RetryState) -> None: retry_invoker = RetryInvoker( wait_gen_factory=exponential, recoverable_exceptions=connection_error_type, - max_tries=max_retries, + max_tries=max_retries + 1 if max_retries is not None else None, max_time=max_wait_time, on_giveup=lambda retry_state: ( log( @@ -309,9 +318,10 @@ def _on_backoff(retry_state: RetryState) -> None: on_backoff=_on_backoff, ) - node_state = NodeState() - # run_id -> (fab_id, fab_version) - run_info: Dict[int, Tuple[str, str]] = {} + # NodeState gets initialized when the first connection is established + node_state: Optional[NodeState] = None + + runs: Dict[int, Run] = {} while not app_state_tracker.interrupt: sleep_duration: int = 0 @@ -325,9 +335,31 @@ def _on_backoff(retry_state: RetryState) -> None: ) as conn: receive, send, create_node, delete_node, get_run = conn - # Register node - if create_node is not None: - create_node() # pylint: disable=not-callable + # Register node when connecting the first time + if node_state is None: + if create_node is None: + if transport not in ["grpc-bidi", None]: + raise NotImplementedError( + "All transports except `grpc-bidi` require " + "an implementation for `create_node()`.'" + ) + # gRPC-bidi doesn't have the concept of node_id, + # so we set it to -1 + node_state = NodeState( + node_id=-1, + node_config={}, + ) + else: + # Call create_node fn to register node + node_id: Optional[int] = ( # pylint: disable=assignment-from-none + create_node() + ) # pylint: disable=not-callable + if node_id is None: + raise ValueError("Node registration failed") + node_state = NodeState( + node_id=node_id, + node_config=node_config, + ) app_state_tracker.register_signal_handler() while not app_state_tracker.interrupt: @@ -361,15 +393,17 @@ def _on_backoff(retry_state: RetryState) -> None: # Get run info run_id = message.metadata.run_id - if run_id not in run_info: + if run_id not in runs: if get_run is not None: - run_info[run_id] = get_run(run_id) + runs[run_id] = get_run(run_id) # If get_run is None, i.e., in grpc-bidi mode else: - run_info[run_id] = ("", "") + runs[run_id] = Run(run_id, "", "", {}) # Register context for this run - node_state.register_context(run_id=run_id) + node_state.register_context( + run_id=run_id, run=runs[run_id], flwr_dir=flwr_dir + ) # Retrieve context for this run context = node_state.retrieve_context(run_id=run_id) @@ -383,7 +417,10 @@ def _on_backoff(retry_state: RetryState) -> None: # Handle app loading and task message try: # Load ClientApp instance - client_app: ClientApp = load_client_app_fn(*run_info[run_id]) + run: Run = runs[run_id] + client_app: ClientApp = load_client_app_fn( + run.fab_id, run.fab_version + ) # Execute ClientApp reply_message = client_app(message=message, context=context) @@ -566,9 +603,9 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ Tuple[ Callable[[], Optional[Message]], Callable[[Message], None], + Optional[Callable[[], Optional[int]]], Optional[Callable[[], None]], - Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ], ], diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py index 2e810f6560f2..2a913b3a248d 100644 --- a/src/py/flwr/client/client_app.py +++ b/src/py/flwr/client/client_app.py @@ -15,19 +15,62 @@ """Flower ClientApp.""" +import inspect from typing import Callable, List, Optional +from flwr.client.client import Client from flwr.client.message_handler.message_handler import ( handle_legacy_message_from_msgtype, ) from flwr.client.mod.utils import make_ffn -from flwr.client.typing import ClientFn, Mod +from flwr.client.typing import ClientFnExt, Mod from flwr.common import Context, Message, MessageType -from flwr.common.logger import warn_preview_feature +from flwr.common.logger import warn_deprecated_feature, warn_preview_feature from .typing import ClientAppCallable +def _alert_erroneous_client_fn() -> None: + raise ValueError( + "A `ClientApp` cannot make use of a `client_fn` that does " + "not have a signature in the form: `def client_fn(context: " + "Context)`. You can import the `Context` like this: " + "`from flwr.common import Context`" + ) + + +def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt: + client_fn_args = inspect.signature(client_fn).parameters + first_arg = list(client_fn_args.keys())[0] + + if len(client_fn_args) != 1: + _alert_erroneous_client_fn() + + first_arg_type = client_fn_args[first_arg].annotation + + if first_arg_type is str or first_arg == "cid": + # Warn previous signature for `client_fn` seems to be used + warn_deprecated_feature( + "`client_fn` now expects a signature `def client_fn(context: Context)`." + "The provided `client_fn` has signature: " + f"{dict(client_fn_args.items())}. You can import the `Context` like this:" + " `from flwr.common import Context`" + ) + + # Wrap depcreated client_fn inside a function with the expected signature + def adaptor_fn( + context: Context, + ) -> Client: # pylint: disable=unused-argument + # if patition-id is defined, pass it. Else pass node_id that should + # always be defined during Context init. + cid = context.node_config.get("partition-id", context.node_id) + return client_fn(str(cid)) # type: ignore + + return adaptor_fn + + return client_fn + + class ClientAppException(Exception): """Exception raised when an exception is raised while executing a ClientApp.""" @@ -48,7 +91,7 @@ class ClientApp: >>> class FlowerClient(NumPyClient): >>> # ... >>> - >>> def client_fn(cid): + >>> def client_fn(context: Context): >>> return FlowerClient().to_client() >>> >>> app = ClientApp(client_fn) @@ -65,7 +108,7 @@ class ClientApp: def __init__( self, - client_fn: Optional[ClientFn] = None, # Only for backward compatibility + client_fn: Optional[ClientFnExt] = None, # Only for backward compatibility mods: Optional[List[Mod]] = None, ) -> None: self._mods: List[Mod] = mods if mods is not None else [] @@ -74,6 +117,8 @@ def __init__( self._call: Optional[ClientAppCallable] = None if client_fn is not None: + client_fn = _inspect_maybe_adapt_client_fn_signature(client_fn) + def ffn( message: Message, context: Context, diff --git a/src/py/flwr/client/grpc_adapter_client/connection.py b/src/py/flwr/client/grpc_adapter_client/connection.py index e4e32b3accd0..80a5cf0b4656 100644 --- a/src/py/flwr/client/grpc_adapter_client/connection.py +++ b/src/py/flwr/client/grpc_adapter_client/connection.py @@ -27,6 +27,7 @@ from flwr.common.logger import log from flwr.common.message import Message from flwr.common.retry_invoker import RetryInvoker +from flwr.common.typing import Run @contextmanager @@ -43,9 +44,9 @@ def grpc_adapter( # pylint: disable=R0913 Tuple[ Callable[[], Optional[Message]], Callable[[Message], None], + Optional[Callable[[], Optional[int]]], Optional[Callable[[], None]], - Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ]: """Primitives for request/response-based interaction with a server via GrpcAdapter. diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 8c049861c672..a6417106d51b 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -38,6 +38,7 @@ from flwr.common.grpc import create_channel from flwr.common.logger import log from flwr.common.retry_invoker import RetryInvoker +from flwr.common.typing import Run from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, Reason, @@ -71,9 +72,9 @@ def grpc_connection( # pylint: disable=R0913, R0915 Tuple[ Callable[[], Optional[Message]], Callable[[Message], None], + Optional[Callable[[], Optional[int]]], Optional[Callable[[], None]], - Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ]: """Establish a gRPC connection to a gRPC server. diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 34dc0e417383..e573df6854bc 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -41,6 +41,7 @@ from flwr.common.message import Message, Metadata from flwr.common.retry_invoker import RetryInvoker from flwr.common.serde import message_from_taskins, message_to_taskres +from flwr.common.typing import Run from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, @@ -78,9 +79,9 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915 Tuple[ Callable[[], Optional[Message]], Callable[[Message], None], + Optional[Callable[[], Optional[int]]], Optional[Callable[[], None]], - Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ]: """Primitives for request/response-based interaction with a server. @@ -175,7 +176,7 @@ def ping() -> None: if not ping_stop_event.is_set(): ping_stop_event.wait(next_interval) - def create_node() -> None: + def create_node() -> Optional[int]: """Set create_node.""" # Call FleetAPI create_node_request = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL) @@ -188,6 +189,7 @@ def create_node() -> None: nonlocal node, ping_thread node = cast(Node, create_node_response.node) ping_thread = start_ping_loop(ping, ping_stop_event) + return node.node_id def delete_node() -> None: """Set delete_node.""" @@ -266,7 +268,7 @@ def send(message: Message) -> None: # Cleanup metadata = None - def get_run(run_id: int) -> Tuple[str, str]: + def get_run(run_id: int) -> Run: # Call FleetAPI get_run_request = GetRunRequest(run_id=run_id) get_run_response: GetRunResponse = retry_invoker.invoke( @@ -275,7 +277,12 @@ def get_run(run_id: int) -> Tuple[str, str]: ) # Return fab_id and fab_version - return get_run_response.run.fab_id, get_run_response.run.fab_version + return Run( + run_id, + get_run_response.run.fab_id, + get_run_response.run.fab_version, + dict(get_run_response.run.override_config.items()), + ) try: # Yield methods diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 68326852970f..1ab84eb01468 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -14,7 +14,6 @@ # ============================================================================== """Client-side message handler.""" - from logging import WARN from typing import Optional, Tuple, cast @@ -25,7 +24,7 @@ maybe_call_get_properties, ) from flwr.client.numpy_client import NumPyClient -from flwr.client.typing import ClientFn +from flwr.client.typing import ClientFnExt from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet, log from flwr.common.constant import MessageType, MessageTypeLegacy from flwr.common.recordset_compat import ( @@ -90,10 +89,10 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]: def handle_legacy_message_from_msgtype( - client_fn: ClientFn, message: Message, context: Context + client_fn: ClientFnExt, message: Message, context: Context ) -> Message: """Handle legacy message in the inner most mod.""" - client = client_fn(str(message.metadata.partition_id)) + client = client_fn(context) # Check if NumPyClient is returend if isinstance(client, NumPyClient): diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 40907942513d..557d61ffb32a 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -22,7 +22,7 @@ from typing import List from flwr.client import Client -from flwr.client.typing import ClientFn +from flwr.client.typing import ClientFnExt from flwr.common import ( DEFAULT_TTL, Code, @@ -113,8 +113,8 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes: ) -def _get_client_fn(client: Client) -> ClientFn: - def client_fn(cid: str) -> Client: # pylint: disable=unused-argument +def _get_client_fn(client: Client) -> ClientFnExt: + def client_fn(contex: Context) -> Client: # pylint: disable=unused-argument return client return client_fn @@ -143,7 +143,7 @@ def test_client_without_get_properties() -> None: actual_msg = handle_legacy_message_from_msgtype( client_fn=_get_client_fn(client), message=message, - context=Context(state=RecordSet()), + context=Context(node_id=1123, node_config={}, state=RecordSet(), run_config={}), ) # Assert @@ -207,7 +207,7 @@ def test_client_with_get_properties() -> None: actual_msg = handle_legacy_message_from_msgtype( client_fn=_get_client_fn(client), message=message, - context=Context(state=RecordSet()), + context=Context(node_id=1123, node_config={}, state=RecordSet(), run_config={}), ) # Assert diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py index 36844a2983a1..2832576fb4fc 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py @@ -73,7 +73,12 @@ def func(configs: Dict[str, ConfigsRecordValues]) -> ConfigsRecord: def _make_ctxt() -> Context: cfg = ConfigsRecord(SecAggPlusState().to_dict()) - return Context(RecordSet(configs_records={RECORD_KEY_STATE: cfg})) + return Context( + node_id=123, + node_config={}, + state=RecordSet(configs_records={RECORD_KEY_STATE: cfg}), + run_config={}, + ) def _make_set_state_fn( diff --git a/src/py/flwr/client/mod/utils_test.py b/src/py/flwr/client/mod/utils_test.py index 035e41639b10..a5bbd0a0bb4d 100644 --- a/src/py/flwr/client/mod/utils_test.py +++ b/src/py/flwr/client/mod/utils_test.py @@ -104,7 +104,7 @@ def test_multiple_mods(self) -> None: state = RecordSet() state.metrics_records[METRIC] = MetricsRecord({COUNTER: 0.0}) - context = Context(state=state) + context = Context(node_id=0, node_config={}, state=state, run_config={}) message = _get_dummy_flower_message() # Execute @@ -129,7 +129,7 @@ def test_filter(self) -> None: # Prepare footprint: List[str] = [] mock_app = make_mock_app("app", footprint) - context = Context(state=RecordSet()) + context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={}) message = _get_dummy_flower_message() def filter_mod( diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/node_state.py index 71681b783419..393ca4564a35 100644 --- a/src/py/flwr/client/node_state.py +++ b/src/py/flwr/client/node_state.py @@ -15,27 +15,58 @@ """Node state.""" -from typing import Any, Dict +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Optional from flwr.common import Context, RecordSet +from flwr.common.config import get_fused_config +from flwr.common.typing import Run + + +@dataclass() +class RunInfo: + """Contains the Context and initial run_config of a Run.""" + + context: Context + initial_run_config: Dict[str, str] class NodeState: """State of a node where client nodes execute runs.""" - def __init__(self) -> None: - self._meta: Dict[str, Any] = {} # holds metadata about the node - self.run_contexts: Dict[int, Context] = {} + def __init__( + self, + node_id: int, + node_config: Dict[str, str], + ) -> None: + self.node_id = node_id + self.node_config = node_config + self.run_infos: Dict[int, RunInfo] = {} - def register_context(self, run_id: int) -> None: + def register_context( + self, + run_id: int, + run: Optional[Run] = None, + flwr_dir: Optional[Path] = None, + ) -> None: """Register new run context for this node.""" - if run_id not in self.run_contexts: - self.run_contexts[run_id] = Context(state=RecordSet()) + if run_id not in self.run_infos: + initial_run_config = get_fused_config(run, flwr_dir) if run else {} + self.run_infos[run_id] = RunInfo( + initial_run_config=initial_run_config, + context=Context( + node_id=self.node_id, + node_config=self.node_config, + state=RecordSet(), + run_config=initial_run_config.copy(), + ), + ) def retrieve_context(self, run_id: int) -> Context: """Get run context given a run_id.""" - if run_id in self.run_contexts: - return self.run_contexts[run_id] + if run_id in self.run_infos: + return self.run_infos[run_id].context raise RuntimeError( f"Context for run_id={run_id} doesn't exist." @@ -45,4 +76,9 @@ def retrieve_context(self, run_id: int) -> Context: def update_context(self, run_id: int, context: Context) -> None: """Update run context.""" - self.run_contexts[run_id] = context + if context.run_config != self.run_infos[run_id].initial_run_config: + raise ValueError( + "The `run_config` field of the `Context` object cannot be " + f"modified (run_id: {run_id})." + ) + self.run_infos[run_id].context = context diff --git a/src/py/flwr/client/node_state_tests.py b/src/py/flwr/client/node_state_tests.py index 193f52661579..26ac4fea6855 100644 --- a/src/py/flwr/client/node_state_tests.py +++ b/src/py/flwr/client/node_state_tests.py @@ -41,7 +41,7 @@ def test_multirun_in_node_state() -> None: expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"} # NodeState - node_state = NodeState() + node_state = NodeState(node_id=0, node_config={}) for task in tasks: run_id = task.run_id @@ -59,7 +59,8 @@ def test_multirun_in_node_state() -> None: node_state.update_context(run_id=run_id, context=updated_state) # Verify values - for run_id, context in node_state.run_contexts.items(): + for run_id, run_info in node_state.run_infos.items(): assert ( - context.state.configs_records["counter"]["count"] == expected_values[run_id] + run_info.context.state.configs_records["counter"]["count"] + == expected_values[run_id] ) diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index db5bd7eb6770..3e81969d898c 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -41,6 +41,7 @@ from flwr.common.message import Message, Metadata from flwr.common.retry_invoker import RetryInvoker from flwr.common.serde import message_from_taskins, message_to_taskres +from flwr.common.typing import Run from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, @@ -89,9 +90,9 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915 Tuple[ Callable[[], Optional[Message]], Callable[[Message], None], + Optional[Callable[[], Optional[int]]], Optional[Callable[[], None]], - Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ]: """Primitives for request/response-based interaction with a server. @@ -236,19 +237,20 @@ def ping() -> None: if not ping_stop_event.is_set(): ping_stop_event.wait(next_interval) - def create_node() -> None: + def create_node() -> Optional[int]: """Set create_node.""" req = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL) # Send the request res = _request(req, CreateNodeResponse, PATH_CREATE_NODE) if res is None: - return + return None # Remember the node and the ping-loop thread nonlocal node, ping_thread node = res.node ping_thread = start_ping_loop(ping, ping_stop_event) + return node.node_id def delete_node() -> None: """Set delete_node.""" @@ -344,16 +346,21 @@ def send(message: Message) -> None: res.results, # pylint: disable=no-member ) - def get_run(run_id: int) -> Tuple[str, str]: + def get_run(run_id: int) -> Run: # Construct the request req = GetRunRequest(run_id=run_id) # Send the request res = _request(req, GetRunResponse, PATH_GET_RUN) if res is None: - return "", "" + return Run(run_id, "", "", {}) - return res.run.fab_id, res.run.fab_version + return Run( + run_id, + res.run.fab_id, + res.run.fab_version, + dict(res.run.override_config.items()), + ) try: # Yield methods diff --git a/src/py/flwr/client/supernode/app.py b/src/py/flwr/client/supernode/app.py index c9a16edeaf15..d61b986bc7af 100644 --- a/src/py/flwr/client/supernode/app.py +++ b/src/py/flwr/client/supernode/app.py @@ -29,7 +29,12 @@ from flwr.client.client_app import ClientApp, LoadClientAppError from flwr.common import EventType, event -from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir +from flwr.common.config import ( + get_flwr_dir, + get_project_config, + get_project_dir, + parse_config_args, +) from flwr.common.constant import ( TRANSPORT_TYPE_GRPC_ADAPTER, TRANSPORT_TYPE_GRPC_RERE, @@ -67,6 +72,8 @@ def run_supernode() -> None: authentication_keys=authentication_keys, max_retries=args.max_retries, max_wait_time=args.max_wait_time, + node_config=parse_config_args(args.node_config), + flwr_dir=get_flwr_dir(args.flwr_dir), ) # Graceful shutdown @@ -91,6 +98,7 @@ def run_client_app() -> None: _start_client_internal( server_address=args.superlink, + node_config=parse_config_args(args.node_config), load_client_app_fn=load_fn, transport=args.transport, root_certificates=root_certificates, @@ -177,7 +185,7 @@ def _get_load_client_app_fn( else: flwr_dir = Path(args.flwr_dir).absolute() - sys.path.insert(0, str(flwr_dir.absolute())) + inserted_path = None default_app_ref: str = getattr(args, "client-app") @@ -187,6 +195,11 @@ def _get_load_client_app_fn( "Flower SuperNode will load and validate ClientApp `%s`", getattr(args, "client-app"), ) + # Insert sys.path + dir_path = Path(args.dir).absolute() + sys.path.insert(0, str(dir_path)) + inserted_path = str(dir_path) + valid, error_msg = validate(default_app_ref) if not valid and error_msg: raise LoadClientAppError(error_msg) from None @@ -195,7 +208,7 @@ def _load(fab_id: str, fab_version: str) -> ClientApp: # If multi-app feature is disabled if not multi_app: # Get sys path to be inserted - sys_path = Path(args.dir).absolute() + dir_path = Path(args.dir).absolute() # Set app reference client_app_ref = default_app_ref @@ -208,7 +221,7 @@ def _load(fab_id: str, fab_version: str) -> ClientApp: log(WARN, "FAB ID is not provided; the default ClientApp will be loaded.") # Get sys path to be inserted - sys_path = Path(args.dir).absolute() + dir_path = Path(args.dir).absolute() # Set app reference client_app_ref = default_app_ref @@ -221,13 +234,21 @@ def _load(fab_id: str, fab_version: str) -> ClientApp: raise LoadClientAppError("Failed to load ClientApp") from e # Get sys path to be inserted - sys_path = Path(project_dir).absolute() + dir_path = Path(project_dir).absolute() # Set app reference client_app_ref = config["flower"]["components"]["clientapp"] # Set sys.path - sys.path.insert(0, str(sys_path)) + nonlocal inserted_path + if inserted_path != str(dir_path): + # Remove the previously inserted path + if inserted_path is not None: + sys.path.remove(inserted_path) + # Insert the new path + sys.path.insert(0, str(dir_path)) + + inserted_path = str(dir_path) # Load ClientApp log( @@ -235,7 +256,7 @@ def _load(fab_id: str, fab_version: str) -> ClientApp: "Loading ClientApp `%s`", client_app_ref, ) - client_app = load_app(client_app_ref, LoadClientAppError, sys_path) + client_app = load_app(client_app_ref, LoadClientAppError, dir_path) if not isinstance(client_app, ClientApp): raise LoadClientAppError( @@ -344,8 +365,8 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None: "--max-retries", type=int, default=None, - help="The maximum number of times the client will try to connect to the" - "server before giving up in case of a connection error. By default," + help="The maximum number of times the client will try to reconnect to the" + "SuperLink before giving up in case of a connection error. By default," "it is set to None, meaning there is no limit to the number of tries.", ) parser.add_argument( @@ -353,7 +374,7 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None: type=float, default=None, help="The maximum duration before the client stops trying to" - "connect to the server in case of connection error. By default, it" + "connect to the SuperLink in case of connection error. By default, it" "is set to None, meaning there is no limit to the total time.", ) parser.add_argument( @@ -373,6 +394,13 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None: type=str, help="The SuperNode's public key (as a path str) to enable authentication.", ) + parser.add_argument( + "--node-config", + type=str, + help="A comma separated list of key/value pairs (separated by `=`) to " + "configure the SuperNode. " + "E.g. --node-config 'key1=\"value1\",partition-id=0,num-partitions=100'", + ) def _try_setup_client_authentication( diff --git a/src/py/flwr/client/typing.py b/src/py/flwr/client/typing.py index 956ac7a15c05..9faed4bc7283 100644 --- a/src/py/flwr/client/typing.py +++ b/src/py/flwr/client/typing.py @@ -23,6 +23,7 @@ # Compatibility ClientFn = Callable[[str], Client] +ClientFnExt = Callable[[Context], Client] ClientAppCallable = Callable[[Message, Context], Message] Mod = Callable[[Message, Context, ClientAppCallable], Message] diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index 20de00a6fba9..54d74353e4ed 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -16,12 +16,13 @@ import os from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import tomli from flwr.cli.config_utils import validate_fields from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME +from flwr.common.typing import Run def get_flwr_dir(provided_path: Optional[str] = None) -> Path: @@ -30,7 +31,7 @@ def get_flwr_dir(provided_path: Optional[str] = None) -> Path: return Path( os.getenv( FLWR_HOME, - f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}/.flwr", + Path(f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}") / ".flwr", ) ) return Path(provided_path).absolute() @@ -71,3 +72,75 @@ def get_project_config(project_dir: Union[str, Path]) -> Dict[str, Any]: ) return config + + +def _fuse_dicts( + main_dict: Dict[str, str], override_dict: Dict[str, str] +) -> Dict[str, str]: + fused_dict = main_dict.copy() + + for key, value in override_dict.items(): + if key in main_dict: + fused_dict[key] = value + + return fused_dict + + +def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> Dict[str, str]: + """Merge the overrides from a `Run` with the config from a FAB. + + Get the config using the fab_id and the fab_version, remove the nesting by adding + the nested keys as prefixes separated by dots, and fuse it with the override dict. + """ + if not run.fab_id or not run.fab_version: + return {} + + project_dir = get_project_dir(run.fab_id, run.fab_version, flwr_dir) + + default_config = get_project_config(project_dir)["flower"].get("config", {}) + flat_default_config = flatten_dict(default_config) + + return _fuse_dicts(flat_default_config, run.override_config) + + +def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, str]: + """Flatten dict by joining nested keys with a given separator.""" + items: List[Tuple[str, str]] = [] + separator: str = "." + for k, v in raw_dict.items(): + new_key = f"{parent_key}{separator}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, parent_key=new_key).items()) + elif isinstance(v, str): + items.append((new_key, v)) + else: + raise ValueError( + f"The value for key {k} needs to be a `str` or a `dict`.", + ) + return dict(items) + + +def parse_config_args( + config: Optional[str], + separator: str = ",", +) -> Dict[str, str]: + """Parse separator separated list of key-value pairs separated by '='.""" + overrides: Dict[str, str] = {} + + if config is None: + return overrides + + overrides_list = config.split(separator) + if ( + len(overrides_list) == 1 + and "=" not in overrides_list + and overrides_list[0].endswith(".toml") + ): + with Path(overrides_list[0]).open("rb") as config_file: + overrides = flatten_dict(tomli.load(config_file)) + else: + for kv_pair in overrides_list: + key, value = kv_pair.split("=") + overrides[key] = value + + return overrides diff --git a/src/py/flwr/common/config_test.py b/src/py/flwr/common/config_test.py new file mode 100644 index 000000000000..fe429bab9cb5 --- /dev/null +++ b/src/py/flwr/common/config_test.py @@ -0,0 +1,230 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test util functions handling Flower config.""" + +import os +import textwrap +from pathlib import Path +from unittest.mock import patch + +import pytest + +from .config import ( + _fuse_dicts, + flatten_dict, + get_flwr_dir, + get_project_config, + get_project_dir, + parse_config_args, +) + +# Mock constants +FAB_CONFIG_FILE = "pyproject.toml" + + +def test_get_flwr_dir_with_provided_path() -> None: + """Test get_flwr_dir with a provided valid path.""" + provided_path = "." + assert get_flwr_dir(provided_path) == Path(provided_path).absolute() + + +def test_get_flwr_dir_without_provided_path() -> None: + """Test get_flwr_dir without a provided path, using default home directory.""" + with patch.dict(os.environ, {"HOME": "/home/user"}): + assert get_flwr_dir() == Path("/home/user/.flwr") + + +def test_get_flwr_dir_with_flwr_home() -> None: + """Test get_flwr_dir with FLWR_HOME environment variable set.""" + with patch.dict(os.environ, {"FLWR_HOME": "/custom/flwr/home"}): + assert get_flwr_dir() == Path("/custom/flwr/home") + + +def test_get_flwr_dir_with_xdg_data_home() -> None: + """Test get_flwr_dir with FLWR_HOME environment variable set.""" + with patch.dict(os.environ, {"XDG_DATA_HOME": "/custom/data/home"}): + assert get_flwr_dir() == Path("/custom/data/home/.flwr") + + +def test_get_project_dir_invalid_fab_id() -> None: + """Test get_project_dir with an invalid fab_id.""" + with pytest.raises(ValueError): + get_project_dir("invalid_fab_id", "1.0.0") + + +def test_get_project_dir_valid() -> None: + """Test get_project_dir with an valid fab_id and version.""" + app_path = get_project_dir("app_name/user", "1.0.0", flwr_dir=".") + assert app_path == Path("apps") / "app_name" / "user" / "1.0.0" + + +def test_get_project_config_file_not_found() -> None: + """Test get_project_config when the configuration file is not found.""" + with pytest.raises(FileNotFoundError): + get_project_config("/invalid/dir") + + +def test_get_fused_config_valid(tmp_path: Path) -> None: + """Test get_project_config when the configuration file is not found.""" + pyproject_toml_content = """ + [build-system] + requires = ["hatchling"] + build-backend = "hatchling.build" + + [project] + name = "fedgpt" + version = "1.0.0" + description = "" + license = {text = "Apache License (2.0)"} + dependencies = [ + "flwr[simulation]>=1.9.0,<2.0", + "numpy>=1.21.0", + ] + + [flower] + publisher = "flwrlabs" + + [flower.components] + serverapp = "fedgpt.server:app" + clientapp = "fedgpt.client:app" + + [flower.config] + num_server_rounds = "10" + momentum = "0.1" + lr = "0.01" + serverapp.test = "key" + + [flower.config.clientapp] + test = "key" + """ + overrides = { + "num_server_rounds": "5", + "lr": "0.2", + "serverapp.test": "overriden", + } + expected_config = { + "num_server_rounds": "5", + "momentum": "0.1", + "lr": "0.2", + "serverapp.test": "overriden", + "clientapp.test": "key", + } + # Current directory + origin = Path.cwd() + + try: + # Change into the temporary directory + os.chdir(tmp_path) + with open(FAB_CONFIG_FILE, "w", encoding="utf-8") as f: + f.write(textwrap.dedent(pyproject_toml_content)) + + # Execute + default_config = get_project_config(tmp_path)["flower"].get("config", {}) + + config = _fuse_dicts(flatten_dict(default_config), overrides) + + # Assert + assert config == expected_config + finally: + os.chdir(origin) + + +def test_get_project_config_file_valid(tmp_path: Path) -> None: + """Test get_project_config when the configuration file is not found.""" + pyproject_toml_content = """ + [build-system] + requires = ["hatchling"] + build-backend = "hatchling.build" + + [project] + name = "fedgpt" + version = "1.0.0" + description = "" + license = {text = "Apache License (2.0)"} + dependencies = [ + "flwr[simulation]>=1.9.0,<2.0", + "numpy>=1.21.0", + ] + + [flower] + publisher = "flwrlabs" + + [flower.components] + serverapp = "fedgpt.server:app" + clientapp = "fedgpt.client:app" + + [flower.config] + num_server_rounds = "10" + momentum = "0.1" + lr = "0.01" + """ + expected_config = { + "build-system": {"build-backend": "hatchling.build", "requires": ["hatchling"]}, + "project": { + "name": "fedgpt", + "version": "1.0.0", + "description": "", + "license": {"text": "Apache License (2.0)"}, + "dependencies": ["flwr[simulation]>=1.9.0,<2.0", "numpy>=1.21.0"], + }, + "flower": { + "publisher": "flwrlabs", + "components": { + "serverapp": "fedgpt.server:app", + "clientapp": "fedgpt.client:app", + }, + "config": { + "num_server_rounds": "10", + "momentum": "0.1", + "lr": "0.01", + }, + }, + } + # Current directory + origin = Path.cwd() + + try: + # Change into the temporary directory + os.chdir(tmp_path) + with open(FAB_CONFIG_FILE, "w", encoding="utf-8") as f: + f.write(textwrap.dedent(pyproject_toml_content)) + + # Execute + config = get_project_config(tmp_path) + + # Assert + assert config == expected_config + finally: + os.chdir(origin) + + +def test_flatten_dict() -> None: + """Test flatten_dict with a nested dictionary.""" + raw_dict = {"a": {"b": {"c": "d"}}, "e": "f"} + expected = {"a.b.c": "d", "e": "f"} + assert flatten_dict(raw_dict) == expected + + +def test_parse_config_args_none() -> None: + """Test parse_config_args with None as input.""" + assert not parse_config_args(None) + + +def test_parse_config_args_overrides() -> None: + """Test parse_config_args with key-value pairs.""" + assert parse_config_args("key1=value1,key2=value2") == { + "key1": "value1", + "key2": "value2", + } diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index ce29b3edb30e..72256a62add7 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -46,6 +46,9 @@ PING_RANDOM_RANGE = (-0.1, 0.1) PING_MAX_INTERVAL = 1e300 +# IDs +RUN_ID_NUM_BYTES = 8 +NODE_ID_NUM_BYTES = 8 GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version" GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit" @@ -54,6 +57,9 @@ FAB_CONFIG_FILE = "pyproject.toml" FLWR_HOME = "FLWR_HOME" +# Constants entries in Node config for Simulation +PARTITION_ID_KEY = "partition-id" +NUM_PARTITIONS_KEY = "num-partitions" GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version" GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit" diff --git a/src/py/flwr/common/context.py b/src/py/flwr/common/context.py index b6349307d150..4da52ba44481 100644 --- a/src/py/flwr/common/context.py +++ b/src/py/flwr/common/context.py @@ -16,16 +16,22 @@ from dataclasses import dataclass +from typing import Dict from .record import RecordSet @dataclass class Context: - """State of your run. + """Context of your run. Parameters ---------- + node_id : int + The ID that identifies the node. + node_config : Dict[str, str] + A config (key/value mapping) unique to the node and independent of the + `run_config`. This config persists across all runs this node participates in. state : RecordSet Holds records added by the entity in a given run and that will stay local. This means that the data it holds will never leave the system it's running from. @@ -33,6 +39,25 @@ class Context: executing mods. It can also be used as a memory to access at different points during the lifecycle of this entity (e.g. across multiple rounds) + run_config : Dict[str, str] + A config (key/value mapping) held by the entity in a given run and that will + stay local. It can be used at any point during the lifecycle of this entity + (e.g. across multiple rounds) """ + node_id: int + node_config: Dict[str, str] state: RecordSet + run_config: Dict[str, str] + + def __init__( # pylint: disable=too-many-arguments + self, + node_id: int, + node_config: Dict[str, str], + state: RecordSet, + run_config: Dict[str, str], + ) -> None: + self.node_id = node_id + self.node_config = node_config + self.state = state + self.run_config = run_config diff --git a/src/py/flwr/common/logger.py b/src/py/flwr/common/logger.py index 7225b0663ae7..2077f9beaca0 100644 --- a/src/py/flwr/common/logger.py +++ b/src/py/flwr/common/logger.py @@ -197,6 +197,44 @@ def warn_deprecated_feature(name: str) -> None: ) +def warn_deprecated_feature_with_example( + deprecation_message: str, example_message: str, code_example: str +) -> None: + """Warn if a feature is deprecated and show code example.""" + log( + WARN, + """DEPRECATED FEATURE: %s + + Check the following `FEATURE UPDATE` warning message for the preferred + new mechanism to use this feature in Flower. + """, + deprecation_message, + ) + log( + WARN, + """FEATURE UPDATE: %s + ------------------------------------------------------------ + %s + ------------------------------------------------------------ + """, + example_message, + code_example, + ) + + +def warn_unsupported_feature(name: str) -> None: + """Warn the user when they use an unsupported feature.""" + log( + WARN, + """UNSUPPORTED FEATURE: %s + + This is an unsupported feature. It will be removed + entirely in future versions of Flower. + """, + name, + ) + + def set_logger_propagation( child_logger: logging.Logger, value: bool = True ) -> logging.Logger: diff --git a/src/py/flwr/common/message.py b/src/py/flwr/common/message.py index 7f7a0e4dd995..4138fc95a591 100644 --- a/src/py/flwr/common/message.py +++ b/src/py/flwr/common/message.py @@ -48,10 +48,6 @@ class Metadata: # pylint: disable=too-many-instance-attributes message_type : str A string that encodes the action to be executed on the receiving end. - partition_id : Optional[int] - An identifier that can be used when loading a particular - data partition for a ClientApp. Making use of this identifier - is more relevant when conducting simulations. """ def __init__( # pylint: disable=too-many-arguments @@ -64,7 +60,6 @@ def __init__( # pylint: disable=too-many-arguments group_id: str, ttl: float, message_type: str, - partition_id: int | None = None, ) -> None: var_dict = { "_run_id": run_id, @@ -75,7 +70,6 @@ def __init__( # pylint: disable=too-many-arguments "_group_id": group_id, "_ttl": ttl, "_message_type": message_type, - "_partition_id": partition_id, } self.__dict__.update(var_dict) @@ -149,16 +143,6 @@ def message_type(self, value: str) -> None: """Set message_type.""" self.__dict__["_message_type"] = value - @property - def partition_id(self) -> int | None: - """An identifier telling which data partition a ClientApp should use.""" - return cast(int, self.__dict__["_partition_id"]) - - @partition_id.setter - def partition_id(self, value: int) -> None: - """Set partition_id.""" - self.__dict__["_partition_id"] = value - def __repr__(self) -> str: """Return a string representation of this instance.""" view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()]) @@ -398,5 +382,4 @@ def _create_reply_metadata(msg: Message, ttl: float) -> Metadata: group_id=msg.metadata.group_id, ttl=ttl, message_type=msg.metadata.message_type, - partition_id=msg.metadata.partition_id, ) diff --git a/src/py/flwr/common/message_test.py b/src/py/flwr/common/message_test.py index daee57896903..c6142cb18256 100644 --- a/src/py/flwr/common/message_test.py +++ b/src/py/flwr/common/message_test.py @@ -174,7 +174,6 @@ def test_create_reply( "group_id": "group_xyz", "ttl": 10.0, "message_type": "request", - "partition_id": None, }, ), (Error, {"code": 1, "reason": "reason_098"}), diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index f51830955679..04d2cf5bbf7f 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -194,3 +194,4 @@ class Run: run_id: int fab_id: str fab_version: str + override_config: Dict[str, str] diff --git a/src/py/flwr/proto/common_pb2.py b/src/py/flwr/proto/common_pb2.py new file mode 100644 index 000000000000..8a6430137f05 --- /dev/null +++ b/src/py/flwr/proto/common_pb2.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: flwr/proto/common.proto +# Protobuf Python Version: 4.25.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/common.proto\x12\nflwr.protob\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.common_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None +# @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/common_pb2.pyi b/src/py/flwr/proto/common_pb2.pyi new file mode 100644 index 000000000000..e08fa11c2caa --- /dev/null +++ b/src/py/flwr/proto/common_pb2.pyi @@ -0,0 +1,7 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import google.protobuf.descriptor + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor diff --git a/src/py/flwr/proto/common_pb2_grpc.py b/src/py/flwr/proto/common_pb2_grpc.py new file mode 100644 index 000000000000..2daafffebfc8 --- /dev/null +++ b/src/py/flwr/proto/common_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/src/py/flwr/proto/common_pb2_grpc.pyi b/src/py/flwr/proto/common_pb2_grpc.pyi new file mode 100644 index 000000000000..f3a5a087ef5d --- /dev/null +++ b/src/py/flwr/proto/common_pb2_grpc.pyi @@ -0,0 +1,4 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" diff --git a/src/py/flwr/proto/driver_pb2.py b/src/py/flwr/proto/driver_pb2.py index a2458b445563..07975937328d 100644 --- a/src/py/flwr/proto/driver_pb2.py +++ b/src/py/flwr/proto/driver_pb2.py @@ -17,29 +17,33 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\"7\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\x84\x03\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\"\xb9\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x1a\x35\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\x84\x03\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.driver_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_CREATERUNREQUEST']._serialized_start=107 - _globals['_CREATERUNREQUEST']._serialized_end=162 - _globals['_CREATERUNRESPONSE']._serialized_start=164 - _globals['_CREATERUNRESPONSE']._serialized_end=199 - _globals['_GETNODESREQUEST']._serialized_start=201 - _globals['_GETNODESREQUEST']._serialized_end=234 - _globals['_GETNODESRESPONSE']._serialized_start=236 - _globals['_GETNODESRESPONSE']._serialized_end=287 - _globals['_PUSHTASKINSREQUEST']._serialized_start=289 - _globals['_PUSHTASKINSREQUEST']._serialized_end=353 - _globals['_PUSHTASKINSRESPONSE']._serialized_start=355 - _globals['_PUSHTASKINSRESPONSE']._serialized_end=394 - _globals['_PULLTASKRESREQUEST']._serialized_start=396 - _globals['_PULLTASKRESREQUEST']._serialized_end=466 - _globals['_PULLTASKRESRESPONSE']._serialized_start=468 - _globals['_PULLTASKRESRESPONSE']._serialized_end=533 - _globals['_DRIVER']._serialized_start=536 - _globals['_DRIVER']._serialized_end=924 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._options = None + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' + _globals['_CREATERUNREQUEST']._serialized_start=108 + _globals['_CREATERUNREQUEST']._serialized_end=293 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=240 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=293 + _globals['_CREATERUNRESPONSE']._serialized_start=295 + _globals['_CREATERUNRESPONSE']._serialized_end=330 + _globals['_GETNODESREQUEST']._serialized_start=332 + _globals['_GETNODESREQUEST']._serialized_end=365 + _globals['_GETNODESRESPONSE']._serialized_start=367 + _globals['_GETNODESRESPONSE']._serialized_end=418 + _globals['_PUSHTASKINSREQUEST']._serialized_start=420 + _globals['_PUSHTASKINSREQUEST']._serialized_end=484 + _globals['_PUSHTASKINSRESPONSE']._serialized_start=486 + _globals['_PUSHTASKINSRESPONSE']._serialized_end=525 + _globals['_PULLTASKRESREQUEST']._serialized_start=527 + _globals['_PULLTASKRESREQUEST']._serialized_end=597 + _globals['_PULLTASKRESRESPONSE']._serialized_start=599 + _globals['_PULLTASKRESRESPONSE']._serialized_end=664 + _globals['_DRIVER']._serialized_start=667 + _globals['_DRIVER']._serialized_end=1055 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/driver_pb2.pyi b/src/py/flwr/proto/driver_pb2.pyi index 2d8d11fb59a3..95d4c9785ff1 100644 --- a/src/py/flwr/proto/driver_pb2.pyi +++ b/src/py/flwr/proto/driver_pb2.pyi @@ -16,16 +16,33 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor class CreateRunRequest(google.protobuf.message.Message): """CreateRun""" DESCRIPTOR: google.protobuf.descriptor.Descriptor + class OverrideConfigEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + value: typing.Text + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Text = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + FAB_ID_FIELD_NUMBER: builtins.int FAB_VERSION_FIELD_NUMBER: builtins.int + OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int fab_id: typing.Text fab_version: typing.Text + @property + def override_config(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, typing.Text]: ... def __init__(self, *, fab_id: typing.Text = ..., fab_version: typing.Text = ..., + override_config: typing.Optional[typing.Mapping[typing.Text, typing.Text]] = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config"]) -> None: ... global___CreateRunRequest = CreateRunRequest class CreateRunResponse(google.protobuf.message.Message): diff --git a/src/py/flwr/proto/exec_pb2.py b/src/py/flwr/proto/exec_pb2.py index 7b037a9454c0..4aee0f4a882f 100644 --- a/src/py/flwr/proto/exec_pb2.py +++ b/src/py/flwr/proto/exec_pb2.py @@ -14,21 +14,25 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\"#\n\x0fStartRunRequest\x12\x10\n\x08\x66\x61\x62_file\x18\x01 \x01(\x0c\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"#\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"(\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t2\xa0\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\"\xa4\x01\n\x0fStartRunRequest\x12\x10\n\x08\x66\x61\x62_file\x18\x01 \x01(\x0c\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x1a\x35\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"#\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"(\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t2\xa0\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.exec_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_STARTRUNREQUEST']._serialized_start=37 - _globals['_STARTRUNREQUEST']._serialized_end=72 - _globals['_STARTRUNRESPONSE']._serialized_start=74 - _globals['_STARTRUNRESPONSE']._serialized_end=108 - _globals['_STREAMLOGSREQUEST']._serialized_start=110 - _globals['_STREAMLOGSREQUEST']._serialized_end=145 - _globals['_STREAMLOGSRESPONSE']._serialized_start=147 - _globals['_STREAMLOGSRESPONSE']._serialized_end=187 - _globals['_EXEC']._serialized_start=190 - _globals['_EXEC']._serialized_end=350 + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._options = None + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' + _globals['_STARTRUNREQUEST']._serialized_start=38 + _globals['_STARTRUNREQUEST']._serialized_end=202 + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=149 + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=202 + _globals['_STARTRUNRESPONSE']._serialized_start=204 + _globals['_STARTRUNRESPONSE']._serialized_end=238 + _globals['_STREAMLOGSREQUEST']._serialized_start=240 + _globals['_STREAMLOGSREQUEST']._serialized_end=275 + _globals['_STREAMLOGSRESPONSE']._serialized_start=277 + _globals['_STREAMLOGSRESPONSE']._serialized_end=317 + _globals['_EXEC']._serialized_start=320 + _globals['_EXEC']._serialized_end=480 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/exec_pb2.pyi b/src/py/flwr/proto/exec_pb2.pyi index 466812808da8..8065fc1de1b4 100644 --- a/src/py/flwr/proto/exec_pb2.pyi +++ b/src/py/flwr/proto/exec_pb2.pyi @@ -4,6 +4,7 @@ isort:skip_file """ import builtins import google.protobuf.descriptor +import google.protobuf.internal.containers import google.protobuf.message import typing import typing_extensions @@ -12,13 +13,30 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor class StartRunRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + class OverrideConfigEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + value: typing.Text + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Text = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + FAB_FILE_FIELD_NUMBER: builtins.int + OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int fab_file: builtins.bytes + @property + def override_config(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, typing.Text]: ... def __init__(self, *, fab_file: builtins.bytes = ..., + override_config: typing.Optional[typing.Mapping[typing.Text, typing.Text]] = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["fab_file",b"fab_file"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["fab_file",b"fab_file","override_config",b"override_config"]) -> None: ... global___StartRunRequest = StartRunRequest class StartRunResponse(google.protobuf.message.Message): diff --git a/src/py/flwr/proto/run_pb2.py b/src/py/flwr/proto/run_pb2.py index 13f06e7169aa..d6531201f647 100644 --- a/src/py/flwr/proto/run_pb2.py +++ b/src/py/flwr/proto/run_pb2.py @@ -14,17 +14,21 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\":\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\"\xaf\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x1a\x35\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.run_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_RUN']._serialized_start=36 - _globals['_RUN']._serialized_end=94 - _globals['_GETRUNREQUEST']._serialized_start=96 - _globals['_GETRUNREQUEST']._serialized_end=127 - _globals['_GETRUNRESPONSE']._serialized_start=129 - _globals['_GETRUNRESPONSE']._serialized_end=175 + _globals['_RUN_OVERRIDECONFIGENTRY']._options = None + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' + _globals['_RUN']._serialized_start=37 + _globals['_RUN']._serialized_end=212 + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=159 + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=212 + _globals['_GETRUNREQUEST']._serialized_start=214 + _globals['_GETRUNREQUEST']._serialized_end=245 + _globals['_GETRUNRESPONSE']._serialized_start=247 + _globals['_GETRUNRESPONSE']._serialized_end=293 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/run_pb2.pyi b/src/py/flwr/proto/run_pb2.pyi index 401d27855a41..3c58c04c1734 100644 --- a/src/py/flwr/proto/run_pb2.pyi +++ b/src/py/flwr/proto/run_pb2.pyi @@ -4,6 +4,7 @@ isort:skip_file """ import builtins import google.protobuf.descriptor +import google.protobuf.internal.containers import google.protobuf.message import typing import typing_extensions @@ -12,19 +13,36 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor class Run(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + class OverrideConfigEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + value: typing.Text + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Text = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + RUN_ID_FIELD_NUMBER: builtins.int FAB_ID_FIELD_NUMBER: builtins.int FAB_VERSION_FIELD_NUMBER: builtins.int + OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int run_id: builtins.int fab_id: typing.Text fab_version: typing.Text + @property + def override_config(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, typing.Text]: ... def __init__(self, *, run_id: builtins.int = ..., fab_id: typing.Text = ..., fab_version: typing.Text = ..., + override_config: typing.Optional[typing.Mapping[typing.Text, typing.Text]] = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","run_id",b"run_id"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config","run_id",b"run_id"]) -> None: ... global___Run = Run class GetRunRequest(google.protobuf.message.Message): diff --git a/src/py/flwr/server/__init__.py b/src/py/flwr/server/__init__.py index 546ce263e2d5..896b46298327 100644 --- a/src/py/flwr/server/__init__.py +++ b/src/py/flwr/server/__init__.py @@ -28,6 +28,7 @@ from .server import Server as Server from .server_app import ServerApp as ServerApp from .server_config import ServerConfig as ServerConfig +from .serverapp_components import ServerAppComponents as ServerAppComponents __all__ = [ "ClientManager", @@ -36,6 +37,7 @@ "LegacyContext", "Server", "ServerApp", + "ServerAppComponents", "ServerConfig", "SimpleClientManager", "run_server_app", diff --git a/src/py/flwr/server/compat/legacy_context.py b/src/py/flwr/server/compat/legacy_context.py index 0b00c98bb16d..ee09d79012dc 100644 --- a/src/py/flwr/server/compat/legacy_context.py +++ b/src/py/flwr/server/compat/legacy_context.py @@ -52,4 +52,4 @@ def __init__( self.strategy = strategy self.client_manager = client_manager self.history = History() - super().__init__(state) + super().__init__(node_id=0, node_config={}, state=state, run_config={}) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index e614df659e3f..84da5882eb73 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -16,8 +16,8 @@ import time import warnings -from logging import DEBUG, ERROR, WARNING -from typing import Iterable, List, Optional, Tuple, cast +from logging import DEBUG, WARNING +from typing import Iterable, List, Optional, cast import grpc @@ -27,8 +27,6 @@ from flwr.common.serde import message_from_taskres, message_to_taskins from flwr.common.typing import Run from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 - CreateRunRequest, - CreateRunResponse, GetNodesRequest, GetNodesResponse, PullTaskResRequest, @@ -53,167 +51,103 @@ """ -class GrpcDriverStub: - """`GrpcDriverStub` provides access to the gRPC Driver API/service. +class GrpcDriver(Driver): + """`GrpcDriver` provides an interface to the Driver API. Parameters ---------- - driver_service_address : Optional[str] - The IPv4 or IPv6 address of the Driver API server. - Defaults to `"[::]:9091"`. + run_id : int + The identifier of the run. + driver_service_address : str (default: "[::]:9091") + The address (URL, IPv6, IPv4) of the SuperLink Driver API service. root_certificates : Optional[bytes] (default: None) The PEM-encoded root certificates as a byte string. If provided, a secure connection using the certificates will be established to an SSL-enabled Flower server. """ - def __init__( + def __init__( # pylint: disable=too-many-arguments self, + run_id: int, driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, root_certificates: Optional[bytes] = None, ) -> None: - self.driver_service_address = driver_service_address - self.root_certificates = root_certificates - self.channel: Optional[grpc.Channel] = None - self.stub: Optional[DriverStub] = None + self._run_id = run_id + self._addr = driver_service_address + self._cert = root_certificates + self._run: Optional[Run] = None + self._grpc_stub: Optional[DriverStub] = None + self._channel: Optional[grpc.Channel] = None + self.node = Node(node_id=0, anonymous=True) + + @property + def _is_connected(self) -> bool: + """Check if connected to the Driver API server.""" + return self._channel is not None - def is_connected(self) -> bool: - """Return True if connected to the Driver API server, otherwise False.""" - return self.channel is not None + def _connect(self) -> None: + """Connect to the Driver API. - def connect(self) -> None: - """Connect to the Driver API.""" + This will not call GetRun. + """ event(EventType.DRIVER_CONNECT) - if self.channel is not None or self.stub is not None: + if self._is_connected: log(WARNING, "Already connected") return - self.channel = create_channel( - server_address=self.driver_service_address, - insecure=(self.root_certificates is None), - root_certificates=self.root_certificates, + self._channel = create_channel( + server_address=self._addr, + insecure=(self._cert is None), + root_certificates=self._cert, ) - self.stub = DriverStub(self.channel) - log(DEBUG, "[Driver] Connected to %s", self.driver_service_address) + self._grpc_stub = DriverStub(self._channel) + log(DEBUG, "[Driver] Connected to %s", self._addr) - def disconnect(self) -> None: + def _disconnect(self) -> None: """Disconnect from the Driver API.""" event(EventType.DRIVER_DISCONNECT) - if self.channel is None or self.stub is None: + if not self._is_connected: log(DEBUG, "Already disconnected") return - channel = self.channel - self.channel = None - self.stub = None + channel: grpc.Channel = self._channel + self._channel = None + self._grpc_stub = None channel.close() log(DEBUG, "[Driver] Disconnected") - def create_run(self, req: CreateRunRequest) -> CreateRunResponse: - """Request for run ID.""" - # Check if channel is open - if self.stub is None: - log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverStub` instance not connected") - - # Call Driver API - res: CreateRunResponse = self.stub.CreateRun(request=req) - return res - - def get_run(self, req: GetRunRequest) -> GetRunResponse: - """Get run information.""" - # Check if channel is open - if self.stub is None: - log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverStub` instance not connected") - - # Call gRPC Driver API - res: GetRunResponse = self.stub.GetRun(request=req) - return res - - def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse: - """Get client IDs.""" - # Check if channel is open - if self.stub is None: - log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverStub` instance not connected") - - # Call gRPC Driver API - res: GetNodesResponse = self.stub.GetNodes(request=req) - return res - - def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse: - """Schedule tasks.""" - # Check if channel is open - if self.stub is None: - log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverStub` instance not connected") - - # Call gRPC Driver API - res: PushTaskInsResponse = self.stub.PushTaskIns(request=req) - return res - - def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse: - """Get task results.""" - # Check if channel is open - if self.stub is None: - log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverStub` instance not connected") - - # Call Driver API - res: PullTaskResResponse = self.stub.PullTaskRes(request=req) - return res - - -class GrpcDriver(Driver): - """`Driver` class provides an interface to the Driver API. - - Parameters - ---------- - run_id : int - The identifier of the run. - stub : Optional[GrpcDriverStub] (default: None) - The ``GrpcDriverStub`` instance used to communicate with the SuperLink. - If None, an instance connected to "[::]:9091" will be created. - """ - - def __init__( # pylint: disable=too-many-arguments - self, - run_id: int, - stub: Optional[GrpcDriverStub] = None, - ) -> None: - self._run_id = run_id - self._run: Optional[Run] = None - self.stub = stub if stub is not None else GrpcDriverStub() - self.node = Node(node_id=0, anonymous=True) + def _init_run(self) -> None: + # Check if is initialized + if self._run is not None: + return + # Get the run info + req = GetRunRequest(run_id=self._run_id) + res: GetRunResponse = self._stub.GetRun(req) + if not res.HasField("run"): + raise RuntimeError(f"Cannot find the run with ID: {self._run_id}") + self._run = Run( + run_id=res.run.run_id, + fab_id=res.run.fab_id, + fab_version=res.run.fab_version, + override_config=dict(res.run.override_config.items()), + ) @property def run(self) -> Run: """Run information.""" - self._get_stub_and_run_id() - return Run(**vars(cast(Run, self._run))) + self._init_run() + return Run(**vars(self._run)) - def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]: - # Check if is initialized - if self._run is None: - # Connect - if not self.stub.is_connected(): - self.stub.connect() - # Get the run info - req = GetRunRequest(run_id=self._run_id) - res = self.stub.get_run(req) - if not res.HasField("run"): - raise RuntimeError(f"Cannot find the run with ID: {self._run_id}") - self._run = Run( - run_id=res.run.run_id, - fab_id=res.run.fab_id, - fab_version=res.run.fab_version, - ) - - return self.stub, self._run.run_id + @property + def _stub(self) -> DriverStub: + """Driver stub.""" + if not self._is_connected: + self._connect() + return cast(DriverStub, self._grpc_stub) def _check_message(self, message: Message) -> None: # Check if the message is valid if not ( - message.metadata.run_id == cast(Run, self._run).run_id + # Assume self._run being initialized + message.metadata.run_id == self._run_id and message.metadata.src_node_id == self.node.node_id and message.metadata.message_id == "" and message.metadata.reply_to_message == "" @@ -234,7 +168,7 @@ def create_message( # pylint: disable=too-many-arguments This method constructs a new `Message` with given content and metadata. The `run_id` and `src_node_id` will be set automatically. """ - _, run_id = self._get_stub_and_run_id() + self._init_run() if ttl: warnings.warn( "A custom TTL was set, but note that the SuperLink does not enforce " @@ -245,7 +179,7 @@ def create_message( # pylint: disable=too-many-arguments ttl_ = DEFAULT_TTL if ttl is None else ttl metadata = Metadata( - run_id=run_id, + run_id=self._run_id, message_id="", # Will be set by the server src_node_id=self.node.node_id, dst_node_id=dst_node_id, @@ -258,9 +192,11 @@ def create_message( # pylint: disable=too-many-arguments def get_node_ids(self) -> List[int]: """Get node IDs.""" - stub, run_id = self._get_stub_and_run_id() + self._init_run() # Call GrpcDriverStub method - res = stub.get_nodes(GetNodesRequest(run_id=run_id)) + res: GetNodesResponse = self._stub.GetNodes( + GetNodesRequest(run_id=self._run_id) + ) return [node.node_id for node in res.nodes] def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: @@ -269,7 +205,7 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: This method takes an iterable of messages and sends each message to the node specified in `dst_node_id`. """ - stub, _ = self._get_stub_and_run_id() + self._init_run() # Construct TaskIns task_ins_list: List[TaskIns] = [] for msg in messages: @@ -280,7 +216,9 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: # Add to list task_ins_list.append(taskins) # Call GrpcDriverStub method - res = stub.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list)) + res: PushTaskInsResponse = self._stub.PushTaskIns( + PushTaskInsRequest(task_ins_list=task_ins_list) + ) return list(res.task_ids) def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: @@ -289,9 +227,9 @@ def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: This method is used to collect messages from the SuperLink that correspond to a set of given message IDs. """ - stub, _ = self._get_stub_and_run_id() + self._init_run() # Pull TaskRes - res = stub.pull_task_res( + res: PullTaskResResponse = self._stub.PullTaskRes( PullTaskResRequest(node=self.node, task_ids=message_ids) ) # Convert TaskRes to Message @@ -331,7 +269,7 @@ def send_and_receive( def close(self) -> None: """Disconnect from the SuperLink if connected.""" # Check if `connect` was called before - if not self.stub.is_connected(): + if not self._is_connected: return # Disconnect - self.stub.disconnect() + self._disconnect() diff --git a/src/py/flwr/server/driver/grpc_driver_test.py b/src/py/flwr/server/driver/grpc_driver_test.py index 72efc5f8b2c6..fdf3c676190d 100644 --- a/src/py/flwr/server/driver/grpc_driver_test.py +++ b/src/py/flwr/server/driver/grpc_driver_test.py @@ -41,10 +41,13 @@ def setUp(self) -> None: mock_response = Mock( run=Run(run_id=61016, fab_id="mock/mock", fab_version="v1.0.0") ) - self.mock_grpc_driver_stub = Mock() - self.mock_grpc_driver_stub.get_run.return_value = mock_response - self.mock_grpc_driver_stub.HasField.return_value = True - self.driver = GrpcDriver(run_id=61016, stub=self.mock_grpc_driver_stub) + self.mock_stub = Mock() + self.mock_channel = Mock() + self.mock_stub.GetRun.return_value = mock_response + mock_response.HasField.return_value = True + self.driver = GrpcDriver(run_id=61016) + self.driver._grpc_stub = self.mock_stub # pylint: disable=protected-access + self.driver._channel = self.mock_channel # pylint: disable=protected-access def test_init_grpc_driver(self) -> None: """Test GrpcDriverStub initialization.""" @@ -52,21 +55,21 @@ def test_init_grpc_driver(self) -> None: self.assertEqual(self.driver.run.run_id, 61016) self.assertEqual(self.driver.run.fab_id, "mock/mock") self.assertEqual(self.driver.run.fab_version, "v1.0.0") - self.mock_grpc_driver_stub.get_run.assert_called_once() + self.mock_stub.GetRun.assert_called_once() def test_get_nodes(self) -> None: """Test retrieval of nodes.""" # Prepare mock_response = Mock() mock_response.nodes = [Mock(node_id=404), Mock(node_id=200)] - self.mock_grpc_driver_stub.get_nodes.return_value = mock_response + self.mock_stub.GetNodes.return_value = mock_response # Execute node_ids = self.driver.get_node_ids() - args, kwargs = self.mock_grpc_driver_stub.get_nodes.call_args + args, kwargs = self.mock_stub.GetNodes.call_args # Assert - self.mock_grpc_driver_stub.get_run.assert_called_once() + self.mock_stub.GetRun.assert_called_once() self.assertEqual(len(args), 1) self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], GetNodesRequest) @@ -77,7 +80,7 @@ def test_push_messages_valid(self) -> None: """Test pushing valid messages.""" # Prepare mock_response = Mock(task_ids=["id1", "id2"]) - self.mock_grpc_driver_stub.push_task_ins.return_value = mock_response + self.mock_stub.PushTaskIns.return_value = mock_response msgs = [ self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL) for _ in range(2) @@ -85,10 +88,10 @@ def test_push_messages_valid(self) -> None: # Execute msg_ids = self.driver.push_messages(msgs) - args, kwargs = self.mock_grpc_driver_stub.push_task_ins.call_args + args, kwargs = self.mock_stub.PushTaskIns.call_args # Assert - self.mock_grpc_driver_stub.get_run.assert_called_once() + self.mock_stub.GetRun.assert_called_once() self.assertEqual(len(args), 1) self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], PushTaskInsRequest) @@ -100,7 +103,7 @@ def test_push_messages_invalid(self) -> None: """Test pushing invalid messages.""" # Prepare mock_response = Mock(task_ids=["id1", "id2"]) - self.mock_grpc_driver_stub.push_task_ins.return_value = mock_response + self.mock_stub.PushTaskIns.return_value = mock_response msgs = [ self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL) for _ in range(2) @@ -124,16 +127,16 @@ def test_pull_messages_with_given_message_ids(self) -> None: ), TaskRes(task=Task(ancestry=["id3"], error=error_to_proto(Error(code=0)))), ] - self.mock_grpc_driver_stub.pull_task_res.return_value = mock_response + self.mock_stub.PullTaskRes.return_value = mock_response msg_ids = ["id1", "id2", "id3"] # Execute msgs = self.driver.pull_messages(msg_ids) reply_tos = {msg.metadata.reply_to_message for msg in msgs} - args, kwargs = self.mock_grpc_driver_stub.pull_task_res.call_args + args, kwargs = self.mock_stub.PullTaskRes.call_args # Assert - self.mock_grpc_driver_stub.get_run.assert_called_once() + self.mock_stub.GetRun.assert_called_once() self.assertEqual(len(args), 1) self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], PullTaskResRequest) @@ -144,14 +147,14 @@ def test_send_and_receive_messages_complete(self) -> None: """Test send and receive all messages successfully.""" # Prepare mock_response = Mock(task_ids=["id1"]) - self.mock_grpc_driver_stub.push_task_ins.return_value = mock_response + self.mock_stub.PushTaskIns.return_value = mock_response # The response message must include either `content` (i.e. a recordset) or # an `Error`. We choose the latter in this case error_proto = error_to_proto(Error(code=0)) mock_response = Mock( task_res_list=[TaskRes(task=Task(ancestry=["id1"], error=error_proto))] ) - self.mock_grpc_driver_stub.pull_task_res.return_value = mock_response + self.mock_stub.PullTaskRes.return_value = mock_response msgs = [self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL)] # Execute @@ -166,9 +169,9 @@ def test_send_and_receive_messages_timeout(self) -> None: # Prepare sleep_fn = time.sleep mock_response = Mock(task_ids=["id1"]) - self.mock_grpc_driver_stub.push_task_ins.return_value = mock_response + self.mock_stub.PushTaskIns.return_value = mock_response mock_response = Mock(task_res_list=[]) - self.mock_grpc_driver_stub.pull_task_res.return_value = mock_response + self.mock_stub.PullTaskRes.return_value = mock_response msgs = [self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL)] # Execute @@ -182,22 +185,20 @@ def test_send_and_receive_messages_timeout(self) -> None: def test_del_with_initialized_driver(self) -> None: """Test cleanup behavior when Driver is initialized.""" - # Prepare - self.mock_grpc_driver_stub.is_connected.return_value = True - # Execute self.driver.close() # Assert - self.mock_grpc_driver_stub.disconnect.assert_called_once() + self.mock_channel.close.assert_called_once() def test_del_with_uninitialized_driver(self) -> None: """Test cleanup behavior when Driver is not initialized.""" # Prepare - self.mock_grpc_driver_stub.is_connected.return_value = False + self.driver._grpc_stub = None # pylint: disable=protected-access + self.driver._channel = None # pylint: disable=protected-access # Execute self.driver.close() # Assert - self.mock_grpc_driver_stub.disconnect.assert_not_called() + self.mock_channel.close.assert_not_called() diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index eff38f548826..d0f32e830f7d 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -15,7 +15,6 @@ """Tests for in-memory driver.""" -import os import time import unittest from typing import Iterable, List, Tuple @@ -23,7 +22,7 @@ from uuid import uuid4 from flwr.common import RecordSet -from flwr.common.constant import PING_MAX_INTERVAL +from flwr.common.constant import NODE_ID_NUM_BYTES, PING_MAX_INTERVAL from flwr.common.message import Error from flwr.common.serde import ( error_to_proto, @@ -34,6 +33,7 @@ from flwr.common.typing import Run from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 from flwr.server.superlink.state import InMemoryState, SqliteState, StateFactory +from flwr.server.superlink.state.utils import generate_rand_int_from_bytes from .inmemory_driver import InMemoryDriver @@ -82,11 +82,14 @@ def setUp(self) -> None: self.num_nodes = 42 self.state = MagicMock() self.state.get_nodes.return_value = [ - int.from_bytes(os.urandom(8), "little", signed=True) + generate_rand_int_from_bytes(NODE_ID_NUM_BYTES) for _ in range(self.num_nodes) ] self.state.get_run.return_value = Run( - run_id=61016, fab_id="mock/mock", fab_version="v1.0.0" + run_id=61016, + fab_id="mock/mock", + fab_version="v1.0.0", + override_config={"test_key": "test_value"}, ) state_factory = MagicMock(state=lambda: self.state) self.driver = InMemoryDriver(run_id=61016, state_factory=state_factory) @@ -98,6 +101,7 @@ def test_get_run(self) -> None: self.assertEqual(self.driver.run.run_id, 61016) self.assertEqual(self.driver.run.fab_id, "mock/mock") self.assertEqual(self.driver.run.fab_version, "v1.0.0") + self.assertEqual(self.driver.run.override_config["test_key"], "test_value") def test_get_nodes(self) -> None: """Test retrieval of nodes.""" @@ -223,7 +227,7 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None: # Prepare state = StateFactory("").state() self.driver = InMemoryDriver( - state.create_run("", ""), MagicMock(state=lambda: state) + state.create_run("", "", {}), MagicMock(state=lambda: state) ) msg_ids, node_id = push_messages(self.driver, self.num_nodes) assert isinstance(state, SqliteState) @@ -249,7 +253,7 @@ def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None: # Prepare state_factory = StateFactory(":flwr-in-memory-state:") state = state_factory.state() - self.driver = InMemoryDriver(state.create_run("", ""), state_factory) + self.driver = InMemoryDriver(state.create_run("", "", {}), state_factory) msg_ids, node_id = push_messages(self.driver, self.num_nodes) assert isinstance(state, InMemoryState) diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index 3505ebfdb0a9..4cc25feb7e0e 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -19,16 +19,24 @@ import sys from logging import DEBUG, INFO, WARN from pathlib import Path -from typing import Optional +from typing import Dict, Optional from flwr.common import Context, EventType, RecordSet, event -from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir +from flwr.common.config import ( + get_flwr_dir, + get_fused_config, + get_project_config, + get_project_dir, +) from flwr.common.logger import log, update_console_handler, warn_deprecated_feature from flwr.common.object_ref import load_app -from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611 +from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 + CreateRunRequest, + CreateRunResponse, +) from .driver import Driver -from .driver.grpc_driver import GrpcDriver, GrpcDriverStub +from .driver.grpc_driver import GrpcDriver from .server_app import LoadServerAppError, ServerApp ADDRESS_DRIVER_API = "0.0.0.0:9091" @@ -37,6 +45,7 @@ def run( driver: Driver, server_app_dir: str, + server_app_run_config: Dict[str, str], server_app_attr: Optional[str] = None, loaded_server_app: Optional[ServerApp] = None, ) -> None: @@ -69,7 +78,9 @@ def _load() -> ServerApp: server_app = _load() # Initialize Context - context = Context(state=RecordSet()) + context = Context( + node_id=0, node_config={}, state=RecordSet(), run_config=server_app_run_config + ) # Call ServerApp server_app(driver=driver, context=context) @@ -144,22 +155,29 @@ def run_server_app() -> None: # pylint: disable=too-many-branches "For more details, use: ``flower-server-app -h``" ) - stub = GrpcDriverStub( - driver_service_address=args.superlink, root_certificates=root_certificates - ) + # Initialize GrpcDriver if args.run_id is not None: # User provided `--run-id`, but not `server-app` - run_id = args.run_id + driver = GrpcDriver( + run_id=args.run_id, + driver_service_address=args.superlink, + root_certificates=root_certificates, + ) else: # User provided `server-app`, but not `--run-id` # Create run if run_id is not provided - stub.connect() + driver = GrpcDriver( + run_id=0, # Will be overwritten + driver_service_address=args.superlink, + root_certificates=root_certificates, + ) + # Create run req = CreateRunRequest(fab_id=args.fab_id, fab_version=args.fab_version) - res = stub.create_run(req) - run_id = res.run_id + res: CreateRunResponse = driver._stub.CreateRun(req) # pylint: disable=W0212 + # Overwrite driver._run_id + driver._run_id = res.run_id # pylint: disable=W0212 - # Initialize GrpcDriver - driver = GrpcDriver(run_id=run_id, stub=stub) + server_app_run_config = {} # Dynamically obtain ServerApp path based on run_id if args.run_id is not None: @@ -169,6 +187,7 @@ def run_server_app() -> None: # pylint: disable=too-many-branches server_app_dir = str(get_project_dir(run_.fab_id, run_.fab_version, flwr_dir)) config = get_project_config(server_app_dir) server_app_attr = config["flower"]["components"]["serverapp"] + server_app_run_config = get_fused_config(run_, flwr_dir) else: # User provided `server-app`, but not `--run-id` server_app_dir = str(Path(args.dir).absolute()) @@ -182,7 +201,12 @@ def run_server_app() -> None: # pylint: disable=too-many-branches ) # Run the ServerApp with the Driver - run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr) + run( + driver=driver, + server_app_dir=server_app_dir, + server_app_run_config=server_app_run_config, + server_app_attr=server_app_attr, + ) # Clean up driver.close() diff --git a/src/py/flwr/server/server_app.py b/src/py/flwr/server/server_app.py index 43b3bcce3f36..e9cb4ddcaf0d 100644 --- a/src/py/flwr/server/server_app.py +++ b/src/py/flwr/server/server_app.py @@ -17,8 +17,11 @@ from typing import Callable, Optional -from flwr.common import Context, RecordSet -from flwr.common.logger import warn_preview_feature +from flwr.common import Context +from flwr.common.logger import ( + warn_deprecated_feature_with_example, + warn_preview_feature, +) from flwr.server.strategy import Strategy from .client_manager import ClientManager @@ -26,7 +29,20 @@ from .driver import Driver from .server import Server from .server_config import ServerConfig -from .typing import ServerAppCallable +from .typing import ServerAppCallable, ServerFn + +SERVER_FN_USAGE_EXAMPLE = """ + + def server_fn(context: Context): + server_config = ServerConfig(num_rounds=3) + strategy = FedAvg() + return ServerAppComponents( + strategy=strategy, + server_config=server_config, + ) + + app = ServerApp(server_fn=server_fn) +""" class ServerApp: @@ -36,13 +52,15 @@ class ServerApp: -------- Use the `ServerApp` with an existing `Strategy`: - >>> server_config = ServerConfig(num_rounds=3) - >>> strategy = FedAvg() + >>> def server_fn(context: Context): + >>> server_config = ServerConfig(num_rounds=3) + >>> strategy = FedAvg() + >>> return ServerAppComponents( + >>> strategy=strategy, + >>> server_config=server_config, + >>> ) >>> - >>> app = ServerApp( - >>> server_config=server_config, - >>> strategy=strategy, - >>> ) + >>> app = ServerApp(server_fn=server_fn) Use the `ServerApp` with a custom main function: @@ -53,23 +71,52 @@ class ServerApp: >>> print("ServerApp running") """ + # pylint: disable=too-many-arguments def __init__( self, server: Optional[Server] = None, config: Optional[ServerConfig] = None, strategy: Optional[Strategy] = None, client_manager: Optional[ClientManager] = None, + server_fn: Optional[ServerFn] = None, ) -> None: + if any([server, config, strategy, client_manager]): + warn_deprecated_feature_with_example( + deprecation_message="Passing either `server`, `config`, `strategy` or " + "`client_manager` directly to the ServerApp " + "constructor is deprecated.", + example_message="Pass `ServerApp` arguments wrapped " + "in a `flwr.server.ServerAppComponents` object that gets " + "returned by a function passed as the `server_fn` argument " + "to the `ServerApp` constructor. For example: ", + code_example=SERVER_FN_USAGE_EXAMPLE, + ) + + if server_fn: + raise ValueError( + "Passing `server_fn` is incompatible with passing the " + "other arguments (now deprecated) to ServerApp. " + "Use `server_fn` exclusively." + ) + self._server = server self._config = config self._strategy = strategy self._client_manager = client_manager + self._server_fn = server_fn self._main: Optional[ServerAppCallable] = None def __call__(self, driver: Driver, context: Context) -> None: """Execute `ServerApp`.""" # Compatibility mode if not self._main: + if self._server_fn: + # Execute server_fn() + components = self._server_fn(context) + self._server = components.server + self._config = components.config + self._strategy = components.strategy + self._client_manager = components.client_manager start_driver( server=self._server, config=self._config, @@ -80,7 +127,6 @@ def __call__(self, driver: Driver, context: Context) -> None: return # New execution mode - context = Context(state=RecordSet()) self._main(driver, context) def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]: diff --git a/src/py/flwr/server/server_app_test.py b/src/py/flwr/server/server_app_test.py index 0751a0cb2bc5..b0672b3202ed 100644 --- a/src/py/flwr/server/server_app_test.py +++ b/src/py/flwr/server/server_app_test.py @@ -29,7 +29,7 @@ def test_server_app_custom_mode() -> None: # Prepare app = ServerApp() driver = MagicMock() - context = Context(state=RecordSet()) + context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={}) called = {"called": False} diff --git a/src/py/flwr/server/serverapp_components.py b/src/py/flwr/server/serverapp_components.py new file mode 100644 index 000000000000..315f0a889a61 --- /dev/null +++ b/src/py/flwr/server/serverapp_components.py @@ -0,0 +1,52 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ServerAppComponents for the ServerApp.""" + + +from dataclasses import dataclass +from typing import Optional + +from .client_manager import ClientManager +from .server import Server +from .server_config import ServerConfig +from .strategy import Strategy + + +@dataclass +class ServerAppComponents: # pylint: disable=too-many-instance-attributes + """Components to construct a ServerApp. + + Parameters + ---------- + server : Optional[Server] (default: None) + A server implementation, either `flwr.server.Server` or a subclass + thereof. If no instance is provided, one will be created internally. + config : Optional[ServerConfig] (default: None) + Currently supported values are `num_rounds` (int, default: 1) and + `round_timeout` in seconds (float, default: None). + strategy : Optional[Strategy] (default: None) + An implementation of the abstract base class + `flwr.server.strategy.Strategy`. If no strategy is provided, then + `flwr.server.strategy.FedAvg` will be used. + client_manager : Optional[ClientManager] (default: None) + An implementation of the class `flwr.server.ClientManager`. If no + implementation is provided, then `flwr.server.SimpleClientManager` + will be used. + """ + + server: Optional[Server] = None + config: Optional[ServerConfig] = None + strategy: Optional[Strategy] = None + client_manager: Optional[ClientManager] = None diff --git a/src/py/flwr/server/superlink/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py index 03128f02158e..7f8ded3bdb85 100644 --- a/src/py/flwr/server/superlink/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -69,7 +69,11 @@ def CreateRun( """Create run ID.""" log(DEBUG, "DriverServicer.CreateRun") state: State = self.state_factory.state() - run_id = state.create_run(request.fab_id, request.fab_version) + run_id = state.create_run( + request.fab_id, + request.fab_version, + dict(request.override_config.items()), + ) return CreateRunResponse(run_id=run_id) def PushTaskIns( diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py index 01499102b7d8..798e71435585 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py @@ -328,7 +328,7 @@ def test_successful_get_run_with_metadata(self) -> None: self.state.create_node( ping_interval=30, public_key=public_key_to_bytes(self._client_public_key) ) - run_id = self.state.create_run("", "") + run_id = self.state.create_run("", "", {}) request = GetRunRequest(run_id=run_id) shared_secret = generate_shared_key( self._client_private_key, self._server_public_key @@ -359,7 +359,7 @@ def test_unsuccessful_get_run_with_metadata(self) -> None: self.state.create_node( ping_interval=30, public_key=public_key_to_bytes(self._client_public_key) ) - run_id = self.state.create_run("", "") + run_id = self.state.create_run("", "", {}) request = GetRunRequest(run_id=run_id) client_private_key, _ = generate_key_pairs() shared_secret = generate_shared_key(client_private_key, self._server_public_key) diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/backend.py b/src/py/flwr/server/superlink/fleet/vce/backend/backend.py index 1d5e3a6a51ad..31c64bd3b233 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/backend.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/backend.py @@ -33,8 +33,8 @@ def __init__(self, backend_config: BackendConfig, work_dir: str) -> None: """Construct a backend.""" @abstractmethod - async def build(self) -> None: - """Build backend asynchronously. + def build(self) -> None: + """Build backend. Different components need to be in place before workers in a backend are ready to accept jobs. When this method finishes executing, the backend should be fully @@ -54,11 +54,11 @@ def is_worker_idle(self) -> bool: """Report whether a backend worker is idle and can therefore run a ClientApp.""" @abstractmethod - async def terminate(self) -> None: + def terminate(self) -> None: """Terminate backend.""" @abstractmethod - async def process_message( + def process_message( self, app: Callable[[], ClientApp], message: Message, diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py index 8a21393db590..0ab29a234f88 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py @@ -21,6 +21,7 @@ import ray from flwr.client.client_app import ClientApp +from flwr.common.constant import PARTITION_ID_KEY from flwr.common.context import Context from flwr.common.logger import log from flwr.common.message import Message @@ -153,12 +154,12 @@ def is_worker_idle(self) -> bool: """Report whether the pool has idle actors.""" return self.pool.is_actor_available() - async def build(self) -> None: + def build(self) -> None: """Build pool of Ray actors that this backend will submit jobs to.""" - await self.pool.add_actors_to_pool(self.pool.actors_capacity) + self.pool.add_actors_to_pool(self.pool.actors_capacity) log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors) - async def process_message( + def process_message( self, app: Callable[[], ClientApp], message: Message, @@ -168,21 +169,20 @@ async def process_message( Return output message and updated context. """ - partition_id = message.metadata.partition_id + partition_id = context.node_config[PARTITION_ID_KEY] try: # Submit a task to the pool - future = await self.pool.submit( + future = self.pool.submit( lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), (app, message, str(partition_id), context), ) - await future # Fetch result ( out_mssg, updated_context, - ) = await self.pool.fetch_result_and_return_actor_to_pool(future) + ) = self.pool.fetch_result_and_return_actor_to_pool(future) return out_mssg, updated_context @@ -193,11 +193,11 @@ async def process_message( self.__class__.__name__, ) # add actor back into pool - await self.pool.add_actor_back_to_pool(future) + self.pool.add_actor_back_to_pool(future) raise ex - async def terminate(self) -> None: + def terminate(self) -> None: """Terminate all actors in actor pool.""" - await self.pool.terminate_all_actors() + self.pool.terminate_all_actors() ray.shutdown() log(DEBUG, "Terminated %s", self.__class__.__name__) diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py index 57c952cc9310..a38cff96ceef 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py @@ -14,16 +14,16 @@ # ============================================================================== """Test for Ray backend for the Fleet API using the Simulation Engine.""" -import asyncio from math import pi from pathlib import Path from typing import Callable, Dict, Optional, Tuple, Union -from unittest import IsolatedAsyncioTestCase +from unittest import TestCase import ray from flwr.client import Client, NumPyClient from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.client.node_state import NodeState from flwr.common import ( DEFAULT_TTL, Config, @@ -33,9 +33,9 @@ Message, MessageTypeLegacy, Metadata, - RecordSet, Scalar, ) +from flwr.common.constant import PARTITION_ID_KEY from flwr.common.object_ref import load_app from flwr.common.recordset_compat import getpropertiesins_to_recordset from flwr.server.superlink.fleet.vce.backend.backend import BackendConfig @@ -54,7 +54,7 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]: return {"result": result} -def get_dummy_client(cid: str) -> Client: # pylint: disable=unused-argument +def get_dummy_client(context: Context) -> Client: # pylint: disable=unused-argument """Return a DummyClient converted to Client type.""" return DummyClient().to_client() @@ -82,18 +82,18 @@ def _load_app() -> ClientApp: return _load_app -async def backend_build_process_and_termination( +def backend_build_process_and_termination( backend: RayBackend, process_args: Optional[Tuple[Callable[[], ClientApp], Message, Context]] = None, ) -> Union[Tuple[Message, Context], None]: """Build, process job and terminate RayBackend.""" - await backend.build() + backend.build() to_return = None if process_args: - to_return = await backend.process_message(*process_args) + to_return = backend.process_message(*process_args) - await backend.terminate() + backend.terminate() return to_return @@ -102,12 +102,13 @@ def _create_message_and_context() -> Tuple[Message, Context, float]: # Construct a Message mult_factor = 2024 + run_id = 0 getproperties_ins = GetPropertiesIns(config={"factor": mult_factor}) recordset = getpropertiesins_to_recordset(getproperties_ins) message = Message( content=recordset, metadata=Metadata( - run_id=0, + run_id=run_id, message_id="", group_id="", src_node_id=0, @@ -118,8 +119,10 @@ def _create_message_and_context() -> Tuple[Message, Context, float]: ), ) - # Construct emtpy Context - context = Context(state=RecordSet()) + # Construct NodeState and retrieve context + node_state = NodeState(node_id=run_id, node_config={PARTITION_ID_KEY: str(0)}) + node_state.register_context(run_id=run_id) + context = node_state.retrieve_context(run_id=run_id) # Expected output expected_output = pi * mult_factor @@ -127,10 +130,10 @@ def _create_message_and_context() -> Tuple[Message, Context, float]: return message, context, expected_output -class AsyncTestRayBackend(IsolatedAsyncioTestCase): - """A basic class that allows runnig multliple asyncio tests.""" +class TestRayBackend(TestCase): + """A basic class that allows runnig multliple tests.""" - async def on_cleanup(self) -> None: + def doCleanups(self) -> None: """Ensure Ray has shutdown.""" if ray.is_initialized(): ray.shutdown() @@ -138,9 +141,7 @@ async def on_cleanup(self) -> None: def test_backend_creation_and_termination(self) -> None: """Test creation of RayBackend and its termination.""" backend = RayBackend(backend_config={}, work_dir="") - asyncio.run( - backend_build_process_and_termination(backend=backend, process_args=None) - ) + backend_build_process_and_termination(backend=backend, process_args=None) def test_backend_creation_submit_and_termination( self, @@ -155,10 +156,8 @@ def test_backend_creation_submit_and_termination( message, context, expected_output = _create_message_and_context() - res = asyncio.run( - backend_build_process_and_termination( - backend=backend, process_args=(client_app_callable, message, context) - ) + res = backend_build_process_and_termination( + backend=backend, process_args=(client_app_callable, message, context) ) if res is None: @@ -187,7 +186,6 @@ def test_backend_creation_submit_and_termination_non_existing_client_app( self.test_backend_creation_submit_and_termination( client_app_loader=_load_from_module("a_non_existing_module:app") ) - self.addAsyncCleanup(self.on_cleanup) def test_backend_creation_submit_and_termination_existing_client_app( self, @@ -215,7 +213,6 @@ def test_backend_creation_submit_and_termination_existing_client_app_unsetworkdi client_app_loader=_load_from_module("raybackend_test:client_app"), workdir="/?&%$^#%@$!", ) - self.addAsyncCleanup(self.on_cleanup) def test_backend_creation_with_init_arguments(self) -> None: """Testing whether init args are properly parsed to Ray.""" @@ -246,5 +243,3 @@ def test_backend_creation_with_init_arguments(self) -> None: nodes = ray.nodes() assert nodes[0]["Resources"]["CPU"] == backend_config_2["init_args"]["num_cpus"] - - self.addAsyncCleanup(self.on_cleanup) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index 3c9628a6d2a3..cd30c40167c5 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -14,24 +14,33 @@ # ============================================================================== """Fleet Simulation Engine API.""" -import asyncio + import json import sys +import threading import time import traceback +from concurrent.futures import ThreadPoolExecutor from logging import DEBUG, ERROR, INFO, WARN from pathlib import Path -from typing import Callable, Dict, List, Optional +from queue import Empty, Queue +from time import sleep +from typing import Callable, Dict, Optional from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError from flwr.client.node_state import NodeState -from flwr.common.constant import PING_MAX_INTERVAL, ErrorCode +from flwr.common.constant import ( + NUM_PARTITIONS_KEY, + PARTITION_ID_KEY, + PING_MAX_INTERVAL, + ErrorCode, +) from flwr.common.logger import log from flwr.common.message import Error from flwr.common.object_ref import load_app from flwr.common.serde import message_from_taskins, message_to_taskres -from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 -from flwr.server.superlink.state import StateFactory +from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 +from flwr.server.superlink.state import State, StateFactory from .backend import Backend, error_messages_backends, supported_backends @@ -52,33 +61,32 @@ def _register_nodes( # pylint: disable=too-many-arguments,too-many-locals -async def worker( +def worker( app_fn: Callable[[], ClientApp], - queue: "asyncio.Queue[TaskIns]", + taskins_queue: "Queue[TaskIns]", + taskres_queue: "Queue[TaskRes]", node_states: Dict[int, NodeState], - state_factory: StateFactory, - nodes_mapping: NodeToPartitionMapping, backend: Backend, + f_stop: threading.Event, ) -> None: """Get TaskIns from queue and pass it to an actor in the pool to execute it.""" - state = state_factory.state() - while True: + while not f_stop.is_set(): out_mssg = None try: - task_ins: TaskIns = await queue.get() + # Fetch from queue with timeout. We use a timeout so + # the stopping event can be evaluated even when the queue is empty. + task_ins: TaskIns = taskins_queue.get(timeout=1.0) node_id = task_ins.task.consumer.node_id - # Register and retrieve runstate + # Register and retrieve context node_states[node_id].register_context(run_id=task_ins.run_id) context = node_states[node_id].retrieve_context(run_id=task_ins.run_id) # Convert TaskIns to Message message = message_from_taskins(task_ins) - # Set partition_id - message.metadata.partition_id = nodes_mapping[node_id] # Let backend process message - out_mssg, updated_context = await backend.process_message( + out_mssg, updated_context = backend.process_message( app_fn, message, context ) @@ -86,11 +94,9 @@ async def worker( node_states[node_id].update_context( task_ins.run_id, context=updated_context ) - - except asyncio.CancelledError as e: - log(DEBUG, "Terminating async worker: %s", e) - break - + except Empty: + # An exception raised if queue.get times out + pass # Exceptions aren't raised but reported as an error message except Exception as ex: # pylint: disable=broad-exception-caught log(ERROR, ex) @@ -114,67 +120,48 @@ async def worker( task_res = message_to_taskres(out_mssg) # Store TaskRes in state task_res.task.pushed_at = time.time() - state.store_task_res(task_res) + taskres_queue.put(task_res) -async def add_taskins_to_queue( - queue: "asyncio.Queue[TaskIns]", - state_factory: StateFactory, +def add_taskins_to_queue( + state: State, + queue: "Queue[TaskIns]", nodes_mapping: NodeToPartitionMapping, - backend: Backend, - consumers: List["asyncio.Task[None]"], - f_stop: asyncio.Event, + f_stop: threading.Event, ) -> None: - """Retrieve TaskIns and add it to the queue.""" - state = state_factory.state() - num_initial_consumers = len(consumers) + """Put TaskIns in a queue from State.""" while not f_stop.is_set(): for node_id in nodes_mapping.keys(): - task_ins = state.get_task_ins(node_id=node_id, limit=1) - if task_ins: - await queue.put(task_ins[0]) - - # Count consumers that are running - num_active = sum(not (cc.done()) for cc in consumers) - - # Alert if number of consumers decreased by half - if num_active < num_initial_consumers // 2: - log( - WARN, - "Number of active workers has more than halved: (%i/%i active)", - num_active, - num_initial_consumers, - ) + task_ins_list = state.get_task_ins(node_id=node_id, limit=1) + for task_ins in task_ins_list: + queue.put(task_ins) + sleep(0.1) - # Break if consumers died - if num_active == 0: - raise RuntimeError("All workers have died. Ending Simulation.") - # Log some stats - log( - DEBUG, - "Simulation Engine stats: " - "Active workers: (%i/%i) | %s (%i workers) | Tasks in queue: %i)", - num_active, - num_initial_consumers, - backend.__class__.__name__, - backend.num_workers, - queue.qsize(), - ) - await asyncio.sleep(1.0) - log(DEBUG, "Async producer: Stopped pulling from StateFactory.") +def put_taskres_into_state( + state: State, queue: "Queue[TaskRes]", f_stop: threading.Event +) -> None: + """Put TaskRes into State from a queue.""" + while not f_stop.is_set(): + try: + taskres = queue.get(timeout=1.0) + state.store_task_res(taskres) + except Empty: + # queue is empty when timeout was triggered + pass -async def run( +def run( app_fn: Callable[[], ClientApp], backend_fn: Callable[[], Backend], nodes_mapping: NodeToPartitionMapping, state_factory: StateFactory, node_states: Dict[int, NodeState], - f_stop: asyncio.Event, + f_stop: threading.Event, ) -> None: - """Run the VCE async.""" - queue: "asyncio.Queue[TaskIns]" = asyncio.Queue(128) + """Run the VCE.""" + taskins_queue: "Queue[TaskIns]" = Queue() + taskres_queue: "Queue[TaskRes]" = Queue() try: @@ -182,29 +169,48 @@ async def run( backend = backend_fn() # Build backend - await backend.build() + backend.build() # Add workers (they submit Messages to Backend) - worker_tasks = [ - asyncio.create_task( - worker( - app_fn, queue, node_states, state_factory, nodes_mapping, backend - ) - ) - for _ in range(backend.num_workers) - ] - # Create producer (adds TaskIns into Queue) - producer = asyncio.create_task( - add_taskins_to_queue( - queue, state_factory, nodes_mapping, backend, worker_tasks, f_stop - ) + state = state_factory.state() + + extractor_th = threading.Thread( + target=add_taskins_to_queue, + args=( + state, + taskins_queue, + nodes_mapping, + f_stop, + ), + ) + extractor_th.start() + + injector_th = threading.Thread( + target=put_taskres_into_state, + args=( + state, + taskres_queue, + f_stop, + ), ) + injector_th.start() + + with ThreadPoolExecutor() as executor: + _ = [ + executor.submit( + worker, + app_fn, + taskins_queue, + taskres_queue, + node_states, + backend, + f_stop, + ) + for _ in range(backend.num_workers) + ] - # Wait for producer to finish - # The producer runs forever until f_stop is set or until - # all worker (consumer) coroutines are completed. Workers - # also run forever and only end if an exception is raised. - await asyncio.gather(producer) + extractor_th.join() + injector_th.join() except Exception as ex: @@ -219,18 +225,9 @@ async def run( raise RuntimeError("Simulation Engine crashed.") from ex finally: - # Produced task terminated, now cancel worker tasks - for w_t in worker_tasks: - _ = w_t.cancel() - - while not all(w_t.done() for w_t in worker_tasks): - log(DEBUG, "Terminating async workers...") - await asyncio.sleep(0.5) - - await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()]) # Terminate backend - await backend.terminate() + backend.terminate() # pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches @@ -239,7 +236,7 @@ def start_vce( backend_name: str, backend_config_json_stream: str, app_dir: str, - f_stop: asyncio.Event, + f_stop: threading.Event, client_app: Optional[ClientApp] = None, client_app_attr: Optional[str] = None, num_supernodes: Optional[int] = None, @@ -291,8 +288,16 @@ def start_vce( # Construct mapping of NodeStates node_states: Dict[int, NodeState] = {} - for node_id in nodes_mapping: - node_states[node_id] = NodeState() + # Number of unique partitions + num_partitions = len(set(nodes_mapping.values())) + for node_id, partition_id in nodes_mapping.items(): + node_states[node_id] = NodeState( + node_id=node_id, + node_config={ + PARTITION_ID_KEY: str(partition_id), + NUM_PARTITIONS_KEY: str(num_partitions), + }, + ) # Load backend config log(DEBUG, "Supported backends: %s", list(supported_backends.keys())) @@ -343,15 +348,13 @@ def _load() -> ClientApp: _ = app_fn() # Run main simulation loop - asyncio.run( - run( - app_fn, - backend_fn, - nodes_mapping, - state_factory, - node_states, - f_stop, - ) + run( + app_fn, + backend_fn, + nodes_mapping, + state_factory, + node_states, + f_stop, ) except LoadClientAppError as loadapp_ex: f_stop_delay = 10 diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py index df9f2cc96f95..7d37f03f6ade 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -15,7 +15,6 @@ """Test Fleet Simulation Engine API.""" -import asyncio import threading import time from itertools import cycle @@ -24,7 +23,7 @@ from pathlib import Path from time import sleep from typing import Dict, Optional, Set, Tuple -from unittest import IsolatedAsyncioTestCase +from unittest import TestCase from uuid import UUID from flwr.client.client_app import LoadClientAppError @@ -46,7 +45,7 @@ from flwr.server.superlink.state import InMemoryState, StateFactory -def terminate_simulation(f_stop: asyncio.Event, sleep_duration: int) -> None: +def terminate_simulation(f_stop: threading.Event, sleep_duration: int) -> None: """Set event to terminate Simulation Engine after `sleep_duration` seconds.""" sleep(sleep_duration) f_stop.set() @@ -82,7 +81,9 @@ def register_messages_into_state( ) -> Dict[UUID, float]: """Register `num_messages` into the state factory.""" state: InMemoryState = state_factory.state() # type: ignore - state.run_ids[run_id] = Run(run_id=run_id, fab_id="Mock/mock", fab_version="v1.0.0") + state.run_ids[run_id] = Run( + run_id=run_id, fab_id="Mock/mock", fab_version="v1.0.0", override_config={} + ) # Artificially add TaskIns to state so they can be processed # by the Simulation Engine logic nodes_cycle = cycle(nodes_mapping.keys()) # we have more messages than supernodes @@ -146,15 +147,15 @@ def start_and_shutdown( ) -> None: """Start Simulation Engine and terminate after specified number of seconds. - Some tests need to be terminated by triggering externally an asyncio.Event. This - is enabled whtn passing `duration`>0. + Some tests need to be terminated by triggering externally an threading.Event. This + is enabled when passing `duration`>0. """ - f_stop = asyncio.Event() + f_stop = threading.Event() if duration: # Setup thread that will set the f_stop event, triggering the termination of all - # asyncio logic in the Simulation Engine. It will also terminate the Backend. + # logic in the Simulation Engine. It will also terminate the Backend. termination_th = threading.Thread( target=terminate_simulation, args=(f_stop, duration) ) @@ -179,8 +180,8 @@ def start_and_shutdown( termination_th.join() -class AsyncTestFleetSimulationEngineRayBackend(IsolatedAsyncioTestCase): - """A basic class that enables testing asyncio functionalities.""" +class TestFleetSimulationEngineRayBackend(TestCase): + """A basic class that enables testing functionalities.""" def test_erroneous_no_supernodes_client_mapping(self) -> None: """Test with unset arguments.""" diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index da9c754c3115..bc4bd4478a23 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -15,7 +15,6 @@ """In-memory State implementation.""" -import os import threading import time from logging import ERROR @@ -23,12 +22,13 @@ from uuid import UUID, uuid4 from flwr.common import log, now +from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES from flwr.common.typing import Run from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from flwr.server.superlink.state.state import State from flwr.server.utils import validate_task_ins_or_res -from .utils import make_node_unavailable_taskres +from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres class InMemoryState(State): # pylint: disable=R0902,R0904 @@ -216,7 +216,7 @@ def create_node( ) -> int: """Create, store in state, and return `node_id`.""" # Sample a random int64 as node_id - node_id: int = int.from_bytes(os.urandom(8), "little", signed=True) + node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES) with self.lock: if node_id in self.node_ids: @@ -275,15 +275,23 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `client_public_keys`.""" return self.public_key_to_node_id.get(client_public_key) - def create_run(self, fab_id: str, fab_version: str) -> int: + def create_run( + self, + fab_id: str, + fab_version: str, + override_config: Dict[str, str], + ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" # Sample a random int64 as run_id with self.lock: - run_id: int = int.from_bytes(os.urandom(8), "little", signed=True) + run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES) if run_id not in self.run_ids: self.run_ids[run_id] = Run( - run_id=run_id, fab_id=fab_id, fab_version=fab_version + run_id=run_id, + fab_id=fab_id, + fab_version=fab_version, + override_config=override_config, ) return run_id log(ERROR, "Unexpected run creation failure.") diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 4df9470ded62..ea6f349b9f9a 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -15,7 +15,7 @@ """SQLite based implemenation of server state.""" -import os +import json import re import sqlite3 import time @@ -24,6 +24,7 @@ from uuid import UUID, uuid4 from flwr.common import log, now +from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES from flwr.common.typing import Run from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 @@ -31,7 +32,7 @@ from flwr.server.utils.validator import validate_task_ins_or_res from .state import State -from .utils import make_node_unavailable_taskres +from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres SQL_CREATE_TABLE_NODE = """ CREATE TABLE IF NOT EXISTS node( @@ -61,9 +62,10 @@ SQL_CREATE_TABLE_RUN = """ CREATE TABLE IF NOT EXISTS run( - run_id INTEGER UNIQUE, - fab_id TEXT, - fab_version TEXT + run_id INTEGER UNIQUE, + fab_id TEXT, + fab_version TEXT, + override_config TEXT ); """ @@ -541,7 +543,7 @@ def create_node( ) -> int: """Create, store in state, and return `node_id`.""" # Sample a random int64 as node_id - node_id: int = int.from_bytes(os.urandom(8), "little", signed=True) + node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES) query = "SELECT node_id FROM node WHERE public_key = :public_key;" row = self.query(query, {"public_key": public_key}) @@ -613,17 +615,27 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]: return node_id return None - def create_run(self, fab_id: str, fab_version: str) -> int: + def create_run( + self, + fab_id: str, + fab_version: str, + override_config: Dict[str, str], + ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" # Sample a random int64 as run_id - run_id: int = int.from_bytes(os.urandom(8), "little", signed=True) + run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES) # Check conflicts query = "SELECT COUNT(*) FROM run WHERE run_id = ?;" # If run_id does not exist if self.query(query, (run_id,))[0]["COUNT(*)"] == 0: - query = "INSERT INTO run (run_id, fab_id, fab_version) VALUES (?, ?, ?);" - self.query(query, (run_id, fab_id, fab_version)) + query = ( + "INSERT INTO run (run_id, fab_id, fab_version, override_config)" + "VALUES (?, ?, ?, ?);" + ) + self.query( + query, (run_id, fab_id, fab_version, json.dumps(override_config)) + ) return run_id log(ERROR, "Unexpected run creation failure.") return 0 @@ -687,7 +699,10 @@ def get_run(self, run_id: int) -> Optional[Run]: try: row = self.query(query, (run_id,))[0] return Run( - run_id=run_id, fab_id=row["fab_id"], fab_version=row["fab_version"] + run_id=run_id, + fab_id=row["fab_id"], + fab_version=row["fab_version"], + override_config=json.loads(row["override_config"]), ) except sqlite3.IntegrityError: log(ERROR, "`run_id` does not exist.") diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index 65e2c63cab69..c93f6ba756b8 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -16,7 +16,7 @@ import abc -from typing import List, Optional, Set +from typing import Dict, List, Optional, Set from uuid import UUID from flwr.common.typing import Run @@ -157,7 +157,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `client_public_keys`.""" @abc.abstractmethod - def create_run(self, fab_id: str, fab_version: str) -> int: + def create_run( + self, + fab_id: str, + fab_version: str, + override_config: Dict[str, str], + ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" @abc.abstractmethod diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 373202d5cde6..5f0d23ffc4d8 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -52,7 +52,7 @@ def test_create_and_get_run(self) -> None: """Test if create_run and get_run work correctly.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("Mock/mock", "v1.0.0") + run_id = state.create_run("Mock/mock", "v1.0.0", {"test_key": "test_value"}) # Execute run = state.get_run(run_id) @@ -62,6 +62,7 @@ def test_create_and_get_run(self) -> None: assert run.run_id == run_id assert run.fab_id == "Mock/mock" assert run.fab_version == "v1.0.0" + assert run.override_config["test_key"] == "test_value" def test_get_task_ins_empty(self) -> None: """Validate that a new state has no TaskIns.""" @@ -90,7 +91,7 @@ def test_store_task_ins_one(self) -> None: # Prepare consumer_node_id = 1 state = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins( consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) @@ -125,7 +126,7 @@ def test_store_and_delete_tasks(self) -> None: # Prepare consumer_node_id = 1 state = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins_0 = create_task_ins( consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) @@ -199,7 +200,7 @@ def test_task_ins_store_anonymous_and_retrieve_anonymous(self) -> None: """ # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Execute @@ -214,7 +215,7 @@ def test_task_ins_store_anonymous_and_fail_retrieving_identitiy(self) -> None: """Store anonymous TaskIns and fail to retrieve it.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Execute @@ -228,7 +229,7 @@ def test_task_ins_store_identity_and_fail_retrieving_anonymous(self) -> None: """Store identity TaskIns and fail retrieving it as anonymous.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute @@ -242,7 +243,7 @@ def test_task_ins_store_identity_and_retrieve_identity(self) -> None: """Store identity TaskIns and retrieve it.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute @@ -259,7 +260,7 @@ def test_task_ins_store_delivered_and_fail_retrieving(self) -> None: """Fail retrieving delivered task.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute @@ -302,7 +303,7 @@ def test_task_res_store_and_retrieve_by_task_ins_id(self) -> None: """Store TaskRes retrieve it by task_ins_id.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins_id = uuid4() task_res = create_task_res( producer_node_id=0, @@ -323,7 +324,7 @@ def test_node_ids_initial_state(self) -> None: """Test retrieving all node_ids and empty initial state.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) # Execute retrieved_node_ids = state.get_nodes(run_id) @@ -335,7 +336,7 @@ def test_create_node_and_get_nodes(self) -> None: """Test creating a client node.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_ids = [] # Execute @@ -352,7 +353,7 @@ def test_create_node_public_key(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) # Execute node_id = state.create_node(ping_interval=10, public_key=public_key) @@ -368,7 +369,7 @@ def test_create_node_public_key_twice(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute @@ -390,7 +391,7 @@ def test_delete_node(self) -> None: """Test deleting a client node.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = state.create_node(ping_interval=10) # Execute @@ -405,7 +406,7 @@ def test_delete_node_public_key(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute @@ -422,7 +423,7 @@ def test_delete_node_public_key_none(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = 0 # Execute & Assert @@ -441,7 +442,7 @@ def test_delete_node_wrong_public_key(self) -> None: state: State = self.state_factory() public_key = b"mock" wrong_public_key = b"mock_mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute & Assert @@ -460,7 +461,7 @@ def test_get_node_id_wrong_public_key(self) -> None: state: State = self.state_factory() public_key = b"mock" wrong_public_key = b"mock_mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) # Execute state.create_node(ping_interval=10, public_key=public_key) @@ -475,7 +476,7 @@ def test_get_nodes_invalid_run_id(self) -> None: """Test retrieving all node_ids with invalid run_id.""" # Prepare state: State = self.state_factory() - state.create_run("mock/mock", "v1.0.0") + state.create_run("mock/mock", "v1.0.0", {}) invalid_run_id = 61016 state.create_node(ping_interval=10) @@ -489,7 +490,7 @@ def test_num_task_ins(self) -> None: """Test if num_tasks returns correct number of not delivered task_ins.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_0 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) task_1 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) @@ -507,7 +508,7 @@ def test_num_task_res(self) -> None: """Test if num_tasks returns correct number of not delivered task_res.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_0 = create_task_res( producer_node_id=0, anonymous=True, ancestry=["1"], run_id=run_id ) @@ -608,7 +609,7 @@ def test_acknowledge_ping(self) -> None: """Test if acknowledge_ping works and if get_nodes return online nodes.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_ids = [state.create_node(ping_interval=10) for _ in range(100)] for node_id in node_ids[:70]: state.acknowledge_ping(node_id, ping_interval=30) @@ -627,7 +628,7 @@ def test_node_unavailable_error(self) -> None: """Test if get_task_res return TaskRes containing node unavailable error.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id_0 = state.create_node(ping_interval=90) node_id_1 = state.create_node(ping_interval=30) # Create and store TaskIns diff --git a/src/py/flwr/server/superlink/state/utils.py b/src/py/flwr/server/superlink/state/utils.py index 233a90946cc7..b12a87ac998d 100644 --- a/src/py/flwr/server/superlink/state/utils.py +++ b/src/py/flwr/server/superlink/state/utils.py @@ -17,6 +17,7 @@ import time from logging import ERROR +from os import urandom from uuid import uuid4 from flwr.common import log @@ -31,6 +32,11 @@ ) +def generate_rand_int_from_bytes(num_bytes: int) -> int: + """Generate a random `num_bytes` integer.""" + return int.from_bytes(urandom(num_bytes), "little", signed=True) + + def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes: """Generate a TaskRes with a node unavailable error from a TaskIns.""" current_time = time.time() diff --git a/src/py/flwr/server/typing.py b/src/py/flwr/server/typing.py index 01143af74392..cdb1c0db4fe7 100644 --- a/src/py/flwr/server/typing.py +++ b/src/py/flwr/server/typing.py @@ -20,6 +20,8 @@ from flwr.common import Context from .driver import Driver +from .serverapp_components import ServerAppComponents ServerAppCallable = Callable[[Driver, Context], None] Workflow = Callable[[Driver, Context], None] +ServerFn = Callable[[Context], ServerAppComponents] diff --git a/src/py/flwr/simulation/app.py b/src/py/flwr/simulation/app.py index 856d6fc45e22..973a9a89e652 100644 --- a/src/py/flwr/simulation/app.py +++ b/src/py/flwr/simulation/app.py @@ -27,14 +27,16 @@ import ray from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy -from flwr.client import ClientFn +from flwr.client import ClientFnExt from flwr.common import EventType, event -from flwr.common.logger import log, set_logger_propagation +from flwr.common.constant import NODE_ID_NUM_BYTES +from flwr.common.logger import log, set_logger_propagation, warn_unsupported_feature from flwr.server.client_manager import ClientManager from flwr.server.history import History from flwr.server.server import Server, init_defaults, run_fl from flwr.server.server_config import ServerConfig from flwr.server.strategy import Strategy +from flwr.server.superlink.state.utils import generate_rand_int_from_bytes from flwr.simulation.ray_transport.ray_actor import ( ClientAppActor, VirtualClientEngineActor, @@ -51,7 +53,7 @@ `start_simulation( *, client_fn: ClientFn, - num_clients: Optional[int] = None, + num_clients: int, clients_ids: Optional[List[str]] = None, client_resources: Optional[Dict[str, float]] = None, server: Optional[Server] = None, @@ -70,13 +72,29 @@ """ +NodeToPartitionMapping = Dict[int, int] + + +def _create_node_id_to_partition_mapping( + num_clients: int, +) -> NodeToPartitionMapping: + """Generate a node_id:partition_id mapping.""" + nodes_mapping: NodeToPartitionMapping = {} # {node-id; partition-id} + for i in range(num_clients): + while True: + node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES) + if node_id not in nodes_mapping: + break + nodes_mapping[node_id] = i + return nodes_mapping + # pylint: disable=too-many-arguments,too-many-statements,too-many-branches def start_simulation( *, - client_fn: ClientFn, - num_clients: Optional[int] = None, - clients_ids: Optional[List[str]] = None, + client_fn: ClientFnExt, + num_clients: int, + clients_ids: Optional[List[str]] = None, # UNSUPPORTED, WILL BE REMOVED client_resources: Optional[Dict[str, float]] = None, server: Optional[Server] = None, config: Optional[ServerConfig] = None, @@ -92,23 +110,24 @@ def start_simulation( Parameters ---------- - client_fn : ClientFn - A function creating client instances. The function must take a single - `str` argument called `cid`. It should return a single client instance - of type Client. Note that the created client instances are ephemeral - and will often be destroyed after a single method invocation. Since client - instances are not long-lived, they should not attempt to carry state over - method invocations. Any state required by the instance (model, dataset, - hyperparameters, ...) should be (re-)created in either the call to `client_fn` - or the call to any of the client methods (e.g., load evaluation data in the - `evaluate` method itself). - num_clients : Optional[int] - The total number of clients in this simulation. This must be set if - `clients_ids` is not set and vice-versa. + client_fn : ClientFnExt + A function creating `Client` instances. The function must have the signature + `client_fn(context: Context). It should return + a single client instance of type `Client`. Note that the created client + instances are ephemeral and will often be destroyed after a single method + invocation. Since client instances are not long-lived, they should not attempt + to carry state over method invocations. Any state required by the instance + (model, dataset, hyperparameters, ...) should be (re-)created in either the + call to `client_fn` or the call to any of the client methods (e.g., load + evaluation data in the `evaluate` method itself). + num_clients : int + The total number of clients in this simulation. clients_ids : Optional[List[str]] + UNSUPPORTED, WILL BE REMOVED. USE `num_clients` INSTEAD. List `client_id`s for each client. This is only required if `num_clients` is not set. Setting both `num_clients` and `clients_ids` with `len(clients_ids)` not equal to `num_clients` generates an error. + Using this argument will raise an error. client_resources : Optional[Dict[str, float]] (default: `{"num_cpus": 1, "num_gpus": 0.0}`) CPU and GPU resources for a single client. Supported keys are `num_cpus` and `num_gpus`. To understand the GPU utilization caused by @@ -158,7 +177,6 @@ def start_simulation( is an advanced feature. For all details, please refer to the Ray documentation: https://docs.ray.io/en/latest/ray-core/scheduling/index.html - Returns ------- hist : flwr.server.history.History @@ -170,6 +188,14 @@ def start_simulation( {"num_clients": len(clients_ids) if clients_ids is not None else num_clients}, ) + if clients_ids is not None: + warn_unsupported_feature( + "Passing `clients_ids` to `start_simulation` is deprecated and not longer " + "used by `start_simulation`. Use `num_clients` exclusively instead." + ) + log(ERROR, "`clients_ids` argument used.") + sys.exit() + # Set logger propagation loop: Optional[asyncio.AbstractEventLoop] = None try: @@ -196,20 +222,8 @@ def start_simulation( initialized_config, ) - # clients_ids takes precedence - cids: List[str] - if clients_ids is not None: - if (num_clients is not None) and (len(clients_ids) != num_clients): - log(ERROR, INVALID_ARGUMENTS_START_SIMULATION) - sys.exit() - else: - cids = clients_ids - else: - if num_clients is None: - log(ERROR, INVALID_ARGUMENTS_START_SIMULATION) - sys.exit() - else: - cids = [str(x) for x in range(num_clients)] + # Create node-id to partition-id mapping + nodes_mapping = _create_node_id_to_partition_mapping(num_clients) # Default arguments for Ray initialization if not ray_init_args: @@ -308,10 +322,12 @@ def update_resources(f_stop: threading.Event) -> None: ) # Register one RayClientProxy object for each client with the ClientManager - for cid in cids: + for node_id, partition_id in nodes_mapping.items(): client_proxy = RayActorClientProxy( client_fn=client_fn, - cid=cid, + node_id=node_id, + partition_id=partition_id, + num_partitions=num_clients, actor_pool=pool, ) initialized_server.client_manager().register(client=client_proxy) diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 7afffb865334..b1c9d2b9c0c1 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -14,7 +14,6 @@ # ============================================================================== """Ray-based Flower Actor and ActorPool implementation.""" -import asyncio import threading from abc import ABC from logging import DEBUG, ERROR, WARNING @@ -411,9 +410,7 @@ def __init__( self.client_resources = client_resources # Queue of idle actors - self.pool: "asyncio.Queue[Type[VirtualClientEngineActor]]" = asyncio.Queue( - maxsize=1024 - ) + self.pool: List[VirtualClientEngineActor] = [] self.num_actors = 0 # Resolve arguments to pass during actor init @@ -427,38 +424,37 @@ def __init__( # Figure out how many actors can be created given the cluster resources # and the resources the user indicates each VirtualClient will need self.actors_capacity = pool_size_from_resources(client_resources) - self._future_to_actor: Dict[Any, Type[VirtualClientEngineActor]] = {} + self._future_to_actor: Dict[Any, VirtualClientEngineActor] = {} def is_actor_available(self) -> bool: """Return true if there is an idle actor.""" - return self.pool.qsize() > 0 + return len(self.pool) > 0 - async def add_actors_to_pool(self, num_actors: int) -> None: + def add_actors_to_pool(self, num_actors: int) -> None: """Add actors to the pool. This method may be executed also if new resources are added to your Ray cluster (e.g. you add a new node). """ for _ in range(num_actors): - await self.pool.put(self.create_actor_fn()) # type: ignore + self.pool.append(self.create_actor_fn()) # type: ignore self.num_actors += num_actors - async def terminate_all_actors(self) -> None: + def terminate_all_actors(self) -> None: """Terminate actors in pool.""" num_terminated = 0 - while self.pool.qsize(): - actor = await self.pool.get() + for actor in self.pool: actor.terminate.remote() # type: ignore num_terminated += 1 log(DEBUG, "Terminated %i actors", num_terminated) - async def submit( + def submit( self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context] ) -> Any: """On idle actor, submit job and return future.""" # Remove idle actor from pool - actor = await self.pool.get() + actor = self.pool.pop() # Submit job to actor app_fn, mssg, cid, context = job future = actor_fn(actor, app_fn, mssg, cid, context) @@ -467,18 +463,18 @@ async def submit( self._future_to_actor[future] = actor return future - async def add_actor_back_to_pool(self, future: Any) -> None: + def add_actor_back_to_pool(self, future: Any) -> None: """Ad actor assigned to run future back into the pool.""" actor = self._future_to_actor.pop(future) - await self.pool.put(actor) + self.pool.append(actor) - async def fetch_result_and_return_actor_to_pool( + def fetch_result_and_return_actor_to_pool( self, future: Any ) -> Tuple[Message, Context]: """Pull result given a future and add actor back to pool.""" - # Get actor that ran job - await self.add_actor_back_to_pool(future) # Retrieve result for object store # Instead of doing ray.get(future) we await it - _, out_mssg, updated_context = await future + _, out_mssg, updated_context = ray.get(future) + # Get actor that ran job + self.add_actor_back_to_pool(future) return out_mssg, updated_context diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index d3d103bb377a..895272c2fd79 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -20,11 +20,16 @@ from typing import Optional from flwr import common -from flwr.client import ClientFn +from flwr.client import ClientFnExt from flwr.client.client_app import ClientApp from flwr.client.node_state import NodeState from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet -from flwr.common.constant import MessageType, MessageTypeLegacy +from flwr.common.constant import ( + NUM_PARTITIONS_KEY, + PARTITION_ID_KEY, + MessageType, + MessageTypeLegacy, +) from flwr.common.logger import log from flwr.common.recordset_compat import ( evaluateins_to_recordset, @@ -43,17 +48,30 @@ class RayActorClientProxy(ClientProxy): """Flower client proxy which delegates work using Ray.""" - def __init__( - self, client_fn: ClientFn, cid: str, actor_pool: VirtualClientEngineActorPool + def __init__( # pylint: disable=too-many-arguments + self, + client_fn: ClientFnExt, + node_id: int, + partition_id: int, + num_partitions: int, + actor_pool: VirtualClientEngineActorPool, ): - super().__init__(cid) + super().__init__(cid=str(node_id)) + self.node_id = node_id + self.partition_id = partition_id def _load_app() -> ClientApp: return ClientApp(client_fn=client_fn) self.app_fn = _load_app self.actor_pool = actor_pool - self.proxy_state = NodeState() + self.proxy_state = NodeState( + node_id=node_id, + node_config={ + PARTITION_ID_KEY: str(partition_id), + NUM_PARTITIONS_KEY: str(num_partitions), + }, + ) def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: """Sumbit a message to the ActorPool.""" @@ -62,16 +80,19 @@ def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: # Register state self.proxy_state.register_context(run_id=run_id) - # Retrieve state - state = self.proxy_state.retrieve_context(run_id=run_id) + # Retrieve context + context = self.proxy_state.retrieve_context(run_id=run_id) + partition_id_str = context.node_config[PARTITION_ID_KEY] try: self.actor_pool.submit_client_job( - lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), - (self.app_fn, message, self.cid, state), + lambda a, a_fn, mssg, partition_id, context: a.run.remote( + a_fn, mssg, partition_id, context + ), + (self.app_fn, message, partition_id_str, context), ) out_mssg, updated_context = self.actor_pool.get_client_result( - self.cid, timeout + partition_id_str, timeout ) # Update state @@ -103,11 +124,10 @@ def _wrap_recordset_in_message( message_id="", group_id=str(group_id) if group_id is not None else "", src_node_id=0, - dst_node_id=int(self.cid), + dst_node_id=self.node_id, reply_to_message="", ttl=timeout if timeout else DEFAULT_TTL, message_type=message_type, - partition_id=int(self.cid), ), ) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 9680b3846f1d..62e0cfd61c99 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -23,6 +23,7 @@ from flwr.client import Client, NumPyClient from flwr.client.client_app import ClientApp +from flwr.client.node_state import NodeState from flwr.common import ( DEFAULT_TTL, Config, @@ -31,14 +32,18 @@ Message, MessageTypeLegacy, Metadata, - RecordSet, Scalar, ) +from flwr.common.constant import NUM_PARTITIONS_KEY, PARTITION_ID_KEY from flwr.common.recordset_compat import ( getpropertiesins_to_recordset, recordset_to_getpropertiesres, ) from flwr.common.recordset_compat_test import _get_valid_getpropertiesins +from flwr.simulation.app import ( + NodeToPartitionMapping, + _create_node_id_to_partition_mapping, +) from flwr.simulation.ray_transport.ray_actor import ( ClientAppActor, VirtualClientEngineActor, @@ -50,12 +55,12 @@ class DummyClient(NumPyClient): """A dummy NumPyClient for tests.""" - def __init__(self, cid: str) -> None: - self.cid = int(cid) + def __init__(self, node_id: int) -> None: + self.node_id = node_id def get_properties(self, config: Config) -> Dict[str, Scalar]: """Return properties by doing a simple calculation.""" - result = int(self.cid) * pi + result = self.node_id * pi # store something in context self.context.state.configs_records["result"] = ConfigsRecord( @@ -64,14 +69,16 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]: return {"result": result} -def get_dummy_client(cid: str) -> Client: +def get_dummy_client(context: Context) -> Client: """Return a DummyClient converted to Client type.""" - return DummyClient(cid).to_client() + return DummyClient(context.node_id).to_client() def prep( actor_type: Type[VirtualClientEngineActor] = ClientAppActor, -) -> Tuple[List[RayActorClientProxy], VirtualClientEngineActorPool]: # pragma: no cover +) -> Tuple[ + List[RayActorClientProxy], VirtualClientEngineActorPool, NodeToPartitionMapping +]: # pragma: no cover """Prepare ClientProxies and pool for tests.""" client_resources = {"num_cpus": 1, "num_gpus": 0.0} @@ -87,16 +94,19 @@ def create_actor_fn() -> Type[VirtualClientEngineActor]: # Create 373 client proxies num_proxies = 373 # a prime number + mapping = _create_node_id_to_partition_mapping(num_proxies) proxies = [ RayActorClientProxy( client_fn=get_dummy_client, - cid=str(cid), + node_id=node_id, + partition_id=partition_id, + num_partitions=num_proxies, actor_pool=pool, ) - for cid in range(num_proxies) + for node_id, partition_id in mapping.items() ] - return proxies, pool + return proxies, pool, mapping def test_cid_consistency_one_at_a_time() -> None: @@ -104,7 +114,7 @@ def test_cid_consistency_one_at_a_time() -> None: Submit one job and waits for completion. Then submits the next and so on """ - proxies, _ = prep() + proxies, _, _ = prep() getproperties_ins = _get_valid_getpropertiesins() recordset = getpropertiesins_to_recordset(getproperties_ins) @@ -123,7 +133,7 @@ def test_cid_consistency_one_at_a_time() -> None: res = recordset_to_getpropertiesres(message_out.content) - assert int(prox.cid) * pi == res.properties["result"] + assert int(prox.node_id) * pi == res.properties["result"] ray.shutdown() @@ -134,7 +144,7 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: All jobs are submitted at the same time. Then fetched one at a time. This also tests NodeState (at each Proxy) and RunState basic functionality. """ - proxies, _ = prep() + proxies, _, _ = prep() run_id = 0 getproperties_ins = _get_valid_getpropertiesins() @@ -156,21 +166,21 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: ) prox.actor_pool.submit_client_job( lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), - (prox.app_fn, message, prox.cid, state), + (prox.app_fn, message, str(prox.node_id), state), ) # fetch results one at a time shuffle(proxies) for prox in proxies: message_out, updated_context = prox.actor_pool.get_client_result( - prox.cid, timeout=None + str(prox.node_id), timeout=None ) prox.proxy_state.update_context(run_id, context=updated_context) res = recordset_to_getpropertiesres(message_out.content) - assert int(prox.cid) * pi == res.properties["result"] + assert prox.node_id * pi == res.properties["result"] assert ( - str(int(prox.cid) * pi) + str(prox.node_id * pi) == prox.proxy_state.retrieve_context(run_id).state.configs_records[ "result" ]["result"] @@ -181,9 +191,19 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: def test_cid_consistency_without_proxies() -> None: """Test cid consistency of jobs submitted/retrieved to/from pool w/o ClientProxy.""" - proxies, pool = prep() - num_clients = len(proxies) - cids = [str(cid) for cid in range(num_clients)] + _, pool, mapping = prep() + node_ids = list(mapping.keys()) + + # register node states + node_states: Dict[int, NodeState] = {} + for node_id, partition_id in mapping.items(): + node_states[node_id] = NodeState( + node_id=node_id, + node_config={ + PARTITION_ID_KEY: str(partition_id), + NUM_PARTITIONS_KEY: str(len(node_ids)), + }, + ) getproperties_ins = _get_valid_getpropertiesins() recordset = getpropertiesins_to_recordset(getproperties_ins) @@ -192,32 +212,37 @@ def _load_app() -> ClientApp: return ClientApp(client_fn=get_dummy_client) # submit all jobs (collect later) - shuffle(cids) - for cid in cids: + shuffle(node_ids) + run_id = 0 + for node_id in node_ids: message = Message( content=recordset, metadata=Metadata( - run_id=0, + run_id=run_id, message_id="", group_id=str(0), src_node_id=0, - dst_node_id=12345, + dst_node_id=node_id, reply_to_message="", ttl=DEFAULT_TTL, message_type=MessageTypeLegacy.GET_PROPERTIES, - partition_id=int(cid), ), ) + # register and retrieve context + node_states[node_id].register_context(run_id=run_id) + context = node_states[node_id].retrieve_context(run_id=run_id) + partition_id_str = context.node_config[PARTITION_ID_KEY] pool.submit_client_job( - lambda a, c_fn, j_fn, cid_, state: a.run.remote(c_fn, j_fn, cid_, state), - (_load_app, message, cid, Context(state=RecordSet())), + lambda a, c_fn, j_fn, nid_, state: a.run.remote(c_fn, j_fn, nid_, state), + (_load_app, message, partition_id_str, context), ) # fetch results one at a time - shuffle(cids) - for cid in cids: - message_out, _ = pool.get_client_result(cid, timeout=None) + shuffle(node_ids) + for node_id in node_ids: + partition_id_str = str(mapping[node_id]) + message_out, _ = pool.get_client_result(partition_id_str, timeout=None) res = recordset_to_getpropertiesres(message_out.content) - assert int(cid) * pi == res.properties["result"] + assert node_id * pi == res.properties["result"] ray.shutdown() diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 7c7a412a245b..de101fe3e09f 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -22,7 +22,7 @@ import traceback from logging import DEBUG, ERROR, INFO, WARNING from time import sleep -from typing import Optional +from typing import Dict, Optional from flwr.client import ClientApp from flwr.common import EventType, event, log @@ -126,16 +126,25 @@ def run_simulation( def run_serverapp_th( server_app_attr: Optional[str], server_app: Optional[ServerApp], + server_app_run_config: Dict[str, str], driver: Driver, app_dir: str, - f_stop: asyncio.Event, + f_stop: threading.Event, + has_exception: threading.Event, enable_tf_gpu_growth: bool, delay_launch: int = 3, ) -> threading.Thread: """Run SeverApp in a thread.""" - def server_th_with_start_checks( # type: ignore - tf_gpu_growth: bool, stop_event: asyncio.Event, **kwargs + def server_th_with_start_checks( + tf_gpu_growth: bool, + stop_event: threading.Event, + exception_event: threading.Event, + _driver: Driver, + _server_app_dir: str, + _server_app_run_config: Dict[str, str], + _server_app_attr: Optional[str], + _server_app: Optional[ServerApp], ) -> None: """Run SeverApp, after check if GPU memory growth has to be set. @@ -147,10 +156,18 @@ def server_th_with_start_checks( # type: ignore enable_gpu_growth() # Run ServerApp - run(**kwargs) + run( + driver=_driver, + server_app_dir=_server_app_dir, + server_app_run_config=_server_app_run_config, + server_app_attr=_server_app_attr, + loaded_server_app=_server_app, + ) except Exception as ex: # pylint: disable=broad-exception-caught log(ERROR, "ServerApp thread raised an exception: %s", ex) log(ERROR, traceback.format_exc()) + exception_event.set() + raise finally: log(DEBUG, "ServerApp finished running.") # Upon completion, trigger stop event if one was passed @@ -160,13 +177,16 @@ def server_th_with_start_checks( # type: ignore serverapp_th = threading.Thread( target=server_th_with_start_checks, - args=(enable_tf_gpu_growth, f_stop), - kwargs={ - "server_app_attr": server_app_attr, - "loaded_server_app": server_app, - "driver": driver, - "server_app_dir": app_dir, - }, + args=( + enable_tf_gpu_growth, + f_stop, + has_exception, + driver, + app_dir, + server_app_run_config, + server_app_attr, + server_app, + ), ) sleep(delay_launch) serverapp_th.start() @@ -196,20 +216,18 @@ def _main_loop( server_app: Optional[ServerApp] = None, server_app_attr: Optional[str] = None, ) -> None: - """Launch SuperLink with Simulation Engine, then ServerApp on a separate thread. - - Everything runs on the main thread or a separate one, depending on whether the main - thread already contains a running Asyncio event loop. This is the case if running - the Simulation Engine on a Jupyter/Colab notebook. - """ + """Launch SuperLink with Simulation Engine, then ServerApp on a separate thread.""" # Initialize StateFactory state_factory = StateFactory(":flwr-in-memory-state:") - f_stop = asyncio.Event() + f_stop = threading.Event() + # A Threading event to indicate if an exception was raised in the ServerApp thread + server_app_thread_has_exception = threading.Event() serverapp_th = None try: # Create run (with empty fab_id and fab_version) - run_id_ = state_factory.state().create_run("", "") + run_id_ = state_factory.state().create_run("", "", {}) + server_app_run_config: Dict[str, str] = {} if run_id: _override_run_id(state_factory, run_id_to_replace=run_id_, run_id=run_id) @@ -222,9 +240,11 @@ def _main_loop( serverapp_th = run_serverapp_th( server_app_attr=server_app_attr, server_app=server_app, + server_app_run_config=server_app_run_config, driver=driver, app_dir=app_dir, f_stop=f_stop, + has_exception=server_app_thread_has_exception, enable_tf_gpu_growth=enable_tf_gpu_growth, ) @@ -253,6 +273,8 @@ def _main_loop( event(EventType.RUN_SUPERLINK_LEAVE) if serverapp_th: serverapp_th.join() + if server_app_thread_has_exception.is_set(): + raise RuntimeError("Exception in ServerApp thread") log(DEBUG, "Stopping Simulation Engine now.") @@ -349,7 +371,6 @@ def _run_simulation( # Convert config to original JSON-stream format backend_config_stream = json.dumps(backend_config) - simulation_engine_th = None args = ( num_supernodes, backend_name, @@ -363,31 +384,26 @@ def _run_simulation( server_app_attr, ) # Detect if there is an Asyncio event loop already running. - # If yes, run everything on a separate thread. In environments - # like Jupyter/Colab notebooks, there is an event loop present. - run_in_thread = False + # If yes, disable logger propagation. In environmnets + # like Jupyter/Colab notebooks, it's often better to do this. + asyncio_loop_running = False try: _ = ( asyncio.get_running_loop() ) # Raises RuntimeError if no event loop is present log(DEBUG, "Asyncio event loop already running.") - run_in_thread = True + asyncio_loop_running = True except RuntimeError: - log(DEBUG, "No asyncio event loop running") + pass finally: - if run_in_thread: + if asyncio_loop_running: # Set logger propagation to False to prevent duplicated log output in Colab. logger = set_logger_propagation(logger, False) - log(DEBUG, "Starting Simulation Engine on a new thread.") - simulation_engine_th = threading.Thread(target=_main_loop, args=args) - simulation_engine_th.start() - simulation_engine_th.join() - else: - log(DEBUG, "Starting Simulation Engine on the main thread.") - _main_loop(*args) + + _main_loop(*args) def _parse_args_run_simulation() -> argparse.ArgumentParser: diff --git a/src/py/flwr/superexec/app.py b/src/py/flwr/superexec/app.py index fa89e83b5e75..b51c3e6821dc 100644 --- a/src/py/flwr/superexec/app.py +++ b/src/py/flwr/superexec/app.py @@ -24,6 +24,7 @@ from flwr.common import EventType, event, log from flwr.common.address import parse_address +from flwr.common.config import parse_config_args from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS from flwr.common.exit_handlers import register_exit_handlers from flwr.common.object_ref import load_app, validate @@ -55,6 +56,7 @@ def run_superexec() -> None: address=address, executor=_load_executor(args), certificates=certificates, + config=parse_config_args(args.executor_config), ) grpc_servers = [superexec_server] @@ -74,20 +76,25 @@ def _parse_args_run_superexec() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Start a Flower SuperExec", ) - parser.add_argument( - "executor", - help="For example: `deployment:exec` or `project.package.module:wrapper.exec`.", - ) parser.add_argument( "--address", help="SuperExec (gRPC) server address (IPv4, IPv6, or a domain name)", default=SUPEREXEC_DEFAULT_ADDRESS, ) + parser.add_argument( + "--executor", + help="For example: `deployment:exec` or `project.package.module:wrapper.exec`.", + default="flwr.superexec.deployment:executor", + ) parser.add_argument( "--executor-dir", help="The directory for the executor.", default=".", ) + parser.add_argument( + "--executor-config", + help="Key-value pairs for the executor config, separated by commas.", + ) parser.add_argument( "--insecure", action="store_true", @@ -126,11 +133,11 @@ def _try_obtain_certificates( return None # Check if certificates are provided if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile: - if not Path.is_file(args.ssl_ca_certfile): + if not Path(args.ssl_ca_certfile).is_file(): sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.") - if not Path.is_file(args.ssl_certfile): + if not Path(args.ssl_certfile).is_file(): sys.exit("Path argument `--ssl-certfile` does not point to a file.") - if not Path.is_file(args.ssl_keyfile): + if not Path(args.ssl_keyfile).is_file(): sys.exit("Path argument `--ssl-keyfile` does not point to a file.") certificates = ( Path(args.ssl_ca_certfile).read_bytes(), # CA certificate diff --git a/src/py/flwr/superexec/deployment.py b/src/py/flwr/superexec/deployment.py index 6f931e81eefa..f9a272e6b0bf 100644 --- a/src/py/flwr/superexec/deployment.py +++ b/src/py/flwr/superexec/deployment.py @@ -17,7 +17,8 @@ import subprocess import sys from logging import ERROR, INFO -from typing import Optional +from pathlib import Path +from typing import Dict, Optional from typing_extensions import override @@ -33,38 +34,99 @@ class DeploymentEngine(Executor): - """Deployment engine executor.""" + """Deployment engine executor. + + Parameters + ---------- + superlink: str (default: "0.0.0.0:9091") + Address of the SuperLink to connect to. + root_certificates: Optional[str] (default: None) + Specifies the path to the PEM-encoded root certificate file for + establishing secure HTTPS connections. + flwr_dir: Optional[str] (default: None) + The path containing installed Flower Apps. + """ def __init__( self, - address: str = DEFAULT_SERVER_ADDRESS_DRIVER, - root_certificates: Optional[bytes] = None, + superlink: str = DEFAULT_SERVER_ADDRESS_DRIVER, + root_certificates: Optional[str] = None, + flwr_dir: Optional[str] = None, ) -> None: - self.address = address - self.root_certificates = root_certificates + self.superlink = superlink + if root_certificates is None: + self.root_certificates = None + self.root_certificates_bytes = None + else: + self.root_certificates = root_certificates + self.root_certificates_bytes = Path(root_certificates).read_bytes() + self.flwr_dir = flwr_dir self.stub: Optional[DriverStub] = None - def _connect(self) -> None: - if self.stub is None: - channel = create_channel( - server_address=self.address, - insecure=(self.root_certificates is None), - root_certificates=self.root_certificates, - ) - self.stub = DriverStub(channel) + @override + def set_config( + self, + config: Dict[str, str], + ) -> None: + """Set executor config arguments. + + Parameters + ---------- + config : Dict[str, str] + A dictionary for configuration values. + Supported configuration key/value pairs: + - "superlink": str + The address of the SuperLink Driver API. + - "root-certificates": str + The path to the root certificates. + - "flwr-dir": str + The path to the Flower directory. + """ + if not config: + return + if superlink_address := config.get("superlink"): + self.superlink = superlink_address + if root_certificates := config.get("root-certificates"): + self.root_certificates = root_certificates + self.root_certificates_bytes = Path(root_certificates).read_bytes() + if flwr_dir := config.get("flwr-dir"): + self.flwr_dir = flwr_dir - def _create_run(self, fab_id: str, fab_version: str) -> int: + def _connect(self) -> None: + if self.stub is not None: + return + channel = create_channel( + server_address=self.superlink, + insecure=(self.root_certificates_bytes is None), + root_certificates=self.root_certificates_bytes, + ) + self.stub = DriverStub(channel) + + def _create_run( + self, + fab_id: str, + fab_version: str, + override_config: Dict[str, str], + ) -> int: if self.stub is None: self._connect() assert self.stub is not None - req = CreateRunRequest(fab_id=fab_id, fab_version=fab_version) + req = CreateRunRequest( + fab_id=fab_id, + fab_version=fab_version, + override_config=override_config, + ) res = self.stub.CreateRun(request=req) return int(res.run_id) @override - def start_run(self, fab_file: bytes) -> Optional[RunTracker]: + def start_run( + self, + fab_file: bytes, + override_config: Dict[str, str], + ) -> Optional[RunTracker]: """Start run using the Flower Deployment Engine.""" try: # Install FAB to flwr dir @@ -79,7 +141,7 @@ def start_run(self, fab_file: bytes) -> Optional[RunTracker]: ) # Call SuperLink to create run - run_id: int = self._create_run(fab_id, fab_version) + run_id: int = self._create_run(fab_id, fab_version, override_config) log(INFO, "Created run %s", str(run_id)) # Start ServerApp @@ -88,7 +150,14 @@ def start_run(self, fab_file: bytes) -> Optional[RunTracker]: "flower-server-app", "--run-id", str(run_id), - "--insecure", + f"--flwr-dir {self.flwr_dir}" if self.flwr_dir else "", + "--superlink", + self.superlink, + ( + "--insecure" + if self.root_certificates is None + else f"--root-certificates {self.root_certificates}" + ), ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, diff --git a/src/py/flwr/superexec/exec_grpc.py b/src/py/flwr/superexec/exec_grpc.py index 127d5615dd84..d90cec3e47cd 100644 --- a/src/py/flwr/superexec/exec_grpc.py +++ b/src/py/flwr/superexec/exec_grpc.py @@ -15,7 +15,7 @@ """SuperExec gRPC API.""" from logging import INFO -from typing import Optional, Tuple +from typing import Dict, Optional, Tuple import grpc @@ -32,8 +32,11 @@ def run_superexec_api_grpc( address: str, executor: Executor, certificates: Optional[Tuple[bytes, bytes, bytes]], + config: Dict[str, str], ) -> grpc.Server: """Run SuperExec API (gRPC, request-response).""" + executor.set_config(config) + exec_servicer: grpc.Server = ExecServicer( executor=executor, ) @@ -45,7 +48,7 @@ def run_superexec_api_grpc( certificates=certificates, ) - log(INFO, "Flower ECE: Starting SuperExec API (gRPC-rere) on %s", address) + log(INFO, "Starting Flower SuperExec gRPC server on %s", address) superexec_grpc_server.start() return superexec_grpc_server diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index e5ef2bd59a79..61a7bc289af3 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -45,7 +45,10 @@ def StartRun( """Create run ID.""" log(INFO, "ExecServicer.StartRun") - run = self.executor.start_run(request.fab_file) + run = self.executor.start_run( + request.fab_file, + dict(request.override_config.items()), + ) if run is None: log(ERROR, "Executor failed to start run") diff --git a/src/py/flwr/superexec/exec_servicer_test.py b/src/py/flwr/superexec/exec_servicer_test.py index 41f67b74c48b..edc91df4530e 100644 --- a/src/py/flwr/superexec/exec_servicer_test.py +++ b/src/py/flwr/superexec/exec_servicer_test.py @@ -36,7 +36,7 @@ def test_start_run() -> None: run_res.proc = proc executor = MagicMock() - executor.start_run = lambda _: run_res + executor.start_run = lambda _, __: run_res context_mock = MagicMock() diff --git a/src/py/flwr/superexec/executor.py b/src/py/flwr/superexec/executor.py index f85ac4c157fc..62d64f366cec 100644 --- a/src/py/flwr/superexec/executor.py +++ b/src/py/flwr/superexec/executor.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from subprocess import Popen -from typing import Optional +from typing import Dict, Optional @dataclass @@ -31,10 +31,24 @@ class RunTracker: class Executor(ABC): """Execute and monitor a Flower run.""" + @abstractmethod + def set_config( + self, + config: Dict[str, str], + ) -> None: + """Register provided config as class attributes. + + Parameters + ---------- + config : Optional[Dict[str, str]] + A dictionary for configuration values. + """ + @abstractmethod def start_run( self, fab_file: bytes, + override_config: Dict[str, str], ) -> Optional[RunTracker]: """Start a run using the given Flower FAB ID and version. @@ -45,6 +59,8 @@ def start_run( ---------- fab_file : bytes The Flower App Bundle file bytes. + override_config: Dict[str, str] + The config overrides dict sent by the user (using `flwr run`). Returns -------