From 356e3f4b65d8d6a47816c1efafdb48622109af86 Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Fri, 12 Jul 2024 12:55:16 +0200 Subject: [PATCH 1/5] feat(datasets) Add pathological partitioner (#3623) Co-authored-by: jafermarq --- .../flwr_datasets/partitioner/__init__.py | 2 + .../partitioner/pathological_partitioner.py | 305 ++++++++++++++++++ .../pathological_partitioner_test.py | 262 +++++++++++++++ 3 files changed, 569 insertions(+) create mode 100644 datasets/flwr_datasets/partitioner/pathological_partitioner.py create mode 100644 datasets/flwr_datasets/partitioner/pathological_partitioner_test.py diff --git a/datasets/flwr_datasets/partitioner/__init__.py b/datasets/flwr_datasets/partitioner/__init__.py index 1fc00ed90323..0c75dbce387a 100644 --- a/datasets/flwr_datasets/partitioner/__init__.py +++ b/datasets/flwr_datasets/partitioner/__init__.py @@ -22,6 +22,7 @@ from .linear_partitioner import LinearPartitioner from .natural_id_partitioner import NaturalIdPartitioner from .partitioner import Partitioner +from .pathological_partitioner import PathologicalPartitioner from .shard_partitioner import ShardPartitioner from .size_partitioner import SizePartitioner from .square_partitioner import SquarePartitioner @@ -34,6 +35,7 @@ "LinearPartitioner", "NaturalIdPartitioner", "Partitioner", + "PathologicalPartitioner", "ShardPartitioner", "SizePartitioner", "SquarePartitioner", diff --git a/datasets/flwr_datasets/partitioner/pathological_partitioner.py b/datasets/flwr_datasets/partitioner/pathological_partitioner.py new file mode 100644 index 000000000000..1ee60d283044 --- /dev/null +++ b/datasets/flwr_datasets/partitioner/pathological_partitioner.py @@ -0,0 +1,305 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Pathological partitioner class that works with Hugging Face Datasets.""" + + +import warnings +from typing import Any, Dict, List, Literal, Optional + +import numpy as np + +import datasets +from flwr_datasets.common.typing import NDArray +from flwr_datasets.partitioner.partitioner import Partitioner + + +# pylint: disable=too-many-arguments, too-many-instance-attributes +class PathologicalPartitioner(Partitioner): + """Partition dataset such that each partition has a chosen number of classes. + + Implementation based on Federated Learning on Non-IID Data Silos: An Experimental + Study https://arxiv.org/pdf/2102.02079. + + The algorithm firstly determines which classe will be assigned to which partitions. + For each partition `num_classes_per_partition` are sampled in a way chosen in + `class_assignment_mode`. Given the information about the required classes for each + partition, it is determined into how many parts the samples corresponding to this + label should be divided. Such division is performed for each class. + + Parameters + ---------- + num_partitions : int + The total number of partitions that the data will be divided into. + partition_by : str + Column name of the labels (targets) based on which partitioning works. + num_classes_per_partition: int + The (exact) number of unique classes that each partition will have. + class_assignment_mode: Literal["random", "deterministic", "first-deterministic"] + The way how the classes are assigned to the partitions. The default is "random". + The possible values are: + + - "random": Randomly assign classes to the partitions. For each partition choose + the `num_classes_per_partition` classes without replacement. + - "first-deterministic": Assign the first class for each partition in a + deterministic way (class id is the partition_id % num_unique_classes). + The rest of the classes are assigned randomly. In case the number of + partitions is smaller than the number of unique classes, not all classes will + be used in the first iteration, otherwise all the classes will be used (such + it will be present in at least one partition). + - "deterministic": Assign all the classes to the partitions in a deterministic + way. Classes are assigned based on the formula: partion_id has classes + identified by the index: (partition_id + i) % num_unique_classes + where i in {0, ..., num_classes_per_partition}. So, partition 0 will have + classes 0, 1, 2, ..., `num_classes_per_partition`-1, partition 1 will have + classes 1, 2, 3, ...,`num_classes_per_partition`, .... + + The list representing the unique lables is sorted in ascending order. In case + of numbers starting from zero the class id corresponds to the number itself. + `class_assignment_mode="first-deterministic"` was used in the orginal paper, + here we provide the option to use the other modes as well. + shuffle: bool + Whether to randomize the order of samples. Shuffling applied after the + samples assignment to partitions. + seed: int + Seed used for dataset shuffling. It has no effect if `shuffle` is False. + + Examples + -------- + In order to mimic the original behavior of the paper follow the setup below + (the `class_assignment_mode="first-deterministic"`): + + >>> from flwr_datasets.partitioner import PathologicalPartitioner + >>> from flwr_datasets import FederatedDataset + >>> + >>> partitioner = PathologicalPartitioner( + >>> num_partitions=10, + >>> partition_by="label", + >>> num_classes_per_partition=2, + >>> class_assignment_mode="first-deterministic" + >>> ) + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) + >>> partition = fds.load_partition(0) + """ + + def __init__( + self, + num_partitions: int, + partition_by: str, + num_classes_per_partition: int, + class_assignment_mode: Literal[ + "random", "deterministic", "first-deterministic" + ] = "random", + shuffle: bool = True, + seed: Optional[int] = 42, + ) -> None: + super().__init__() + self._num_partitions = num_partitions + self._partition_by = partition_by + self._num_classes_per_partition = num_classes_per_partition + self._class_assignment_mode = class_assignment_mode + self._shuffle = shuffle + self._seed = seed + self._rng = np.random.default_rng(seed=self._seed) + + # Utility attributes + self._partition_id_to_indices: Dict[int, List[int]] = {} + self._partition_id_to_unique_labels: Dict[int, List[Any]] = { + pid: [] for pid in range(self._num_partitions) + } + self._unique_labels: List[Any] = [] + # Count in how many partitions the label is used + self._unique_label_to_times_used_counter: Dict[Any, int] = {} + self._partition_id_to_indices_determined = False + + def load_partition(self, partition_id: int) -> datasets.Dataset: + """Load a partition based on the partition index. + + Parameters + ---------- + partition_id : int + The index that corresponds to the requested partition. + + Returns + ------- + dataset_partition : Dataset + Single partition of a dataset. + """ + # The partitioning is done lazily - only when the first partition is + # requested. Only the first call creates the indices assignments for all the + # partition indices. + self._check_num_partitions_correctness_if_needed() + self._determine_partition_id_to_indices_if_needed() + return self.dataset.select(self._partition_id_to_indices[partition_id]) + + @property + def num_partitions(self) -> int: + """Total number of partitions.""" + self._check_num_partitions_correctness_if_needed() + self._determine_partition_id_to_indices_if_needed() + return self._num_partitions + + def _determine_partition_id_to_indices_if_needed(self) -> None: + """Create an assignment of indices to the partition indices.""" + if self._partition_id_to_indices_determined: + return + self._determine_partition_id_to_unique_labels() + assert self._unique_labels is not None + self._count_partitions_having_each_unique_label() + + labels = np.asarray(self.dataset[self._partition_by]) + self._check_correctness_of_unique_label_to_times_used_counter(labels) + for partition_id in range(self._num_partitions): + self._partition_id_to_indices[partition_id] = [] + + unused_labels = [] + for unique_label in self._unique_labels: + if self._unique_label_to_times_used_counter[unique_label] == 0: + unused_labels.append(unique_label) + continue + # Get the indices in the original dataset where the y == unique_label + unique_label_to_indices = np.where(labels == unique_label)[0] + + split_unique_labels_to_indices = np.array_split( + unique_label_to_indices, + self._unique_label_to_times_used_counter[unique_label], + ) + + split_index = 0 + for partition_id in range(self._num_partitions): + if unique_label in self._partition_id_to_unique_labels[partition_id]: + self._partition_id_to_indices[partition_id].extend( + split_unique_labels_to_indices[split_index] + ) + split_index += 1 + + if len(unused_labels) >= 1: + warnings.warn( + f"Classes: {unused_labels} will NOT be used due to the chosen " + f"configuration. If it is undesired behavior consider setting" + f" 'first_class_deterministic_assignment=True' which in case when" + f" the number of classes is smaller than the number of partitions will " + f"utilize all the classes for the created partitions.", + stacklevel=1, + ) + if self._shuffle: + for indices in self._partition_id_to_indices.values(): + # In place shuffling + self._rng.shuffle(indices) + + self._partition_id_to_indices_determined = True + + def _check_num_partitions_correctness_if_needed(self) -> None: + """Test num_partitions when the dataset is given (in load_partition).""" + if not self._partition_id_to_indices_determined: + if self._num_partitions > self.dataset.num_rows: + raise ValueError( + "The number of partitions needs to be smaller than the number of " + "samples in the dataset." + ) + + def _determine_partition_id_to_unique_labels(self) -> None: + """Determine the assignment of unique labels to the partitions.""" + self._unique_labels = sorted(self.dataset.unique(self._partition_by)) + num_unique_classes = len(self._unique_labels) + + if self._num_classes_per_partition > num_unique_classes: + raise ValueError( + f"The specified `num_classes_per_partition`" + f"={self._num_classes_per_partition} is greater than the number " + f"of unique classes in the given dataset={num_unique_classes}. " + f"Reduce the `num_classes_per_partition` or make use different dataset " + f"to apply this partitioning." + ) + if self._class_assignment_mode == "first-deterministic": + # if self._first_class_deterministic_assignment: + for partition_id in range(self._num_partitions): + label = partition_id % num_unique_classes + self._partition_id_to_unique_labels[partition_id].append(label) + + while ( + len(self._partition_id_to_unique_labels[partition_id]) + < self._num_classes_per_partition + ): + label = self._rng.choice(self._unique_labels, size=1)[0] + if label not in self._partition_id_to_unique_labels[partition_id]: + self._partition_id_to_unique_labels[partition_id].append(label) + elif self._class_assignment_mode == "deterministic": + for partition_id in range(self._num_partitions): + labels = [] + for i in range(self._num_classes_per_partition): + label = self._unique_labels[ + (partition_id + i) % len(self._unique_labels) + ] + labels.append(label) + self._partition_id_to_unique_labels[partition_id] = labels + elif self._class_assignment_mode == "random": + for partition_id in range(self._num_partitions): + labels = self._rng.choice( + self._unique_labels, + size=self._num_classes_per_partition, + replace=False, + ).tolist() + self._partition_id_to_unique_labels[partition_id] = labels + else: + raise ValueError( + f"The supported class_assignment_mode are: 'random', 'deterministic', " + f"'first-deterministic'. You provided: {self._class_assignment_mode}." + ) + + def _count_partitions_having_each_unique_label(self) -> None: + """Count the number of partitions that have each unique label. + + This computation is based on the assigment of the label to the partition_id in + the `_determine_partition_id_to_unique_labels` method. + Given: + * partition 0 has only labels: 0,1 (not necessarily just two samples it can have + many samples but either from 0 or 1) + * partition 1 has only labels: 1, 2 (same count note as above) + * and there are only two partitions then the following will be computed: + { + 0: 1, + 1: 2, + 2: 1 + } + """ + for unique_label in self._unique_labels: + self._unique_label_to_times_used_counter[unique_label] = 0 + for unique_labels in self._partition_id_to_unique_labels.values(): + for unique_label in unique_labels: + self._unique_label_to_times_used_counter[unique_label] += 1 + + def _check_correctness_of_unique_label_to_times_used_counter( + self, labels: NDArray + ) -> None: + """Check if partitioning is possible given the presence requirements. + + The number of times the label can be used must be smaller or equal to the number + of times that the label is present in the dataset. + """ + for unique_label in self._unique_labels: + num_unique = np.sum(labels == unique_label) + if self._unique_label_to_times_used_counter[unique_label] > num_unique: + raise ValueError( + f"Label: {unique_label} is needed to be assigned to more " + f"partitions " + f"({self._unique_label_to_times_used_counter[unique_label]})" + f" than there are samples (corresponding to this label) in the " + f"dataset ({num_unique}). Please decrease the `num_partitions`, " + f"`num_classes_per_partition` to avoid this situation, " + f"or try `class_assigment_mode='deterministic'` to create a more " + f"even distribution of classes along the partitions. " + f"Alternatively use a different dataset if you can not adjust" + f" the any of these parameters." + ) diff --git a/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py b/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py new file mode 100644 index 000000000000..151b7e14659c --- /dev/null +++ b/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py @@ -0,0 +1,262 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test cases for PathologicalPartitioner.""" + + +import unittest +from typing import Dict + +import numpy as np +from parameterized import parameterized + +import datasets +from datasets import Dataset +from flwr_datasets.partitioner.pathological_partitioner import PathologicalPartitioner + + +def _dummy_dataset_setup( + num_samples: int, partition_by: str, num_unique_classes: int +) -> Dataset: + """Create a dummy dataset for testing.""" + data = { + partition_by: np.tile( + np.arange(num_unique_classes), num_samples // num_unique_classes + 1 + )[:num_samples], + "features": np.random.randn(num_samples), + } + return Dataset.from_dict(data) + + +def _dummy_heterogeneous_dataset_setup( + num_samples: int, partition_by: str, num_unique_classes: int +) -> Dataset: + """Create a dummy dataset for testing.""" + data = { + partition_by: np.tile( + np.arange(num_unique_classes), num_samples // num_unique_classes + 1 + )[:num_samples], + "features": np.random.randn(num_samples), + } + return Dataset.from_dict(data) + + +class TestClassConstrainedPartitioner(unittest.TestCase): + """Unit tests for PathologicalPartitioner.""" + + @parameterized.expand( # type: ignore + [ + # num_partition, num_classes_per_partition, num_samples, total_classes + (3, 1, 60, 3), # Single class per partition scenario + (5, 2, 100, 5), + (5, 2, 100, 10), + (4, 3, 120, 6), + ] + ) + def test_correct_num_classes_when_partitioned( + self, + num_partitions: int, + num_classes_per_partition: int, + num_samples: int, + num_unique_classes: int, + ) -> None: + """Test correct number of unique classes.""" + dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes) + partitioner = PathologicalPartitioner( + num_partitions=num_partitions, + partition_by="labels", + num_classes_per_partition=num_classes_per_partition, + ) + partitioner.dataset = dataset + partitions: Dict[int, Dataset] = { + pid: partitioner.load_partition(pid) for pid in range(num_partitions) + } + unique_classes_per_partition = { + pid: np.unique(partition["labels"]) for pid, partition in partitions.items() + } + + for unique_classes in unique_classes_per_partition.values(): + self.assertEqual(num_classes_per_partition, len(unique_classes)) + + def test_first_class_deterministic_assignment(self) -> None: + """Test deterministic assignment of first classes to partitions. + + Test if all the classes are used (which has to be the case, given num_partitions + >= than the number of unique classes). + """ + dataset = _dummy_dataset_setup(100, "labels", 10) + partitioner = PathologicalPartitioner( + num_partitions=10, + partition_by="labels", + num_classes_per_partition=2, + class_assignment_mode="first-deterministic", + ) + partitioner.dataset = dataset + partitioner.load_partition(0) + expected_classes = set(range(10)) + actual_classes = set() + for pid in range(10): + partition = partitioner.load_partition(pid) + actual_classes.update(np.unique(partition["labels"])) + self.assertEqual(expected_classes, actual_classes) + + @parameterized.expand( + [ # type: ignore + # num_partitions, num_classes_per_partition, num_samples, num_unique_classes + (4, 2, 80, 8), + (10, 2, 100, 10), + ] + ) + def test_deterministic_class_assignment( + self, num_partitions, num_classes_per_partition, num_samples, num_unique_classes + ): + """Test deterministic assignment of classes to partitions.""" + dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes) + partitioner = PathologicalPartitioner( + num_partitions=num_partitions, + partition_by="labels", + num_classes_per_partition=num_classes_per_partition, + class_assignment_mode="deterministic", + ) + partitioner.dataset = dataset + partitions = { + pid: partitioner.load_partition(pid) for pid in range(num_partitions) + } + + # Verify each partition has the expected classes, order does not matter + for pid, partition in partitions.items(): + expected_labels = sorted( + [ + (pid + i) % num_unique_classes + for i in range(num_classes_per_partition) + ] + ) + actual_labels = sorted(np.unique(partition["labels"])) + self.assertTrue( + np.array_equal(expected_labels, actual_labels), + f"Partition {pid} does not have the expected labels: " + f"{expected_labels} but instead {actual_labels}.", + ) + + @parameterized.expand( + [ # type: ignore + # num_partitions, num_classes_per_partition, num_samples, num_unique_classes + (10, 3, 20, 3), + ] + ) + def test_too_many_partitions_for_a_class( + self, num_partitions, num_classes_per_partition, num_samples, num_unique_classes + ) -> None: + """Test too many partitions for the number of samples in a class.""" + dataset_1 = _dummy_dataset_setup( + num_samples // 2, "labels", num_unique_classes - 1 + ) + # Create a skewed part of the dataset for the last label + data = { + "labels": np.array([num_unique_classes - 1] * (num_samples // 2)), + "features": np.random.randn(num_samples // 2), + } + dataset_2 = Dataset.from_dict(data) + dataset = datasets.concatenate_datasets([dataset_1, dataset_2]) + + partitioner = PathologicalPartitioner( + num_partitions=num_partitions, + partition_by="labels", + num_classes_per_partition=num_classes_per_partition, + class_assignment_mode="random", + ) + partitioner.dataset = dataset + + with self.assertRaises(ValueError) as context: + _ = partitioner.load_partition(0) + self.assertEqual( + str(context.exception), + "Label: 0 is needed to be assigned to more partitions (10) than there are " + "samples (corresponding to this label) in the dataset (5). " + "Please decrease the `num_partitions`, `num_classes_per_partition` to " + "avoid this situation, or try `class_assigment_mode='deterministic'` to " + "create a more even distribution of classes along the partitions. " + "Alternatively use a different dataset if you can not adjust the any of " + "these parameters.", + ) + + @parameterized.expand( # type: ignore + [ + # num_partitions, num_classes_per_partition, num_samples, num_unique_classes + (10, 11, 100, 10), # 11 > 10 + (5, 11, 100, 10), # 11 > 10 + (10, 20, 100, 5), # 20 > 5 + ] + ) + def test_more_classes_per_partition_than_num_unique_classes_in_dataset_raises( + self, + num_partitions: int, + num_classes_per_partition: int, + num_samples: int, + num_unique_classes: int, + ) -> None: + """Test more num_classes_per_partition > num_unique_classes in the dataset.""" + dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes) + with self.assertRaises(ValueError) as context: + partitioner = PathologicalPartitioner( + num_partitions=num_partitions, + partition_by="labels", + num_classes_per_partition=num_classes_per_partition, + ) + partitioner.dataset = dataset + partitioner.load_partition(0) + self.assertEqual( + str(context.exception), + "The specified " + f"`num_classes_per_partition`={num_classes_per_partition} is " + f"greater than the number of unique classes in the given " + f"dataset={len(dataset.unique('labels'))}. Reduce the " + f"`num_classes_per_partition` or make use different dataset " + f"to apply this partitioning.", + ) + + @parameterized.expand( # type: ignore + [ + # num_classes_per_partition should be irrelevant since the exception should + # be raised at the very beginning + # num_partitions, num_classes_per_partition, num_samples + (10, 2, 5), + (10, 10, 5), + (100, 10, 99), + ] + ) + def test_more_partitions_than_samples_raises( + self, num_partitions: int, num_classes_per_partition: int, num_samples: int + ) -> None: + """Test if generation of more partitions that there are samples raises.""" + # The number of unique classes in the dataset should be irrelevant since the + # exception should be raised at the very beginning + dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes=5) + with self.assertRaises(ValueError) as context: + partitioner = PathologicalPartitioner( + num_partitions=num_partitions, + partition_by="labels", + num_classes_per_partition=num_classes_per_partition, + ) + partitioner.dataset = dataset + partitioner.load_partition(0) + self.assertEqual( + str(context.exception), + "The number of partitions needs to be smaller than the number of " + "samples in the dataset.", + ) + + +if __name__ == "__main__": + unittest.main() From 237988943ab8f0e80fbf467a665f48d4dc92d45e Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Fri, 12 Jul 2024 13:09:59 +0200 Subject: [PATCH 2/5] docs(datasets:skip) Update the partitioner list in README and index (#3785) --- datasets/README.md | 11 ++++++----- datasets/doc/source/index.rst | 1 + 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/datasets/README.md b/datasets/README.md index 1d8014d57ea3..50fc67376ae4 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -42,11 +42,12 @@ Create **custom partitioning schemes** or choose from the **implemented [partiti * IID partitioning `IidPartitioner(num_partitions)` * Dirichlet partitioning `DirichletPartitioner(num_partitions, partition_by, alpha)` * InnerDirichlet partitioning `InnerDirichletPartitioner(partition_sizes, partition_by, alpha)` -* Natural ID partitioner `NaturalIdPartitioner(partition_by)` -* Size partitioner (the abstract base class for the partitioners dictating the division based the number of samples) `SizePartitioner` -* Linear partitioner `LinearPartitioner(num_partitions)` -* Square partitioner `SquarePartitioner(num_partitions)` -* Exponential partitioner `ExponentialPartitioner(num_partitions)` +* Pathological partitioning `PathologicalPartitioner(num_partitions, partition_by, num_classes_per_partition, class_assignment_mode)` +* Natural ID partitioning `NaturalIdPartitioner(partition_by)` +* Size based partitioning (the abstract base class for the partitioners dictating the division based the number of samples) `SizePartitioner` +* Linear partitioning `LinearPartitioner(num_partitions)` +* Square partitioning `SquarePartitioner(num_partitions)` +* Exponential partitioning `ExponentialPartitioner(num_partitions)` * more to come in the future releases (contributions are welcome).

Comparison of partitioning schemes. diff --git a/datasets/doc/source/index.rst b/datasets/doc/source/index.rst index bdcea7650bbc..fcc7920711bf 100644 --- a/datasets/doc/source/index.rst +++ b/datasets/doc/source/index.rst @@ -94,6 +94,7 @@ Here are a few of the ``Partitioner`` s that are available: (for a full list see * IID partitioning ``IidPartitioner(num_partitions)`` * Dirichlet partitioning ``DirichletPartitioner(num_partitions, partition_by, alpha)`` * InnerDirichlet partitioning ``InnerDirichletPartitioner(partition_sizes, partition_by, alpha)`` +* PathologicalPartitioner ``PathologicalPartitioner(num_partitions, partition_by, num_classes_per_partition, class_assignment_mode)`` * Natural ID partitioner ``NaturalIdPartitioner(partition_by)`` * Size partitioner (the abstract base class for the partitioners dictating the division based the number of samples) ``SizePartitioner`` * Linear partitioner ``LinearPartitioner(num_partitions)`` From 2464534578a982d44e631281672ede9f9b57e567 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Fri, 12 Jul 2024 15:06:02 +0200 Subject: [PATCH 3/5] refactor(framework) Log warnings on SuperNode retry attempts (#3789) --- src/py/flwr/client/app.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 851083d4abb7..bfe5147f78e1 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -18,7 +18,7 @@ import sys import time from dataclasses import dataclass -from logging import DEBUG, ERROR, INFO, WARN +from logging import ERROR, INFO, WARN from pathlib import Path from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union @@ -295,7 +295,7 @@ def _on_backoff(retry_state: RetryState) -> None: log(WARN, "Connection attempt failed, retrying...") else: log( - DEBUG, + WARN, "Connection attempt failed, retrying in %.2f seconds", retry_state.actual_wait, ) From 19fad01444d09ccd701e29ab6654fe8611889074 Mon Sep 17 00:00:00 2001 From: Danny Date: Fri, 12 Jul 2024 15:52:03 +0200 Subject: [PATCH 4/5] fix(framework:skip) Update certificate check (#3786) Signed-off-by: Danny Heinrich Co-authored-by: Charles Beauville --- src/py/flwr/superexec/app.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/superexec/app.py b/src/py/flwr/superexec/app.py index b4d4b462bbcc..372ccb443a76 100644 --- a/src/py/flwr/superexec/app.py +++ b/src/py/flwr/superexec/app.py @@ -127,11 +127,11 @@ def _try_obtain_certificates( return None # Check if certificates are provided if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile: - if not Path.is_file(args.ssl_ca_certfile): + if not Path(args.ssl_ca_certfile).is_file(): sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.") - if not Path.is_file(args.ssl_certfile): + if not Path(args.ssl_certfile).is_file(): sys.exit("Path argument `--ssl-certfile` does not point to a file.") - if not Path.is_file(args.ssl_keyfile): + if not Path(args.ssl_keyfile).is_file(): sys.exit("Path argument `--ssl-keyfile` does not point to a file.") certificates = ( Path(args.ssl_ca_certfile).read_bytes(), # CA certificate From 01ca846aeae8e298a6b663861c9f1b5155b1c14c Mon Sep 17 00:00:00 2001 From: Javier Date: Fri, 12 Jul 2024 17:29:38 +0200 Subject: [PATCH 5/5] feat(framework) Capture `node_id`/`node_config` in `Context` via `NodeState` (#3780) Co-authored-by: Daniel J. Beutel --- src/py/flwr/client/app.py | 40 ++++++++++++++++--- .../client/grpc_adapter_client/connection.py | 2 +- src/py/flwr/client/grpc_client/connection.py | 2 +- .../client/grpc_rere_client/connection.py | 5 ++- .../message_handler/message_handler_test.py | 4 +- .../secure_aggregation/secaggplus_mod_test.py | 7 +++- src/py/flwr/client/mod/utils_test.py | 4 +- src/py/flwr/client/node_state.py | 11 +++-- src/py/flwr/client/node_state_tests.py | 2 +- src/py/flwr/client/rest_client/connection.py | 7 ++-- src/py/flwr/common/context.py | 15 ++++++- src/py/flwr/server/compat/legacy_context.py | 2 +- src/py/flwr/server/run_serverapp.py | 4 +- src/py/flwr/server/server_app_test.py | 2 +- .../fleet/vce/backend/raybackend_test.py | 2 +- .../server/superlink/fleet/vce/vce_api.py | 4 +- .../ray_transport/ray_client_proxy.py | 4 +- .../ray_transport/ray_client_proxy_test.py | 8 +++- 18 files changed, 95 insertions(+), 30 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index bfe5147f78e1..fa17ba9a8481 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -319,7 +319,13 @@ def _on_backoff(retry_state: RetryState) -> None: on_backoff=_on_backoff, ) - node_state = NodeState(partition_id=partition_id) + # Empty dict (for now) + # This will be removed once users can pass node_config via flower-supernode + node_config: Dict[str, str] = {} + + # NodeState gets initialized when the first connection is established + node_state: Optional[NodeState] = None + runs: Dict[int, Run] = {} while not app_state_tracker.interrupt: @@ -334,9 +340,33 @@ def _on_backoff(retry_state: RetryState) -> None: ) as conn: receive, send, create_node, delete_node, get_run = conn - # Register node - if create_node is not None: - create_node() # pylint: disable=not-callable + # Register node when connecting the first time + if node_state is None: + if create_node is None: + if transport not in ["grpc-bidi", None]: + raise NotImplementedError( + "All transports except `grpc-bidi` require " + "an implementation for `create_node()`.'" + ) + # gRPC-bidi doesn't have the concept of node_id, + # so we set it to -1 + node_state = NodeState( + node_id=-1, + node_config={}, + partition_id=partition_id, + ) + else: + # Call create_node fn to register node + node_id: Optional[int] = ( # pylint: disable=assignment-from-none + create_node() + ) # pylint: disable=not-callable + if node_id is None: + raise ValueError("Node registration failed") + node_state = NodeState( + node_id=node_id, + node_config=node_config, + partition_id=partition_id, + ) app_state_tracker.register_signal_handler() while not app_state_tracker.interrupt: @@ -580,7 +610,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ Tuple[ Callable[[], Optional[Message]], Callable[[Message], None], - Optional[Callable[[], None]], + Optional[Callable[[], Optional[int]]], Optional[Callable[[], None]], Optional[Callable[[int], Run]], ] diff --git a/src/py/flwr/client/grpc_adapter_client/connection.py b/src/py/flwr/client/grpc_adapter_client/connection.py index 971b630e470b..80a5cf0b4656 100644 --- a/src/py/flwr/client/grpc_adapter_client/connection.py +++ b/src/py/flwr/client/grpc_adapter_client/connection.py @@ -44,7 +44,7 @@ def grpc_adapter( # pylint: disable=R0913 Tuple[ Callable[[], Optional[Message]], Callable[[Message], None], - Optional[Callable[[], None]], + Optional[Callable[[], Optional[int]]], Optional[Callable[[], None]], Optional[Callable[[int], Run]], ] diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 3e9f261c1ecf..a6417106d51b 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -72,7 +72,7 @@ def grpc_connection( # pylint: disable=R0913, R0915 Tuple[ Callable[[], Optional[Message]], Callable[[Message], None], - Optional[Callable[[], None]], + Optional[Callable[[], Optional[int]]], Optional[Callable[[], None]], Optional[Callable[[int], Run]], ] diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 8062ce28fcc7..e573df6854bc 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -79,7 +79,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915 Tuple[ Callable[[], Optional[Message]], Callable[[Message], None], - Optional[Callable[[], None]], + Optional[Callable[[], Optional[int]]], Optional[Callable[[], None]], Optional[Callable[[int], Run]], ] @@ -176,7 +176,7 @@ def ping() -> None: if not ping_stop_event.is_set(): ping_stop_event.wait(next_interval) - def create_node() -> None: + def create_node() -> Optional[int]: """Set create_node.""" # Call FleetAPI create_node_request = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL) @@ -189,6 +189,7 @@ def create_node() -> None: nonlocal node, ping_thread node = cast(Node, create_node_response.node) ping_thread = start_ping_loop(ping, ping_stop_event) + return node.node_id def delete_node() -> None: """Set delete_node.""" diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 9ce4c9620c43..96de7ce0c2cb 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -145,7 +145,7 @@ def test_client_without_get_properties() -> None: actual_msg = handle_legacy_message_from_msgtype( client_fn=_get_client_fn(client), message=message, - context=Context(state=RecordSet(), run_config={}), + context=Context(node_id=1123, node_config={}, state=RecordSet(), run_config={}), ) # Assert @@ -209,7 +209,7 @@ def test_client_with_get_properties() -> None: actual_msg = handle_legacy_message_from_msgtype( client_fn=_get_client_fn(client), message=message, - context=Context(state=RecordSet(), run_config={}), + context=Context(node_id=1123, node_config={}, state=RecordSet(), run_config={}), ) # Assert diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py index 5e4c4411e1f7..2832576fb4fc 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py @@ -73,7 +73,12 @@ def func(configs: Dict[str, ConfigsRecordValues]) -> ConfigsRecord: def _make_ctxt() -> Context: cfg = ConfigsRecord(SecAggPlusState().to_dict()) - return Context(RecordSet(configs_records={RECORD_KEY_STATE: cfg}), run_config={}) + return Context( + node_id=123, + node_config={}, + state=RecordSet(configs_records={RECORD_KEY_STATE: cfg}), + run_config={}, + ) def _make_set_state_fn( diff --git a/src/py/flwr/client/mod/utils_test.py b/src/py/flwr/client/mod/utils_test.py index 7a1dd8988399..a5bbd0a0bb4d 100644 --- a/src/py/flwr/client/mod/utils_test.py +++ b/src/py/flwr/client/mod/utils_test.py @@ -104,7 +104,7 @@ def test_multiple_mods(self) -> None: state = RecordSet() state.metrics_records[METRIC] = MetricsRecord({COUNTER: 0.0}) - context = Context(state=state, run_config={}) + context = Context(node_id=0, node_config={}, state=state, run_config={}) message = _get_dummy_flower_message() # Execute @@ -129,7 +129,7 @@ def test_filter(self) -> None: # Prepare footprint: List[str] = [] mock_app = make_mock_app("app", footprint) - context = Context(state=RecordSet(), run_config={}) + context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={}) message = _get_dummy_flower_message() def filter_mod( diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/node_state.py index 2b090eba9720..d0a349b0cae0 100644 --- a/src/py/flwr/client/node_state.py +++ b/src/py/flwr/client/node_state.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Optional +from typing import Dict, Optional from flwr.common import Context, RecordSet from flwr.common.config import get_fused_config @@ -35,8 +35,11 @@ class RunInfo: class NodeState: """State of a node where client nodes execute runs.""" - def __init__(self, partition_id: Optional[int]) -> None: - self._meta: Dict[str, Any] = {} # holds metadata about the node + def __init__( + self, node_id: int, node_config: Dict[str, str], partition_id: Optional[int] + ) -> None: + self.node_id = node_id + self.node_config = node_config self.run_infos: Dict[int, RunInfo] = {} self._partition_id = partition_id @@ -52,6 +55,8 @@ def register_context( self.run_infos[run_id] = RunInfo( initial_run_config=initial_run_config, context=Context( + node_id=self.node_id, + node_config=self.node_config, state=RecordSet(), run_config=initial_run_config.copy(), partition_id=self._partition_id, diff --git a/src/py/flwr/client/node_state_tests.py b/src/py/flwr/client/node_state_tests.py index effd64a3ae7a..8d7971fa5280 100644 --- a/src/py/flwr/client/node_state_tests.py +++ b/src/py/flwr/client/node_state_tests.py @@ -41,7 +41,7 @@ def test_multirun_in_node_state() -> None: expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"} # NodeState - node_state = NodeState(partition_id=None) + node_state = NodeState(node_id=0, node_config={}, partition_id=None) for task in tasks: run_id = task.run_id diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index 0efa5731ae51..3e81969d898c 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -90,7 +90,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915 Tuple[ Callable[[], Optional[Message]], Callable[[Message], None], - Optional[Callable[[], None]], + Optional[Callable[[], Optional[int]]], Optional[Callable[[], None]], Optional[Callable[[int], Run]], ] @@ -237,19 +237,20 @@ def ping() -> None: if not ping_stop_event.is_set(): ping_stop_event.wait(next_interval) - def create_node() -> None: + def create_node() -> Optional[int]: """Set create_node.""" req = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL) # Send the request res = _request(req, CreateNodeResponse, PATH_CREATE_NODE) if res is None: - return + return None # Remember the node and the ping-loop thread nonlocal node, ping_thread node = res.node ping_thread = start_ping_loop(ping, ping_stop_event) + return node.node_id def delete_node() -> None: """Set delete_node.""" diff --git a/src/py/flwr/common/context.py b/src/py/flwr/common/context.py index 8120723ce9e9..e65300278c84 100644 --- a/src/py/flwr/common/context.py +++ b/src/py/flwr/common/context.py @@ -27,6 +27,11 @@ class Context: Parameters ---------- + node_id : int + The ID that identifies the node. + node_config : Dict[str, str] + A config (key/value mapping) unique to the node and independent of the + `run_config`. This config persists across all runs this node participates in. state : RecordSet Holds records added by the entity in a given run and that will stay local. This means that the data it holds will never leave the system it's running from. @@ -44,16 +49,22 @@ class Context: simulation or proto typing setups. """ + node_id: int + node_config: Dict[str, str] state: RecordSet - partition_id: Optional[int] run_config: Dict[str, str] + partition_id: Optional[int] - def __init__( + def __init__( # pylint: disable=too-many-arguments self, + node_id: int, + node_config: Dict[str, str], state: RecordSet, run_config: Dict[str, str], partition_id: Optional[int] = None, ) -> None: + self.node_id = node_id + self.node_config = node_config self.state = state self.run_config = run_config self.partition_id = partition_id diff --git a/src/py/flwr/server/compat/legacy_context.py b/src/py/flwr/server/compat/legacy_context.py index 9e120c824103..ee09d79012dc 100644 --- a/src/py/flwr/server/compat/legacy_context.py +++ b/src/py/flwr/server/compat/legacy_context.py @@ -52,4 +52,4 @@ def __init__( self.strategy = strategy self.client_manager = client_manager self.history = History() - super().__init__(state, run_config={}) + super().__init__(node_id=0, node_config={}, state=state, run_config={}) diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index b4697e99913f..4cc25feb7e0e 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -78,7 +78,9 @@ def _load() -> ServerApp: server_app = _load() # Initialize Context - context = Context(state=RecordSet(), run_config=server_app_run_config) + context = Context( + node_id=0, node_config={}, state=RecordSet(), run_config=server_app_run_config + ) # Call ServerApp server_app(driver=driver, context=context) diff --git a/src/py/flwr/server/server_app_test.py b/src/py/flwr/server/server_app_test.py index 7de8774d4c81..b0672b3202ed 100644 --- a/src/py/flwr/server/server_app_test.py +++ b/src/py/flwr/server/server_app_test.py @@ -29,7 +29,7 @@ def test_server_app_custom_mode() -> None: # Prepare app = ServerApp() driver = MagicMock() - context = Context(state=RecordSet(), run_config={}) + context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={}) called = {"called": False} diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py index 287983003f8c..da4390194d05 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py @@ -120,7 +120,7 @@ def _create_message_and_context() -> Tuple[Message, Context, float]: ) # Construct emtpy Context - context = Context(state=RecordSet(), run_config={}) + context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={}) # Expected output expected_output = pi * mult_factor diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index 3c0b36e1ca3c..134fd34ed8f0 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -284,7 +284,9 @@ def start_vce( # Construct mapping of NodeStates node_states: Dict[int, NodeState] = {} for node_id, partition_id in nodes_mapping.items(): - node_states[node_id] = NodeState(partition_id=partition_id) + node_states[node_id] = NodeState( + node_id=node_id, node_config={}, partition_id=partition_id + ) # Load backend config log(DEBUG, "Supported backends: %s", list(supported_backends.keys())) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 31bc22c84bd5..f2684016048e 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -59,7 +59,9 @@ def _load_app() -> ClientApp: self.app_fn = _load_app self.actor_pool = actor_pool - self.proxy_state = NodeState(partition_id=self.partition_id) + self.proxy_state = NodeState( + node_id=node_id, node_config={}, partition_id=self.partition_id + ) def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: """Sumbit a message to the ActorPool.""" diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 83f6cfe05313..8831e5f475ea 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -218,7 +218,13 @@ def _load_app() -> ClientApp: _load_app, message, str(node_id), - Context(state=RecordSet(), run_config={}, partition_id=node_id), + Context( + node_id=0, + node_config={}, + state=RecordSet(), + run_config={}, + partition_id=node_id, + ), ), )