diff --git a/datasets/README.md b/datasets/README.md index 1d8014d57ea3..50fc67376ae4 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -42,11 +42,12 @@ Create **custom partitioning schemes** or choose from the **implemented [partiti * IID partitioning `IidPartitioner(num_partitions)` * Dirichlet partitioning `DirichletPartitioner(num_partitions, partition_by, alpha)` * InnerDirichlet partitioning `InnerDirichletPartitioner(partition_sizes, partition_by, alpha)` -* Natural ID partitioner `NaturalIdPartitioner(partition_by)` -* Size partitioner (the abstract base class for the partitioners dictating the division based the number of samples) `SizePartitioner` -* Linear partitioner `LinearPartitioner(num_partitions)` -* Square partitioner `SquarePartitioner(num_partitions)` -* Exponential partitioner `ExponentialPartitioner(num_partitions)` +* Pathological partitioning `PathologicalPartitioner(num_partitions, partition_by, num_classes_per_partition, class_assignment_mode)` +* Natural ID partitioning `NaturalIdPartitioner(partition_by)` +* Size based partitioning (the abstract base class for the partitioners dictating the division based the number of samples) `SizePartitioner` +* Linear partitioning `LinearPartitioner(num_partitions)` +* Square partitioning `SquarePartitioner(num_partitions)` +* Exponential partitioning `ExponentialPartitioner(num_partitions)` * more to come in the future releases (contributions are welcome).
diff --git a/datasets/doc/source/index.rst b/datasets/doc/source/index.rst
index bdcea7650bbc..fcc7920711bf 100644
--- a/datasets/doc/source/index.rst
+++ b/datasets/doc/source/index.rst
@@ -94,6 +94,7 @@ Here are a few of the ``Partitioner`` s that are available: (for a full list see
* IID partitioning ``IidPartitioner(num_partitions)``
* Dirichlet partitioning ``DirichletPartitioner(num_partitions, partition_by, alpha)``
* InnerDirichlet partitioning ``InnerDirichletPartitioner(partition_sizes, partition_by, alpha)``
+* PathologicalPartitioner ``PathologicalPartitioner(num_partitions, partition_by, num_classes_per_partition, class_assignment_mode)``
* Natural ID partitioner ``NaturalIdPartitioner(partition_by)``
* Size partitioner (the abstract base class for the partitioners dictating the division based the number of samples) ``SizePartitioner``
* Linear partitioner ``LinearPartitioner(num_partitions)``
diff --git a/datasets/flwr_datasets/partitioner/__init__.py b/datasets/flwr_datasets/partitioner/__init__.py
index 1fc00ed90323..0c75dbce387a 100644
--- a/datasets/flwr_datasets/partitioner/__init__.py
+++ b/datasets/flwr_datasets/partitioner/__init__.py
@@ -22,6 +22,7 @@
from .linear_partitioner import LinearPartitioner
from .natural_id_partitioner import NaturalIdPartitioner
from .partitioner import Partitioner
+from .pathological_partitioner import PathologicalPartitioner
from .shard_partitioner import ShardPartitioner
from .size_partitioner import SizePartitioner
from .square_partitioner import SquarePartitioner
@@ -34,6 +35,7 @@
"LinearPartitioner",
"NaturalIdPartitioner",
"Partitioner",
+ "PathologicalPartitioner",
"ShardPartitioner",
"SizePartitioner",
"SquarePartitioner",
diff --git a/datasets/flwr_datasets/partitioner/pathological_partitioner.py b/datasets/flwr_datasets/partitioner/pathological_partitioner.py
new file mode 100644
index 000000000000..1ee60d283044
--- /dev/null
+++ b/datasets/flwr_datasets/partitioner/pathological_partitioner.py
@@ -0,0 +1,305 @@
+# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Pathological partitioner class that works with Hugging Face Datasets."""
+
+
+import warnings
+from typing import Any, Dict, List, Literal, Optional
+
+import numpy as np
+
+import datasets
+from flwr_datasets.common.typing import NDArray
+from flwr_datasets.partitioner.partitioner import Partitioner
+
+
+# pylint: disable=too-many-arguments, too-many-instance-attributes
+class PathologicalPartitioner(Partitioner):
+ """Partition dataset such that each partition has a chosen number of classes.
+
+ Implementation based on Federated Learning on Non-IID Data Silos: An Experimental
+ Study https://arxiv.org/pdf/2102.02079.
+
+ The algorithm firstly determines which classe will be assigned to which partitions.
+ For each partition `num_classes_per_partition` are sampled in a way chosen in
+ `class_assignment_mode`. Given the information about the required classes for each
+ partition, it is determined into how many parts the samples corresponding to this
+ label should be divided. Such division is performed for each class.
+
+ Parameters
+ ----------
+ num_partitions : int
+ The total number of partitions that the data will be divided into.
+ partition_by : str
+ Column name of the labels (targets) based on which partitioning works.
+ num_classes_per_partition: int
+ The (exact) number of unique classes that each partition will have.
+ class_assignment_mode: Literal["random", "deterministic", "first-deterministic"]
+ The way how the classes are assigned to the partitions. The default is "random".
+ The possible values are:
+
+ - "random": Randomly assign classes to the partitions. For each partition choose
+ the `num_classes_per_partition` classes without replacement.
+ - "first-deterministic": Assign the first class for each partition in a
+ deterministic way (class id is the partition_id % num_unique_classes).
+ The rest of the classes are assigned randomly. In case the number of
+ partitions is smaller than the number of unique classes, not all classes will
+ be used in the first iteration, otherwise all the classes will be used (such
+ it will be present in at least one partition).
+ - "deterministic": Assign all the classes to the partitions in a deterministic
+ way. Classes are assigned based on the formula: partion_id has classes
+ identified by the index: (partition_id + i) % num_unique_classes
+ where i in {0, ..., num_classes_per_partition}. So, partition 0 will have
+ classes 0, 1, 2, ..., `num_classes_per_partition`-1, partition 1 will have
+ classes 1, 2, 3, ...,`num_classes_per_partition`, ....
+
+ The list representing the unique lables is sorted in ascending order. In case
+ of numbers starting from zero the class id corresponds to the number itself.
+ `class_assignment_mode="first-deterministic"` was used in the orginal paper,
+ here we provide the option to use the other modes as well.
+ shuffle: bool
+ Whether to randomize the order of samples. Shuffling applied after the
+ samples assignment to partitions.
+ seed: int
+ Seed used for dataset shuffling. It has no effect if `shuffle` is False.
+
+ Examples
+ --------
+ In order to mimic the original behavior of the paper follow the setup below
+ (the `class_assignment_mode="first-deterministic"`):
+
+ >>> from flwr_datasets.partitioner import PathologicalPartitioner
+ >>> from flwr_datasets import FederatedDataset
+ >>>
+ >>> partitioner = PathologicalPartitioner(
+ >>> num_partitions=10,
+ >>> partition_by="label",
+ >>> num_classes_per_partition=2,
+ >>> class_assignment_mode="first-deterministic"
+ >>> )
+ >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner})
+ >>> partition = fds.load_partition(0)
+ """
+
+ def __init__(
+ self,
+ num_partitions: int,
+ partition_by: str,
+ num_classes_per_partition: int,
+ class_assignment_mode: Literal[
+ "random", "deterministic", "first-deterministic"
+ ] = "random",
+ shuffle: bool = True,
+ seed: Optional[int] = 42,
+ ) -> None:
+ super().__init__()
+ self._num_partitions = num_partitions
+ self._partition_by = partition_by
+ self._num_classes_per_partition = num_classes_per_partition
+ self._class_assignment_mode = class_assignment_mode
+ self._shuffle = shuffle
+ self._seed = seed
+ self._rng = np.random.default_rng(seed=self._seed)
+
+ # Utility attributes
+ self._partition_id_to_indices: Dict[int, List[int]] = {}
+ self._partition_id_to_unique_labels: Dict[int, List[Any]] = {
+ pid: [] for pid in range(self._num_partitions)
+ }
+ self._unique_labels: List[Any] = []
+ # Count in how many partitions the label is used
+ self._unique_label_to_times_used_counter: Dict[Any, int] = {}
+ self._partition_id_to_indices_determined = False
+
+ def load_partition(self, partition_id: int) -> datasets.Dataset:
+ """Load a partition based on the partition index.
+
+ Parameters
+ ----------
+ partition_id : int
+ The index that corresponds to the requested partition.
+
+ Returns
+ -------
+ dataset_partition : Dataset
+ Single partition of a dataset.
+ """
+ # The partitioning is done lazily - only when the first partition is
+ # requested. Only the first call creates the indices assignments for all the
+ # partition indices.
+ self._check_num_partitions_correctness_if_needed()
+ self._determine_partition_id_to_indices_if_needed()
+ return self.dataset.select(self._partition_id_to_indices[partition_id])
+
+ @property
+ def num_partitions(self) -> int:
+ """Total number of partitions."""
+ self._check_num_partitions_correctness_if_needed()
+ self._determine_partition_id_to_indices_if_needed()
+ return self._num_partitions
+
+ def _determine_partition_id_to_indices_if_needed(self) -> None:
+ """Create an assignment of indices to the partition indices."""
+ if self._partition_id_to_indices_determined:
+ return
+ self._determine_partition_id_to_unique_labels()
+ assert self._unique_labels is not None
+ self._count_partitions_having_each_unique_label()
+
+ labels = np.asarray(self.dataset[self._partition_by])
+ self._check_correctness_of_unique_label_to_times_used_counter(labels)
+ for partition_id in range(self._num_partitions):
+ self._partition_id_to_indices[partition_id] = []
+
+ unused_labels = []
+ for unique_label in self._unique_labels:
+ if self._unique_label_to_times_used_counter[unique_label] == 0:
+ unused_labels.append(unique_label)
+ continue
+ # Get the indices in the original dataset where the y == unique_label
+ unique_label_to_indices = np.where(labels == unique_label)[0]
+
+ split_unique_labels_to_indices = np.array_split(
+ unique_label_to_indices,
+ self._unique_label_to_times_used_counter[unique_label],
+ )
+
+ split_index = 0
+ for partition_id in range(self._num_partitions):
+ if unique_label in self._partition_id_to_unique_labels[partition_id]:
+ self._partition_id_to_indices[partition_id].extend(
+ split_unique_labels_to_indices[split_index]
+ )
+ split_index += 1
+
+ if len(unused_labels) >= 1:
+ warnings.warn(
+ f"Classes: {unused_labels} will NOT be used due to the chosen "
+ f"configuration. If it is undesired behavior consider setting"
+ f" 'first_class_deterministic_assignment=True' which in case when"
+ f" the number of classes is smaller than the number of partitions will "
+ f"utilize all the classes for the created partitions.",
+ stacklevel=1,
+ )
+ if self._shuffle:
+ for indices in self._partition_id_to_indices.values():
+ # In place shuffling
+ self._rng.shuffle(indices)
+
+ self._partition_id_to_indices_determined = True
+
+ def _check_num_partitions_correctness_if_needed(self) -> None:
+ """Test num_partitions when the dataset is given (in load_partition)."""
+ if not self._partition_id_to_indices_determined:
+ if self._num_partitions > self.dataset.num_rows:
+ raise ValueError(
+ "The number of partitions needs to be smaller than the number of "
+ "samples in the dataset."
+ )
+
+ def _determine_partition_id_to_unique_labels(self) -> None:
+ """Determine the assignment of unique labels to the partitions."""
+ self._unique_labels = sorted(self.dataset.unique(self._partition_by))
+ num_unique_classes = len(self._unique_labels)
+
+ if self._num_classes_per_partition > num_unique_classes:
+ raise ValueError(
+ f"The specified `num_classes_per_partition`"
+ f"={self._num_classes_per_partition} is greater than the number "
+ f"of unique classes in the given dataset={num_unique_classes}. "
+ f"Reduce the `num_classes_per_partition` or make use different dataset "
+ f"to apply this partitioning."
+ )
+ if self._class_assignment_mode == "first-deterministic":
+ # if self._first_class_deterministic_assignment:
+ for partition_id in range(self._num_partitions):
+ label = partition_id % num_unique_classes
+ self._partition_id_to_unique_labels[partition_id].append(label)
+
+ while (
+ len(self._partition_id_to_unique_labels[partition_id])
+ < self._num_classes_per_partition
+ ):
+ label = self._rng.choice(self._unique_labels, size=1)[0]
+ if label not in self._partition_id_to_unique_labels[partition_id]:
+ self._partition_id_to_unique_labels[partition_id].append(label)
+ elif self._class_assignment_mode == "deterministic":
+ for partition_id in range(self._num_partitions):
+ labels = []
+ for i in range(self._num_classes_per_partition):
+ label = self._unique_labels[
+ (partition_id + i) % len(self._unique_labels)
+ ]
+ labels.append(label)
+ self._partition_id_to_unique_labels[partition_id] = labels
+ elif self._class_assignment_mode == "random":
+ for partition_id in range(self._num_partitions):
+ labels = self._rng.choice(
+ self._unique_labels,
+ size=self._num_classes_per_partition,
+ replace=False,
+ ).tolist()
+ self._partition_id_to_unique_labels[partition_id] = labels
+ else:
+ raise ValueError(
+ f"The supported class_assignment_mode are: 'random', 'deterministic', "
+ f"'first-deterministic'. You provided: {self._class_assignment_mode}."
+ )
+
+ def _count_partitions_having_each_unique_label(self) -> None:
+ """Count the number of partitions that have each unique label.
+
+ This computation is based on the assigment of the label to the partition_id in
+ the `_determine_partition_id_to_unique_labels` method.
+ Given:
+ * partition 0 has only labels: 0,1 (not necessarily just two samples it can have
+ many samples but either from 0 or 1)
+ * partition 1 has only labels: 1, 2 (same count note as above)
+ * and there are only two partitions then the following will be computed:
+ {
+ 0: 1,
+ 1: 2,
+ 2: 1
+ }
+ """
+ for unique_label in self._unique_labels:
+ self._unique_label_to_times_used_counter[unique_label] = 0
+ for unique_labels in self._partition_id_to_unique_labels.values():
+ for unique_label in unique_labels:
+ self._unique_label_to_times_used_counter[unique_label] += 1
+
+ def _check_correctness_of_unique_label_to_times_used_counter(
+ self, labels: NDArray
+ ) -> None:
+ """Check if partitioning is possible given the presence requirements.
+
+ The number of times the label can be used must be smaller or equal to the number
+ of times that the label is present in the dataset.
+ """
+ for unique_label in self._unique_labels:
+ num_unique = np.sum(labels == unique_label)
+ if self._unique_label_to_times_used_counter[unique_label] > num_unique:
+ raise ValueError(
+ f"Label: {unique_label} is needed to be assigned to more "
+ f"partitions "
+ f"({self._unique_label_to_times_used_counter[unique_label]})"
+ f" than there are samples (corresponding to this label) in the "
+ f"dataset ({num_unique}). Please decrease the `num_partitions`, "
+ f"`num_classes_per_partition` to avoid this situation, "
+ f"or try `class_assigment_mode='deterministic'` to create a more "
+ f"even distribution of classes along the partitions. "
+ f"Alternatively use a different dataset if you can not adjust"
+ f" the any of these parameters."
+ )
diff --git a/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py b/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py
new file mode 100644
index 000000000000..151b7e14659c
--- /dev/null
+++ b/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py
@@ -0,0 +1,262 @@
+# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test cases for PathologicalPartitioner."""
+
+
+import unittest
+from typing import Dict
+
+import numpy as np
+from parameterized import parameterized
+
+import datasets
+from datasets import Dataset
+from flwr_datasets.partitioner.pathological_partitioner import PathologicalPartitioner
+
+
+def _dummy_dataset_setup(
+ num_samples: int, partition_by: str, num_unique_classes: int
+) -> Dataset:
+ """Create a dummy dataset for testing."""
+ data = {
+ partition_by: np.tile(
+ np.arange(num_unique_classes), num_samples // num_unique_classes + 1
+ )[:num_samples],
+ "features": np.random.randn(num_samples),
+ }
+ return Dataset.from_dict(data)
+
+
+def _dummy_heterogeneous_dataset_setup(
+ num_samples: int, partition_by: str, num_unique_classes: int
+) -> Dataset:
+ """Create a dummy dataset for testing."""
+ data = {
+ partition_by: np.tile(
+ np.arange(num_unique_classes), num_samples // num_unique_classes + 1
+ )[:num_samples],
+ "features": np.random.randn(num_samples),
+ }
+ return Dataset.from_dict(data)
+
+
+class TestClassConstrainedPartitioner(unittest.TestCase):
+ """Unit tests for PathologicalPartitioner."""
+
+ @parameterized.expand( # type: ignore
+ [
+ # num_partition, num_classes_per_partition, num_samples, total_classes
+ (3, 1, 60, 3), # Single class per partition scenario
+ (5, 2, 100, 5),
+ (5, 2, 100, 10),
+ (4, 3, 120, 6),
+ ]
+ )
+ def test_correct_num_classes_when_partitioned(
+ self,
+ num_partitions: int,
+ num_classes_per_partition: int,
+ num_samples: int,
+ num_unique_classes: int,
+ ) -> None:
+ """Test correct number of unique classes."""
+ dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes)
+ partitioner = PathologicalPartitioner(
+ num_partitions=num_partitions,
+ partition_by="labels",
+ num_classes_per_partition=num_classes_per_partition,
+ )
+ partitioner.dataset = dataset
+ partitions: Dict[int, Dataset] = {
+ pid: partitioner.load_partition(pid) for pid in range(num_partitions)
+ }
+ unique_classes_per_partition = {
+ pid: np.unique(partition["labels"]) for pid, partition in partitions.items()
+ }
+
+ for unique_classes in unique_classes_per_partition.values():
+ self.assertEqual(num_classes_per_partition, len(unique_classes))
+
+ def test_first_class_deterministic_assignment(self) -> None:
+ """Test deterministic assignment of first classes to partitions.
+
+ Test if all the classes are used (which has to be the case, given num_partitions
+ >= than the number of unique classes).
+ """
+ dataset = _dummy_dataset_setup(100, "labels", 10)
+ partitioner = PathologicalPartitioner(
+ num_partitions=10,
+ partition_by="labels",
+ num_classes_per_partition=2,
+ class_assignment_mode="first-deterministic",
+ )
+ partitioner.dataset = dataset
+ partitioner.load_partition(0)
+ expected_classes = set(range(10))
+ actual_classes = set()
+ for pid in range(10):
+ partition = partitioner.load_partition(pid)
+ actual_classes.update(np.unique(partition["labels"]))
+ self.assertEqual(expected_classes, actual_classes)
+
+ @parameterized.expand(
+ [ # type: ignore
+ # num_partitions, num_classes_per_partition, num_samples, num_unique_classes
+ (4, 2, 80, 8),
+ (10, 2, 100, 10),
+ ]
+ )
+ def test_deterministic_class_assignment(
+ self, num_partitions, num_classes_per_partition, num_samples, num_unique_classes
+ ):
+ """Test deterministic assignment of classes to partitions."""
+ dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes)
+ partitioner = PathologicalPartitioner(
+ num_partitions=num_partitions,
+ partition_by="labels",
+ num_classes_per_partition=num_classes_per_partition,
+ class_assignment_mode="deterministic",
+ )
+ partitioner.dataset = dataset
+ partitions = {
+ pid: partitioner.load_partition(pid) for pid in range(num_partitions)
+ }
+
+ # Verify each partition has the expected classes, order does not matter
+ for pid, partition in partitions.items():
+ expected_labels = sorted(
+ [
+ (pid + i) % num_unique_classes
+ for i in range(num_classes_per_partition)
+ ]
+ )
+ actual_labels = sorted(np.unique(partition["labels"]))
+ self.assertTrue(
+ np.array_equal(expected_labels, actual_labels),
+ f"Partition {pid} does not have the expected labels: "
+ f"{expected_labels} but instead {actual_labels}.",
+ )
+
+ @parameterized.expand(
+ [ # type: ignore
+ # num_partitions, num_classes_per_partition, num_samples, num_unique_classes
+ (10, 3, 20, 3),
+ ]
+ )
+ def test_too_many_partitions_for_a_class(
+ self, num_partitions, num_classes_per_partition, num_samples, num_unique_classes
+ ) -> None:
+ """Test too many partitions for the number of samples in a class."""
+ dataset_1 = _dummy_dataset_setup(
+ num_samples // 2, "labels", num_unique_classes - 1
+ )
+ # Create a skewed part of the dataset for the last label
+ data = {
+ "labels": np.array([num_unique_classes - 1] * (num_samples // 2)),
+ "features": np.random.randn(num_samples // 2),
+ }
+ dataset_2 = Dataset.from_dict(data)
+ dataset = datasets.concatenate_datasets([dataset_1, dataset_2])
+
+ partitioner = PathologicalPartitioner(
+ num_partitions=num_partitions,
+ partition_by="labels",
+ num_classes_per_partition=num_classes_per_partition,
+ class_assignment_mode="random",
+ )
+ partitioner.dataset = dataset
+
+ with self.assertRaises(ValueError) as context:
+ _ = partitioner.load_partition(0)
+ self.assertEqual(
+ str(context.exception),
+ "Label: 0 is needed to be assigned to more partitions (10) than there are "
+ "samples (corresponding to this label) in the dataset (5). "
+ "Please decrease the `num_partitions`, `num_classes_per_partition` to "
+ "avoid this situation, or try `class_assigment_mode='deterministic'` to "
+ "create a more even distribution of classes along the partitions. "
+ "Alternatively use a different dataset if you can not adjust the any of "
+ "these parameters.",
+ )
+
+ @parameterized.expand( # type: ignore
+ [
+ # num_partitions, num_classes_per_partition, num_samples, num_unique_classes
+ (10, 11, 100, 10), # 11 > 10
+ (5, 11, 100, 10), # 11 > 10
+ (10, 20, 100, 5), # 20 > 5
+ ]
+ )
+ def test_more_classes_per_partition_than_num_unique_classes_in_dataset_raises(
+ self,
+ num_partitions: int,
+ num_classes_per_partition: int,
+ num_samples: int,
+ num_unique_classes: int,
+ ) -> None:
+ """Test more num_classes_per_partition > num_unique_classes in the dataset."""
+ dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes)
+ with self.assertRaises(ValueError) as context:
+ partitioner = PathologicalPartitioner(
+ num_partitions=num_partitions,
+ partition_by="labels",
+ num_classes_per_partition=num_classes_per_partition,
+ )
+ partitioner.dataset = dataset
+ partitioner.load_partition(0)
+ self.assertEqual(
+ str(context.exception),
+ "The specified "
+ f"`num_classes_per_partition`={num_classes_per_partition} is "
+ f"greater than the number of unique classes in the given "
+ f"dataset={len(dataset.unique('labels'))}. Reduce the "
+ f"`num_classes_per_partition` or make use different dataset "
+ f"to apply this partitioning.",
+ )
+
+ @parameterized.expand( # type: ignore
+ [
+ # num_classes_per_partition should be irrelevant since the exception should
+ # be raised at the very beginning
+ # num_partitions, num_classes_per_partition, num_samples
+ (10, 2, 5),
+ (10, 10, 5),
+ (100, 10, 99),
+ ]
+ )
+ def test_more_partitions_than_samples_raises(
+ self, num_partitions: int, num_classes_per_partition: int, num_samples: int
+ ) -> None:
+ """Test if generation of more partitions that there are samples raises."""
+ # The number of unique classes in the dataset should be irrelevant since the
+ # exception should be raised at the very beginning
+ dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes=5)
+ with self.assertRaises(ValueError) as context:
+ partitioner = PathologicalPartitioner(
+ num_partitions=num_partitions,
+ partition_by="labels",
+ num_classes_per_partition=num_classes_per_partition,
+ )
+ partitioner.dataset = dataset
+ partitioner.load_partition(0)
+ self.assertEqual(
+ str(context.exception),
+ "The number of partitions needs to be smaller than the number of "
+ "samples in the dataset.",
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py
index 20e0fbb9b229..998bdb44e4c0 100644
--- a/src/py/flwr/client/app.py
+++ b/src/py/flwr/client/app.py
@@ -18,7 +18,7 @@
import sys
import time
from dataclasses import dataclass
-from logging import DEBUG, ERROR, INFO, WARN
+from logging import ERROR, INFO, WARN
from pathlib import Path
from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union
@@ -296,7 +296,7 @@ def _on_backoff(retry_state: RetryState) -> None:
log(WARN, "Connection attempt failed, retrying...")
else:
log(
- DEBUG,
+ WARN,
"Connection attempt failed, retrying in %.2f seconds",
retry_state.actual_wait,
)
@@ -320,7 +320,9 @@ def _on_backoff(retry_state: RetryState) -> None:
on_backoff=_on_backoff,
)
- node_state = NodeState(node_id=-1, node_config=node_config, partition_id=-1)
+ # NodeState gets initialized when the first connection is established
+ node_state: Optional[NodeState] = None
+
runs: Dict[int, Run] = {}
while not app_state_tracker.interrupt:
@@ -335,13 +337,33 @@ def _on_backoff(retry_state: RetryState) -> None:
) as conn:
receive, send, create_node, delete_node, get_run = conn
- # Register node
- if create_node is not None:
- node_id = ( # pylint: disable=assignment-from-none
- create_node()
- ) # pylint: disable=not-callable
- if transport in ["grpc-rere", None]:
- node_state.node_id = node_id # type: ignore
+ # Register node when connecting the first time
+ if node_state is None:
+ if create_node is None:
+ if transport not in ["grpc-bidi", None]:
+ raise NotImplementedError(
+ "All transports except `grpc-bidi` require "
+ "an implementation for `create_node()`.'"
+ )
+ # gRPC-bidi doesn't have the concept of node_id,
+ # so we set it to -1
+ node_state = NodeState(
+ node_id=-1,
+ node_config={},
+ partition_id=None,
+ )
+ else:
+ # Call create_node fn to register node
+ node_id: Optional[int] = ( # pylint: disable=assignment-from-none
+ create_node()
+ ) # pylint: disable=not-callable
+ if node_id is None:
+ raise ValueError("Node registration failed")
+ node_state = NodeState(
+ node_id=node_id,
+ node_config=node_config,
+ partition_id=None,
+ )
app_state_tracker.register_signal_handler()
while not app_state_tracker.interrupt:
diff --git a/src/py/flwr/client/grpc_adapter_client/connection.py b/src/py/flwr/client/grpc_adapter_client/connection.py
index d4071f3b1793..80a5cf0b4656 100644
--- a/src/py/flwr/client/grpc_adapter_client/connection.py
+++ b/src/py/flwr/client/grpc_adapter_client/connection.py
@@ -44,7 +44,7 @@ def grpc_adapter( # pylint: disable=R0913
Tuple[
Callable[[], Optional[Message]],
Callable[[Message], None],
- Optional[Callable[[], int]],
+ Optional[Callable[[], Optional[int]]],
Optional[Callable[[], None]],
Optional[Callable[[int], Run]],
]
diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py
index c4051e6c5c16..e573df6854bc 100644
--- a/src/py/flwr/client/grpc_rere_client/connection.py
+++ b/src/py/flwr/client/grpc_rere_client/connection.py
@@ -79,7 +79,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
Tuple[
Callable[[], Optional[Message]],
Callable[[Message], None],
- Optional[Callable[[], int]],
+ Optional[Callable[[], Optional[int]]],
Optional[Callable[[], None]],
Optional[Callable[[int], Run]],
]
@@ -176,7 +176,7 @@ def ping() -> None:
if not ping_stop_event.is_set():
ping_stop_event.wait(next_interval)
- def create_node() -> int:
+ def create_node() -> Optional[int]:
"""Set create_node."""
# Call FleetAPI
create_node_request = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py
index 4e667801f105..3e81969d898c 100644
--- a/src/py/flwr/client/rest_client/connection.py
+++ b/src/py/flwr/client/rest_client/connection.py
@@ -237,19 +237,20 @@ def ping() -> None:
if not ping_stop_event.is_set():
ping_stop_event.wait(next_interval)
- def create_node() -> None:
+ def create_node() -> Optional[int]:
"""Set create_node."""
req = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
# Send the request
res = _request(req, CreateNodeResponse, PATH_CREATE_NODE)
if res is None:
- return
+ return None
# Remember the node and the ping-loop thread
nonlocal node, ping_thread
node = res.node
ping_thread = start_ping_loop(ping, ping_stop_event)
+ return node.node_id
def delete_node() -> None:
"""Set delete_node."""
diff --git a/src/py/flwr/common/context.py b/src/py/flwr/common/context.py
index efe3f8f01664..e65300278c84 100644
--- a/src/py/flwr/common/context.py
+++ b/src/py/flwr/common/context.py
@@ -28,10 +28,10 @@ class Context:
Parameters
----------
node_id : int
- A UUID that identifies the node.
+ The ID that identifies the node.
node_config : Dict[str, str]
A config (key/value mapping) unique to the node and independent of the
- `run_config`. This config persist across runs this node participates in.
+ `run_config`. This config persists across all runs this node participates in.
state : RecordSet
Holds records added by the entity in a given run and that will stay local.
This means that the data it holds will never leave the system it's running from.
diff --git a/src/py/flwr/superexec/app.py b/src/py/flwr/superexec/app.py
index b4d4b462bbcc..372ccb443a76 100644
--- a/src/py/flwr/superexec/app.py
+++ b/src/py/flwr/superexec/app.py
@@ -127,11 +127,11 @@ def _try_obtain_certificates(
return None
# Check if certificates are provided
if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile:
- if not Path.is_file(args.ssl_ca_certfile):
+ if not Path(args.ssl_ca_certfile).is_file():
sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.")
- if not Path.is_file(args.ssl_certfile):
+ if not Path(args.ssl_certfile).is_file():
sys.exit("Path argument `--ssl-certfile` does not point to a file.")
- if not Path.is_file(args.ssl_keyfile):
+ if not Path(args.ssl_keyfile).is_file():
sys.exit("Path argument `--ssl-keyfile` does not point to a file.")
certificates = (
Path(args.ssl_ca_certfile).read_bytes(), # CA certificate