diff --git a/datasets/README.md b/datasets/README.md index 1d8014d57ea3..50fc67376ae4 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -42,11 +42,12 @@ Create **custom partitioning schemes** or choose from the **implemented [partiti * IID partitioning `IidPartitioner(num_partitions)` * Dirichlet partitioning `DirichletPartitioner(num_partitions, partition_by, alpha)` * InnerDirichlet partitioning `InnerDirichletPartitioner(partition_sizes, partition_by, alpha)` -* Natural ID partitioner `NaturalIdPartitioner(partition_by)` -* Size partitioner (the abstract base class for the partitioners dictating the division based the number of samples) `SizePartitioner` -* Linear partitioner `LinearPartitioner(num_partitions)` -* Square partitioner `SquarePartitioner(num_partitions)` -* Exponential partitioner `ExponentialPartitioner(num_partitions)` +* Pathological partitioning `PathologicalPartitioner(num_partitions, partition_by, num_classes_per_partition, class_assignment_mode)` +* Natural ID partitioning `NaturalIdPartitioner(partition_by)` +* Size based partitioning (the abstract base class for the partitioners dictating the division based the number of samples) `SizePartitioner` +* Linear partitioning `LinearPartitioner(num_partitions)` +* Square partitioning `SquarePartitioner(num_partitions)` +* Exponential partitioning `ExponentialPartitioner(num_partitions)` * more to come in the future releases (contributions are welcome).
diff --git a/datasets/doc/source/index.rst b/datasets/doc/source/index.rst
index bdcea7650bbc..fcc7920711bf 100644
--- a/datasets/doc/source/index.rst
+++ b/datasets/doc/source/index.rst
@@ -94,6 +94,7 @@ Here are a few of the ``Partitioner`` s that are available: (for a full list see
* IID partitioning ``IidPartitioner(num_partitions)``
* Dirichlet partitioning ``DirichletPartitioner(num_partitions, partition_by, alpha)``
* InnerDirichlet partitioning ``InnerDirichletPartitioner(partition_sizes, partition_by, alpha)``
+* PathologicalPartitioner ``PathologicalPartitioner(num_partitions, partition_by, num_classes_per_partition, class_assignment_mode)``
* Natural ID partitioner ``NaturalIdPartitioner(partition_by)``
* Size partitioner (the abstract base class for the partitioners dictating the division based the number of samples) ``SizePartitioner``
* Linear partitioner ``LinearPartitioner(num_partitions)``
diff --git a/datasets/flwr_datasets/partitioner/__init__.py b/datasets/flwr_datasets/partitioner/__init__.py
index 1fc00ed90323..0c75dbce387a 100644
--- a/datasets/flwr_datasets/partitioner/__init__.py
+++ b/datasets/flwr_datasets/partitioner/__init__.py
@@ -22,6 +22,7 @@
from .linear_partitioner import LinearPartitioner
from .natural_id_partitioner import NaturalIdPartitioner
from .partitioner import Partitioner
+from .pathological_partitioner import PathologicalPartitioner
from .shard_partitioner import ShardPartitioner
from .size_partitioner import SizePartitioner
from .square_partitioner import SquarePartitioner
@@ -34,6 +35,7 @@
"LinearPartitioner",
"NaturalIdPartitioner",
"Partitioner",
+ "PathologicalPartitioner",
"ShardPartitioner",
"SizePartitioner",
"SquarePartitioner",
diff --git a/datasets/flwr_datasets/partitioner/pathological_partitioner.py b/datasets/flwr_datasets/partitioner/pathological_partitioner.py
new file mode 100644
index 000000000000..1ee60d283044
--- /dev/null
+++ b/datasets/flwr_datasets/partitioner/pathological_partitioner.py
@@ -0,0 +1,305 @@
+# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Pathological partitioner class that works with Hugging Face Datasets."""
+
+
+import warnings
+from typing import Any, Dict, List, Literal, Optional
+
+import numpy as np
+
+import datasets
+from flwr_datasets.common.typing import NDArray
+from flwr_datasets.partitioner.partitioner import Partitioner
+
+
+# pylint: disable=too-many-arguments, too-many-instance-attributes
+class PathologicalPartitioner(Partitioner):
+ """Partition dataset such that each partition has a chosen number of classes.
+
+ Implementation based on Federated Learning on Non-IID Data Silos: An Experimental
+ Study https://arxiv.org/pdf/2102.02079.
+
+ The algorithm firstly determines which classe will be assigned to which partitions.
+ For each partition `num_classes_per_partition` are sampled in a way chosen in
+ `class_assignment_mode`. Given the information about the required classes for each
+ partition, it is determined into how many parts the samples corresponding to this
+ label should be divided. Such division is performed for each class.
+
+ Parameters
+ ----------
+ num_partitions : int
+ The total number of partitions that the data will be divided into.
+ partition_by : str
+ Column name of the labels (targets) based on which partitioning works.
+ num_classes_per_partition: int
+ The (exact) number of unique classes that each partition will have.
+ class_assignment_mode: Literal["random", "deterministic", "first-deterministic"]
+ The way how the classes are assigned to the partitions. The default is "random".
+ The possible values are:
+
+ - "random": Randomly assign classes to the partitions. For each partition choose
+ the `num_classes_per_partition` classes without replacement.
+ - "first-deterministic": Assign the first class for each partition in a
+ deterministic way (class id is the partition_id % num_unique_classes).
+ The rest of the classes are assigned randomly. In case the number of
+ partitions is smaller than the number of unique classes, not all classes will
+ be used in the first iteration, otherwise all the classes will be used (such
+ it will be present in at least one partition).
+ - "deterministic": Assign all the classes to the partitions in a deterministic
+ way. Classes are assigned based on the formula: partion_id has classes
+ identified by the index: (partition_id + i) % num_unique_classes
+ where i in {0, ..., num_classes_per_partition}. So, partition 0 will have
+ classes 0, 1, 2, ..., `num_classes_per_partition`-1, partition 1 will have
+ classes 1, 2, 3, ...,`num_classes_per_partition`, ....
+
+ The list representing the unique lables is sorted in ascending order. In case
+ of numbers starting from zero the class id corresponds to the number itself.
+ `class_assignment_mode="first-deterministic"` was used in the orginal paper,
+ here we provide the option to use the other modes as well.
+ shuffle: bool
+ Whether to randomize the order of samples. Shuffling applied after the
+ samples assignment to partitions.
+ seed: int
+ Seed used for dataset shuffling. It has no effect if `shuffle` is False.
+
+ Examples
+ --------
+ In order to mimic the original behavior of the paper follow the setup below
+ (the `class_assignment_mode="first-deterministic"`):
+
+ >>> from flwr_datasets.partitioner import PathologicalPartitioner
+ >>> from flwr_datasets import FederatedDataset
+ >>>
+ >>> partitioner = PathologicalPartitioner(
+ >>> num_partitions=10,
+ >>> partition_by="label",
+ >>> num_classes_per_partition=2,
+ >>> class_assignment_mode="first-deterministic"
+ >>> )
+ >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner})
+ >>> partition = fds.load_partition(0)
+ """
+
+ def __init__(
+ self,
+ num_partitions: int,
+ partition_by: str,
+ num_classes_per_partition: int,
+ class_assignment_mode: Literal[
+ "random", "deterministic", "first-deterministic"
+ ] = "random",
+ shuffle: bool = True,
+ seed: Optional[int] = 42,
+ ) -> None:
+ super().__init__()
+ self._num_partitions = num_partitions
+ self._partition_by = partition_by
+ self._num_classes_per_partition = num_classes_per_partition
+ self._class_assignment_mode = class_assignment_mode
+ self._shuffle = shuffle
+ self._seed = seed
+ self._rng = np.random.default_rng(seed=self._seed)
+
+ # Utility attributes
+ self._partition_id_to_indices: Dict[int, List[int]] = {}
+ self._partition_id_to_unique_labels: Dict[int, List[Any]] = {
+ pid: [] for pid in range(self._num_partitions)
+ }
+ self._unique_labels: List[Any] = []
+ # Count in how many partitions the label is used
+ self._unique_label_to_times_used_counter: Dict[Any, int] = {}
+ self._partition_id_to_indices_determined = False
+
+ def load_partition(self, partition_id: int) -> datasets.Dataset:
+ """Load a partition based on the partition index.
+
+ Parameters
+ ----------
+ partition_id : int
+ The index that corresponds to the requested partition.
+
+ Returns
+ -------
+ dataset_partition : Dataset
+ Single partition of a dataset.
+ """
+ # The partitioning is done lazily - only when the first partition is
+ # requested. Only the first call creates the indices assignments for all the
+ # partition indices.
+ self._check_num_partitions_correctness_if_needed()
+ self._determine_partition_id_to_indices_if_needed()
+ return self.dataset.select(self._partition_id_to_indices[partition_id])
+
+ @property
+ def num_partitions(self) -> int:
+ """Total number of partitions."""
+ self._check_num_partitions_correctness_if_needed()
+ self._determine_partition_id_to_indices_if_needed()
+ return self._num_partitions
+
+ def _determine_partition_id_to_indices_if_needed(self) -> None:
+ """Create an assignment of indices to the partition indices."""
+ if self._partition_id_to_indices_determined:
+ return
+ self._determine_partition_id_to_unique_labels()
+ assert self._unique_labels is not None
+ self._count_partitions_having_each_unique_label()
+
+ labels = np.asarray(self.dataset[self._partition_by])
+ self._check_correctness_of_unique_label_to_times_used_counter(labels)
+ for partition_id in range(self._num_partitions):
+ self._partition_id_to_indices[partition_id] = []
+
+ unused_labels = []
+ for unique_label in self._unique_labels:
+ if self._unique_label_to_times_used_counter[unique_label] == 0:
+ unused_labels.append(unique_label)
+ continue
+ # Get the indices in the original dataset where the y == unique_label
+ unique_label_to_indices = np.where(labels == unique_label)[0]
+
+ split_unique_labels_to_indices = np.array_split(
+ unique_label_to_indices,
+ self._unique_label_to_times_used_counter[unique_label],
+ )
+
+ split_index = 0
+ for partition_id in range(self._num_partitions):
+ if unique_label in self._partition_id_to_unique_labels[partition_id]:
+ self._partition_id_to_indices[partition_id].extend(
+ split_unique_labels_to_indices[split_index]
+ )
+ split_index += 1
+
+ if len(unused_labels) >= 1:
+ warnings.warn(
+ f"Classes: {unused_labels} will NOT be used due to the chosen "
+ f"configuration. If it is undesired behavior consider setting"
+ f" 'first_class_deterministic_assignment=True' which in case when"
+ f" the number of classes is smaller than the number of partitions will "
+ f"utilize all the classes for the created partitions.",
+ stacklevel=1,
+ )
+ if self._shuffle:
+ for indices in self._partition_id_to_indices.values():
+ # In place shuffling
+ self._rng.shuffle(indices)
+
+ self._partition_id_to_indices_determined = True
+
+ def _check_num_partitions_correctness_if_needed(self) -> None:
+ """Test num_partitions when the dataset is given (in load_partition)."""
+ if not self._partition_id_to_indices_determined:
+ if self._num_partitions > self.dataset.num_rows:
+ raise ValueError(
+ "The number of partitions needs to be smaller than the number of "
+ "samples in the dataset."
+ )
+
+ def _determine_partition_id_to_unique_labels(self) -> None:
+ """Determine the assignment of unique labels to the partitions."""
+ self._unique_labels = sorted(self.dataset.unique(self._partition_by))
+ num_unique_classes = len(self._unique_labels)
+
+ if self._num_classes_per_partition > num_unique_classes:
+ raise ValueError(
+ f"The specified `num_classes_per_partition`"
+ f"={self._num_classes_per_partition} is greater than the number "
+ f"of unique classes in the given dataset={num_unique_classes}. "
+ f"Reduce the `num_classes_per_partition` or make use different dataset "
+ f"to apply this partitioning."
+ )
+ if self._class_assignment_mode == "first-deterministic":
+ # if self._first_class_deterministic_assignment:
+ for partition_id in range(self._num_partitions):
+ label = partition_id % num_unique_classes
+ self._partition_id_to_unique_labels[partition_id].append(label)
+
+ while (
+ len(self._partition_id_to_unique_labels[partition_id])
+ < self._num_classes_per_partition
+ ):
+ label = self._rng.choice(self._unique_labels, size=1)[0]
+ if label not in self._partition_id_to_unique_labels[partition_id]:
+ self._partition_id_to_unique_labels[partition_id].append(label)
+ elif self._class_assignment_mode == "deterministic":
+ for partition_id in range(self._num_partitions):
+ labels = []
+ for i in range(self._num_classes_per_partition):
+ label = self._unique_labels[
+ (partition_id + i) % len(self._unique_labels)
+ ]
+ labels.append(label)
+ self._partition_id_to_unique_labels[partition_id] = labels
+ elif self._class_assignment_mode == "random":
+ for partition_id in range(self._num_partitions):
+ labels = self._rng.choice(
+ self._unique_labels,
+ size=self._num_classes_per_partition,
+ replace=False,
+ ).tolist()
+ self._partition_id_to_unique_labels[partition_id] = labels
+ else:
+ raise ValueError(
+ f"The supported class_assignment_mode are: 'random', 'deterministic', "
+ f"'first-deterministic'. You provided: {self._class_assignment_mode}."
+ )
+
+ def _count_partitions_having_each_unique_label(self) -> None:
+ """Count the number of partitions that have each unique label.
+
+ This computation is based on the assigment of the label to the partition_id in
+ the `_determine_partition_id_to_unique_labels` method.
+ Given:
+ * partition 0 has only labels: 0,1 (not necessarily just two samples it can have
+ many samples but either from 0 or 1)
+ * partition 1 has only labels: 1, 2 (same count note as above)
+ * and there are only two partitions then the following will be computed:
+ {
+ 0: 1,
+ 1: 2,
+ 2: 1
+ }
+ """
+ for unique_label in self._unique_labels:
+ self._unique_label_to_times_used_counter[unique_label] = 0
+ for unique_labels in self._partition_id_to_unique_labels.values():
+ for unique_label in unique_labels:
+ self._unique_label_to_times_used_counter[unique_label] += 1
+
+ def _check_correctness_of_unique_label_to_times_used_counter(
+ self, labels: NDArray
+ ) -> None:
+ """Check if partitioning is possible given the presence requirements.
+
+ The number of times the label can be used must be smaller or equal to the number
+ of times that the label is present in the dataset.
+ """
+ for unique_label in self._unique_labels:
+ num_unique = np.sum(labels == unique_label)
+ if self._unique_label_to_times_used_counter[unique_label] > num_unique:
+ raise ValueError(
+ f"Label: {unique_label} is needed to be assigned to more "
+ f"partitions "
+ f"({self._unique_label_to_times_used_counter[unique_label]})"
+ f" than there are samples (corresponding to this label) in the "
+ f"dataset ({num_unique}). Please decrease the `num_partitions`, "
+ f"`num_classes_per_partition` to avoid this situation, "
+ f"or try `class_assigment_mode='deterministic'` to create a more "
+ f"even distribution of classes along the partitions. "
+ f"Alternatively use a different dataset if you can not adjust"
+ f" the any of these parameters."
+ )
diff --git a/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py b/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py
new file mode 100644
index 000000000000..151b7e14659c
--- /dev/null
+++ b/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py
@@ -0,0 +1,262 @@
+# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test cases for PathologicalPartitioner."""
+
+
+import unittest
+from typing import Dict
+
+import numpy as np
+from parameterized import parameterized
+
+import datasets
+from datasets import Dataset
+from flwr_datasets.partitioner.pathological_partitioner import PathologicalPartitioner
+
+
+def _dummy_dataset_setup(
+ num_samples: int, partition_by: str, num_unique_classes: int
+) -> Dataset:
+ """Create a dummy dataset for testing."""
+ data = {
+ partition_by: np.tile(
+ np.arange(num_unique_classes), num_samples // num_unique_classes + 1
+ )[:num_samples],
+ "features": np.random.randn(num_samples),
+ }
+ return Dataset.from_dict(data)
+
+
+def _dummy_heterogeneous_dataset_setup(
+ num_samples: int, partition_by: str, num_unique_classes: int
+) -> Dataset:
+ """Create a dummy dataset for testing."""
+ data = {
+ partition_by: np.tile(
+ np.arange(num_unique_classes), num_samples // num_unique_classes + 1
+ )[:num_samples],
+ "features": np.random.randn(num_samples),
+ }
+ return Dataset.from_dict(data)
+
+
+class TestClassConstrainedPartitioner(unittest.TestCase):
+ """Unit tests for PathologicalPartitioner."""
+
+ @parameterized.expand( # type: ignore
+ [
+ # num_partition, num_classes_per_partition, num_samples, total_classes
+ (3, 1, 60, 3), # Single class per partition scenario
+ (5, 2, 100, 5),
+ (5, 2, 100, 10),
+ (4, 3, 120, 6),
+ ]
+ )
+ def test_correct_num_classes_when_partitioned(
+ self,
+ num_partitions: int,
+ num_classes_per_partition: int,
+ num_samples: int,
+ num_unique_classes: int,
+ ) -> None:
+ """Test correct number of unique classes."""
+ dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes)
+ partitioner = PathologicalPartitioner(
+ num_partitions=num_partitions,
+ partition_by="labels",
+ num_classes_per_partition=num_classes_per_partition,
+ )
+ partitioner.dataset = dataset
+ partitions: Dict[int, Dataset] = {
+ pid: partitioner.load_partition(pid) for pid in range(num_partitions)
+ }
+ unique_classes_per_partition = {
+ pid: np.unique(partition["labels"]) for pid, partition in partitions.items()
+ }
+
+ for unique_classes in unique_classes_per_partition.values():
+ self.assertEqual(num_classes_per_partition, len(unique_classes))
+
+ def test_first_class_deterministic_assignment(self) -> None:
+ """Test deterministic assignment of first classes to partitions.
+
+ Test if all the classes are used (which has to be the case, given num_partitions
+ >= than the number of unique classes).
+ """
+ dataset = _dummy_dataset_setup(100, "labels", 10)
+ partitioner = PathologicalPartitioner(
+ num_partitions=10,
+ partition_by="labels",
+ num_classes_per_partition=2,
+ class_assignment_mode="first-deterministic",
+ )
+ partitioner.dataset = dataset
+ partitioner.load_partition(0)
+ expected_classes = set(range(10))
+ actual_classes = set()
+ for pid in range(10):
+ partition = partitioner.load_partition(pid)
+ actual_classes.update(np.unique(partition["labels"]))
+ self.assertEqual(expected_classes, actual_classes)
+
+ @parameterized.expand(
+ [ # type: ignore
+ # num_partitions, num_classes_per_partition, num_samples, num_unique_classes
+ (4, 2, 80, 8),
+ (10, 2, 100, 10),
+ ]
+ )
+ def test_deterministic_class_assignment(
+ self, num_partitions, num_classes_per_partition, num_samples, num_unique_classes
+ ):
+ """Test deterministic assignment of classes to partitions."""
+ dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes)
+ partitioner = PathologicalPartitioner(
+ num_partitions=num_partitions,
+ partition_by="labels",
+ num_classes_per_partition=num_classes_per_partition,
+ class_assignment_mode="deterministic",
+ )
+ partitioner.dataset = dataset
+ partitions = {
+ pid: partitioner.load_partition(pid) for pid in range(num_partitions)
+ }
+
+ # Verify each partition has the expected classes, order does not matter
+ for pid, partition in partitions.items():
+ expected_labels = sorted(
+ [
+ (pid + i) % num_unique_classes
+ for i in range(num_classes_per_partition)
+ ]
+ )
+ actual_labels = sorted(np.unique(partition["labels"]))
+ self.assertTrue(
+ np.array_equal(expected_labels, actual_labels),
+ f"Partition {pid} does not have the expected labels: "
+ f"{expected_labels} but instead {actual_labels}.",
+ )
+
+ @parameterized.expand(
+ [ # type: ignore
+ # num_partitions, num_classes_per_partition, num_samples, num_unique_classes
+ (10, 3, 20, 3),
+ ]
+ )
+ def test_too_many_partitions_for_a_class(
+ self, num_partitions, num_classes_per_partition, num_samples, num_unique_classes
+ ) -> None:
+ """Test too many partitions for the number of samples in a class."""
+ dataset_1 = _dummy_dataset_setup(
+ num_samples // 2, "labels", num_unique_classes - 1
+ )
+ # Create a skewed part of the dataset for the last label
+ data = {
+ "labels": np.array([num_unique_classes - 1] * (num_samples // 2)),
+ "features": np.random.randn(num_samples // 2),
+ }
+ dataset_2 = Dataset.from_dict(data)
+ dataset = datasets.concatenate_datasets([dataset_1, dataset_2])
+
+ partitioner = PathologicalPartitioner(
+ num_partitions=num_partitions,
+ partition_by="labels",
+ num_classes_per_partition=num_classes_per_partition,
+ class_assignment_mode="random",
+ )
+ partitioner.dataset = dataset
+
+ with self.assertRaises(ValueError) as context:
+ _ = partitioner.load_partition(0)
+ self.assertEqual(
+ str(context.exception),
+ "Label: 0 is needed to be assigned to more partitions (10) than there are "
+ "samples (corresponding to this label) in the dataset (5). "
+ "Please decrease the `num_partitions`, `num_classes_per_partition` to "
+ "avoid this situation, or try `class_assigment_mode='deterministic'` to "
+ "create a more even distribution of classes along the partitions. "
+ "Alternatively use a different dataset if you can not adjust the any of "
+ "these parameters.",
+ )
+
+ @parameterized.expand( # type: ignore
+ [
+ # num_partitions, num_classes_per_partition, num_samples, num_unique_classes
+ (10, 11, 100, 10), # 11 > 10
+ (5, 11, 100, 10), # 11 > 10
+ (10, 20, 100, 5), # 20 > 5
+ ]
+ )
+ def test_more_classes_per_partition_than_num_unique_classes_in_dataset_raises(
+ self,
+ num_partitions: int,
+ num_classes_per_partition: int,
+ num_samples: int,
+ num_unique_classes: int,
+ ) -> None:
+ """Test more num_classes_per_partition > num_unique_classes in the dataset."""
+ dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes)
+ with self.assertRaises(ValueError) as context:
+ partitioner = PathologicalPartitioner(
+ num_partitions=num_partitions,
+ partition_by="labels",
+ num_classes_per_partition=num_classes_per_partition,
+ )
+ partitioner.dataset = dataset
+ partitioner.load_partition(0)
+ self.assertEqual(
+ str(context.exception),
+ "The specified "
+ f"`num_classes_per_partition`={num_classes_per_partition} is "
+ f"greater than the number of unique classes in the given "
+ f"dataset={len(dataset.unique('labels'))}. Reduce the "
+ f"`num_classes_per_partition` or make use different dataset "
+ f"to apply this partitioning.",
+ )
+
+ @parameterized.expand( # type: ignore
+ [
+ # num_classes_per_partition should be irrelevant since the exception should
+ # be raised at the very beginning
+ # num_partitions, num_classes_per_partition, num_samples
+ (10, 2, 5),
+ (10, 10, 5),
+ (100, 10, 99),
+ ]
+ )
+ def test_more_partitions_than_samples_raises(
+ self, num_partitions: int, num_classes_per_partition: int, num_samples: int
+ ) -> None:
+ """Test if generation of more partitions that there are samples raises."""
+ # The number of unique classes in the dataset should be irrelevant since the
+ # exception should be raised at the very beginning
+ dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes=5)
+ with self.assertRaises(ValueError) as context:
+ partitioner = PathologicalPartitioner(
+ num_partitions=num_partitions,
+ partition_by="labels",
+ num_classes_per_partition=num_classes_per_partition,
+ )
+ partitioner.dataset = dataset
+ partitioner.load_partition(0)
+ self.assertEqual(
+ str(context.exception),
+ "The number of partitions needs to be smaller than the number of "
+ "samples in the dataset.",
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/e2e/bare-client-auth/client.py b/e2e/bare-client-auth/client.py
index e82f17088bd9..c7b0d59b8ea5 100644
--- a/e2e/bare-client-auth/client.py
+++ b/e2e/bare-client-auth/client.py
@@ -1,13 +1,14 @@
import numpy as np
-import flwr as fl
+from flwr.client import ClientApp, NumPyClient
+from flwr.common import Context
model_params = np.array([1])
objective = 5
# Define Flower client
-class FlowerClient(fl.client.NumPyClient):
+class FlowerClient(NumPyClient):
def get_parameters(self, config):
return model_params
@@ -23,10 +24,10 @@ def evaluate(self, parameters, config):
return loss, 1, {"accuracy": accuracy}
-def client_fn(cid):
+def client_fn(context: Context):
return FlowerClient().to_client()
-app = fl.client.ClientApp(
+app = ClientApp(
client_fn=client_fn,
)
diff --git a/e2e/bare-https/client.py b/e2e/bare-https/client.py
index 8f5c1412fd01..4a682af3aec3 100644
--- a/e2e/bare-https/client.py
+++ b/e2e/bare-https/client.py
@@ -2,14 +2,15 @@
import numpy as np
-import flwr as fl
+from flwr.client import ClientApp, NumPyClient, start_client
+from flwr.common import Context
model_params = np.array([1])
objective = 5
# Define Flower client
-class FlowerClient(fl.client.NumPyClient):
+class FlowerClient(NumPyClient):
def get_parameters(self, config):
return model_params
@@ -25,17 +26,17 @@ def evaluate(self, parameters, config):
return loss, 1, {"accuracy": accuracy}
-def client_fn(cid):
+def client_fn(context: Context):
return FlowerClient().to_client()
-app = fl.client.ClientApp(
+app = ClientApp(
client_fn=client_fn,
)
if __name__ == "__main__":
# Start Flower client
- fl.client.start_client(
+ start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
root_certificates=Path("certificates/ca.crt").read_bytes(),
diff --git a/e2e/bare/client.py b/e2e/bare/client.py
index 402d775ac3a9..943e60d5db9f 100644
--- a/e2e/bare/client.py
+++ b/e2e/bare/client.py
@@ -2,8 +2,8 @@
import numpy as np
-import flwr as fl
-from flwr.common import ConfigsRecord
+from flwr.client import ClientApp, NumPyClient, start_client
+from flwr.common import ConfigsRecord, Context
SUBSET_SIZE = 1000
STATE_VAR = "timestamp"
@@ -14,7 +14,7 @@
# Define Flower client
-class FlowerClient(fl.client.NumPyClient):
+class FlowerClient(NumPyClient):
def get_parameters(self, config):
return model_params
@@ -51,16 +51,14 @@ def evaluate(self, parameters, config):
)
-def client_fn(cid):
+def client_fn(context: Context):
return FlowerClient().to_client()
-app = fl.client.ClientApp(
+app = ClientApp(
client_fn=client_fn,
)
if __name__ == "__main__":
# Start Flower client
- fl.client.start_client(
- server_address="127.0.0.1:8080", client=FlowerClient().to_client()
- )
+ start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client())
diff --git a/e2e/docker/client.py b/e2e/docker/client.py
index 8451b810416b..44313c7c3af6 100644
--- a/e2e/docker/client.py
+++ b/e2e/docker/client.py
@@ -9,6 +9,7 @@
from torchvision.transforms import Compose, Normalize, ToTensor
from flwr.client import ClientApp, NumPyClient
+from flwr.common import Context
# #############################################################################
# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader
@@ -122,7 +123,7 @@ def evaluate(self, parameters, config):
return loss, len(testloader.dataset), {"accuracy": accuracy}
-def client_fn(cid: str):
+def client_fn(context: Context):
"""Create and return an instance of Flower `Client`."""
return FlowerClient().to_client()
diff --git a/e2e/framework-fastai/client.py b/e2e/framework-fastai/client.py
index 1d98a1134941..161b27b5a548 100644
--- a/e2e/framework-fastai/client.py
+++ b/e2e/framework-fastai/client.py
@@ -5,7 +5,8 @@
import torch
from fastai.vision.all import *
-import flwr as fl
+from flwr.client import ClientApp, NumPyClient, start_client
+from flwr.common import Context
warnings.filterwarnings("ignore", category=UserWarning)
@@ -29,7 +30,7 @@
# Define Flower client
-class FlowerClient(fl.client.NumPyClient):
+class FlowerClient(NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in learn.model.state_dict().items()]
@@ -49,18 +50,18 @@ def evaluate(self, parameters, config):
return loss, len(dls.valid), {"accuracy": 1 - error_rate}
-def client_fn(cid):
+def client_fn(context: Context):
return FlowerClient().to_client()
-app = fl.client.ClientApp(
+app = ClientApp(
client_fn=client_fn,
)
if __name__ == "__main__":
# Start Flower client
- fl.client.start_client(
+ start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
)
diff --git a/e2e/framework-jax/client.py b/e2e/framework-jax/client.py
index 347a005d923a..c9ff67b3e38e 100644
--- a/e2e/framework-jax/client.py
+++ b/e2e/framework-jax/client.py
@@ -6,7 +6,8 @@
import jax_training
import numpy as np
-import flwr as fl
+from flwr.client import ClientApp, NumPyClient, start_client
+from flwr.common import Context
# Load data and determine model shape
train_x, train_y, test_x, test_y = jax_training.load_data()
@@ -14,7 +15,7 @@
model_shape = train_x.shape[1:]
-class FlowerClient(fl.client.NumPyClient):
+class FlowerClient(NumPyClient):
def __init__(self):
self.params = jax_training.load_model(model_shape)
@@ -48,16 +49,14 @@ def evaluate(
return float(loss), num_examples, {"loss": float(loss)}
-def client_fn(cid):
+def client_fn(context: Context):
return FlowerClient().to_client()
-app = fl.client.ClientApp(
+app = ClientApp(
client_fn=client_fn,
)
if __name__ == "__main__":
# Start Flower client
- fl.client.start_client(
- server_address="127.0.0.1:8080", client=FlowerClient().to_client()
- )
+ start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client())
diff --git a/e2e/framework-opacus/client.py b/e2e/framework-opacus/client.py
index c9ebe319063a..167fa4584e37 100644
--- a/e2e/framework-opacus/client.py
+++ b/e2e/framework-opacus/client.py
@@ -9,7 +9,8 @@
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
-import flwr as fl
+from flwr.client import ClientApp, NumPyClient, start_client
+from flwr.common import Context
# Define parameters.
PARAMS = {
@@ -95,7 +96,7 @@ def load_data():
# Define Flower client.
-class FlowerClient(fl.client.NumPyClient):
+class FlowerClient(NumPyClient):
def __init__(self, model) -> None:
super().__init__()
# Create a privacy engine which will add DP and keep track of the privacy budget.
@@ -139,16 +140,16 @@ def evaluate(self, parameters, config):
return float(loss), len(testloader), {"accuracy": float(accuracy)}
-def client_fn(cid):
+def client_fn(context: Context):
model = Net()
return FlowerClient(model).to_client()
-app = fl.client.ClientApp(
+app = ClientApp(
client_fn=client_fn,
)
if __name__ == "__main__":
- fl.client.start_client(
+ start_client(
server_address="127.0.0.1:8080", client=FlowerClient(model).to_client()
)
diff --git a/e2e/framework-pandas/client.py b/e2e/framework-pandas/client.py
index 19e15f5a3b11..0c3300e1dd3f 100644
--- a/e2e/framework-pandas/client.py
+++ b/e2e/framework-pandas/client.py
@@ -3,7 +3,8 @@
import numpy as np
import pandas as pd
-import flwr as fl
+from flwr.client import ClientApp, NumPyClient, start_client
+from flwr.common import Context
df = pd.read_csv("./data/client.csv")
@@ -16,7 +17,7 @@ def compute_hist(df: pd.DataFrame, col_name: str) -> np.ndarray:
# Define Flower client
-class FlowerClient(fl.client.NumPyClient):
+class FlowerClient(NumPyClient):
def fit(
self, parameters: List[np.ndarray], config: Dict[str, str]
) -> Tuple[List[np.ndarray], int, Dict]:
@@ -32,17 +33,17 @@ def fit(
)
-def client_fn(cid):
+def client_fn(context: Context):
return FlowerClient().to_client()
-app = fl.client.ClientApp(
+app = ClientApp(
client_fn=client_fn,
)
if __name__ == "__main__":
# Start Flower client
- fl.client.start_client(
+ start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
)
diff --git a/e2e/framework-pytorch-lightning/client.py b/e2e/framework-pytorch-lightning/client.py
index fdd55b3dc344..bf291a1ca2c5 100644
--- a/e2e/framework-pytorch-lightning/client.py
+++ b/e2e/framework-pytorch-lightning/client.py
@@ -4,10 +4,11 @@
import pytorch_lightning as pl
import torch
-import flwr as fl
+from flwr.client import ClientApp, NumPyClient, start_client
+from flwr.common import Context
-class FlowerClient(fl.client.NumPyClient):
+class FlowerClient(NumPyClient):
def __init__(self, model, train_loader, val_loader, test_loader):
self.model = model
self.train_loader = train_loader
@@ -51,7 +52,7 @@ def _set_parameters(model, parameters):
model.load_state_dict(state_dict, strict=True)
-def client_fn(cid):
+def client_fn(context: Context):
model = mnist.LitAutoEncoder()
train_loader, val_loader, test_loader = mnist.load_data()
@@ -59,7 +60,7 @@ def client_fn(cid):
return FlowerClient(model, train_loader, val_loader, test_loader).to_client()
-app = fl.client.ClientApp(
+app = ClientApp(
client_fn=client_fn,
)
@@ -71,7 +72,7 @@ def main() -> None:
# Flower client
client = FlowerClient(model, train_loader, val_loader, test_loader).to_client()
- fl.client.start_client(server_address="127.0.0.1:8080", client=client)
+ start_client(server_address="127.0.0.1:8080", client=client)
if __name__ == "__main__":
diff --git a/e2e/framework-pytorch/client.py b/e2e/framework-pytorch/client.py
index dbfbfed1ffa7..ab4bc7b5c5b9 100644
--- a/e2e/framework-pytorch/client.py
+++ b/e2e/framework-pytorch/client.py
@@ -10,8 +10,8 @@
from torchvision.transforms import Compose, Normalize, ToTensor
from tqdm import tqdm
-import flwr as fl
-from flwr.common import ConfigsRecord
+from flwr.client import ClientApp, NumPyClient, start_client
+from flwr.common import ConfigsRecord, Context
# #############################################################################
# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader
@@ -89,7 +89,7 @@ def load_data():
# Define Flower client
-class FlowerClient(fl.client.NumPyClient):
+class FlowerClient(NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]
@@ -136,18 +136,18 @@ def set_parameters(model, parameters):
return
-def client_fn(cid):
+def client_fn(context: Context):
return FlowerClient().to_client()
-app = fl.client.ClientApp(
+app = ClientApp(
client_fn=client_fn,
)
if __name__ == "__main__":
# Start Flower client
- fl.client.start_client(
+ start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
)
diff --git a/e2e/framework-scikit-learn/client.py b/e2e/framework-scikit-learn/client.py
index b0691e75a79d..24c6617c1289 100644
--- a/e2e/framework-scikit-learn/client.py
+++ b/e2e/framework-scikit-learn/client.py
@@ -5,7 +5,8 @@
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
-import flwr as fl
+from flwr.client import ClientApp, NumPyClient, start_client
+from flwr.common import Context
# Load MNIST dataset from https://www.openml.org/d/554
(X_train, y_train), (X_test, y_test) = utils.load_mnist()
@@ -26,7 +27,7 @@
# Define Flower client
-class FlowerClient(fl.client.NumPyClient):
+class FlowerClient(NumPyClient):
def get_parameters(self, config): # type: ignore
return utils.get_model_parameters(model)
@@ -45,16 +46,14 @@ def evaluate(self, parameters, config): # type: ignore
return loss, len(X_test), {"accuracy": accuracy}
-def client_fn(cid):
+def client_fn(context: Context):
return FlowerClient().to_client()
-app = fl.client.ClientApp(
+app = ClientApp(
client_fn=client_fn,
)
if __name__ == "__main__":
# Start Flower client
- fl.client.start_client(
- server_address="0.0.0.0:8080", client=FlowerClient().to_client()
- )
+ start_client(server_address="0.0.0.0:8080", client=FlowerClient().to_client())
diff --git a/e2e/framework-tensorflow/client.py b/e2e/framework-tensorflow/client.py
index 779be0c3746d..351f495a3acb 100644
--- a/e2e/framework-tensorflow/client.py
+++ b/e2e/framework-tensorflow/client.py
@@ -2,7 +2,8 @@
import tensorflow as tf
-import flwr as fl
+from flwr.client import ClientApp, NumPyClient, start_client
+from flwr.common import Context
SUBSET_SIZE = 1000
@@ -18,7 +19,7 @@
# Define Flower client
-class FlowerClient(fl.client.NumPyClient):
+class FlowerClient(NumPyClient):
def get_parameters(self, config):
return model.get_weights()
@@ -33,16 +34,14 @@ def evaluate(self, parameters, config):
return loss, len(x_test), {"accuracy": accuracy}
-def client_fn(cid):
+def client_fn(context: Context):
return FlowerClient().to_client()
-app = fl.client.ClientApp(
+app = ClientApp(
client_fn=client_fn,
)
if __name__ == "__main__":
# Start Flower client
- fl.client.start_client(
- server_address="127.0.0.1:8080", client=FlowerClient().to_client()
- )
+ start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client())
diff --git a/e2e/strategies/client.py b/e2e/strategies/client.py
index 505340e013a5..0403416cc3b7 100644
--- a/e2e/strategies/client.py
+++ b/e2e/strategies/client.py
@@ -2,7 +2,8 @@
import tensorflow as tf
-import flwr as fl
+from flwr.client import ClientApp, NumPyClient, start_client
+from flwr.common import Context
SUBSET_SIZE = 1000
@@ -33,7 +34,7 @@ def get_model():
# Define Flower client
-class FlowerClient(fl.client.NumPyClient):
+class FlowerClient(NumPyClient):
def get_parameters(self, config):
return model.get_weights()
@@ -48,17 +49,15 @@ def evaluate(self, parameters, config):
return loss, len(x_test), {"accuracy": accuracy}
-def client_fn(cid):
+def client_fn(context: Context):
return FlowerClient().to_client()
-app = fl.client.ClientApp(
+app = ClientApp(
client_fn=client_fn,
)
if __name__ == "__main__":
# Start Flower client
- fl.client.start_client(
- server_address="127.0.0.1:8080", client=FlowerClient().to_client()
- )
+ start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client())
diff --git a/e2e/strategies/test.py b/e2e/strategies/test.py
index abf9cdb5a5c7..c567f33b236b 100644
--- a/e2e/strategies/test.py
+++ b/e2e/strategies/test.py
@@ -3,8 +3,8 @@
import tensorflow as tf
from client import SUBSET_SIZE, FlowerClient, get_model
-import flwr as fl
-from flwr.common import ndarrays_to_parameters
+from flwr.common import Context, ndarrays_to_parameters
+from flwr.server import ServerConfig
from flwr.server.strategy import (
FaultTolerantFedAvg,
FedAdagrad,
@@ -15,6 +15,7 @@
FedYogi,
QFedAvg,
)
+from flwr.simulation import start_simulation
STRATEGY_LIST = [
FedMedian,
@@ -42,8 +43,7 @@ def get_strat(name):
init_model = get_model()
-def client_fn(cid):
- _ = cid
+def client_fn(context: Context):
return FlowerClient()
@@ -71,10 +71,10 @@ def evaluate(server_round, parameters, config):
if start_idx >= OPT_IDX:
strat_args["tau"] = 0.01
-hist = fl.simulation.start_simulation(
+hist = start_simulation(
client_fn=client_fn,
num_clients=2,
- config=fl.server.ServerConfig(num_rounds=3),
+ config=ServerConfig(num_rounds=3),
strategy=strategy(**strat_args),
)
diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py
index 851083d4abb7..348ef8910dd3 100644
--- a/src/py/flwr/client/app.py
+++ b/src/py/flwr/client/app.py
@@ -18,7 +18,7 @@
import sys
import time
from dataclasses import dataclass
-from logging import DEBUG, ERROR, INFO, WARN
+from logging import ERROR, INFO, WARN
from pathlib import Path
from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union
@@ -28,7 +28,7 @@
from flwr.client.client import Client
from flwr.client.client_app import ClientApp, LoadClientAppError
from flwr.client.typing import ClientFnExt
-from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event
+from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
from flwr.common.address import parse_address
from flwr.common.constant import (
MISSING_EXTRA_REST,
@@ -138,8 +138,8 @@ class `flwr.client.Client` (default: None)
Starting an SSL-enabled gRPC client using system certificates:
- >>> def client_fn(node_id: int, partition_id: Optional[int]):
- >>> return FlowerClient()
+ >>> def client_fn(context: Context):
+ >>> return FlowerClient().to_client()
>>>
>>> start_client(
>>> server_address=localhost:8080,
@@ -160,6 +160,7 @@ class `flwr.client.Client` (default: None)
event(EventType.START_CLIENT_ENTER)
_start_client_internal(
server_address=server_address,
+ node_config={},
load_client_app_fn=None,
client_fn=client_fn,
client=client,
@@ -181,6 +182,7 @@ class `flwr.client.Client` (default: None)
def _start_client_internal(
*,
server_address: str,
+ node_config: Dict[str, str],
load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None,
client_fn: Optional[ClientFnExt] = None,
client: Optional[Client] = None,
@@ -193,7 +195,6 @@ def _start_client_internal(
] = None,
max_retries: Optional[int] = None,
max_wait_time: Optional[float] = None,
- partition_id: Optional[int] = None,
flwr_dir: Optional[Path] = None,
) -> None:
"""Start a Flower client node which connects to a Flower server.
@@ -204,6 +205,8 @@ def _start_client_internal(
The IPv4 or IPv6 address of the server. If the Flower
server runs on the same machine on port 8080, then `server_address`
would be `"[::]:8080"`.
+ node_config: Dict[str, str]
+ The configuration of the node.
load_client_app_fn : Optional[Callable[[], ClientApp]] (default: None)
A function that can be used to load a `ClientApp` instance.
client_fn : Optional[ClientFnExt]
@@ -238,9 +241,6 @@ class `flwr.client.Client` (default: None)
The maximum duration before the client stops trying to
connect to the server in case of connection error.
If set to None, there is no limit to the total time.
- partition_id: Optional[int] (default: None)
- The data partition index associated with this node. Better suited for
- prototyping purposes.
flwr_dir: Optional[Path] (default: None)
The fully resolved path containing installed Flower Apps.
"""
@@ -253,8 +253,7 @@ class `flwr.client.Client` (default: None)
if client_fn is None:
# Wrap `Client` instance in `client_fn`
def single_client_factory(
- node_id: int, # pylint: disable=unused-argument
- partition_id: Optional[int], # pylint: disable=unused-argument
+ context: Context, # pylint: disable=unused-argument
) -> Client:
if client is None: # Added this to keep mypy happy
raise ValueError(
@@ -295,7 +294,7 @@ def _on_backoff(retry_state: RetryState) -> None:
log(WARN, "Connection attempt failed, retrying...")
else:
log(
- DEBUG,
+ WARN,
"Connection attempt failed, retrying in %.2f seconds",
retry_state.actual_wait,
)
@@ -319,7 +318,9 @@ def _on_backoff(retry_state: RetryState) -> None:
on_backoff=_on_backoff,
)
- node_state = NodeState(partition_id=partition_id)
+ # NodeState gets initialized when the first connection is established
+ node_state: Optional[NodeState] = None
+
runs: Dict[int, Run] = {}
while not app_state_tracker.interrupt:
@@ -334,9 +335,31 @@ def _on_backoff(retry_state: RetryState) -> None:
) as conn:
receive, send, create_node, delete_node, get_run = conn
- # Register node
- if create_node is not None:
- create_node() # pylint: disable=not-callable
+ # Register node when connecting the first time
+ if node_state is None:
+ if create_node is None:
+ if transport not in ["grpc-bidi", None]:
+ raise NotImplementedError(
+ "All transports except `grpc-bidi` require "
+ "an implementation for `create_node()`.'"
+ )
+ # gRPC-bidi doesn't have the concept of node_id,
+ # so we set it to -1
+ node_state = NodeState(
+ node_id=-1,
+ node_config={},
+ )
+ else:
+ # Call create_node fn to register node
+ node_id: Optional[int] = ( # pylint: disable=assignment-from-none
+ create_node()
+ ) # pylint: disable=not-callable
+ if node_id is None:
+ raise ValueError("Node registration failed")
+ node_state = NodeState(
+ node_id=node_id,
+ node_config=node_config,
+ )
app_state_tracker.register_signal_handler()
while not app_state_tracker.interrupt:
@@ -580,7 +603,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
Tuple[
Callable[[], Optional[Message]],
Callable[[Message], None],
- Optional[Callable[[], None]],
+ Optional[Callable[[], Optional[int]]],
Optional[Callable[[], None]],
Optional[Callable[[int], Run]],
]
diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py
index 663d83a8b19e..2a913b3a248d 100644
--- a/src/py/flwr/client/client_app.py
+++ b/src/py/flwr/client/client_app.py
@@ -30,21 +30,41 @@
from .typing import ClientAppCallable
+def _alert_erroneous_client_fn() -> None:
+ raise ValueError(
+ "A `ClientApp` cannot make use of a `client_fn` that does "
+ "not have a signature in the form: `def client_fn(context: "
+ "Context)`. You can import the `Context` like this: "
+ "`from flwr.common import Context`"
+ )
+
+
def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt:
client_fn_args = inspect.signature(client_fn).parameters
+ first_arg = list(client_fn_args.keys())[0]
+
+ if len(client_fn_args) != 1:
+ _alert_erroneous_client_fn()
+
+ first_arg_type = client_fn_args[first_arg].annotation
- if not all(key in client_fn_args for key in ["node_id", "partition_id"]):
+ if first_arg_type is str or first_arg == "cid":
+ # Warn previous signature for `client_fn` seems to be used
warn_deprecated_feature(
- "`client_fn` now expects a signature `def client_fn(node_id: int, "
- "partition_id: Optional[int])`.\nYou provided `client_fn` with signature: "
- f"{dict(client_fn_args.items())}"
+ "`client_fn` now expects a signature `def client_fn(context: Context)`."
+ "The provided `client_fn` has signature: "
+ f"{dict(client_fn_args.items())}. You can import the `Context` like this:"
+ " `from flwr.common import Context`"
)
# Wrap depcreated client_fn inside a function with the expected signature
def adaptor_fn(
- node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument
- ) -> Client:
- return client_fn(str(partition_id)) # type: ignore
+ context: Context,
+ ) -> Client: # pylint: disable=unused-argument
+ # if patition-id is defined, pass it. Else pass node_id that should
+ # always be defined during Context init.
+ cid = context.node_config.get("partition-id", context.node_id)
+ return client_fn(str(cid)) # type: ignore
return adaptor_fn
@@ -71,7 +91,7 @@ class ClientApp:
>>> class FlowerClient(NumPyClient):
>>> # ...
>>>
- >>> def client_fn(node_id: int, partition_id: Optional[int]):
+ >>> def client_fn(context: Context):
>>> return FlowerClient().to_client()
>>>
>>> app = ClientApp(client_fn)
diff --git a/src/py/flwr/client/grpc_adapter_client/connection.py b/src/py/flwr/client/grpc_adapter_client/connection.py
index 971b630e470b..80a5cf0b4656 100644
--- a/src/py/flwr/client/grpc_adapter_client/connection.py
+++ b/src/py/flwr/client/grpc_adapter_client/connection.py
@@ -44,7 +44,7 @@ def grpc_adapter( # pylint: disable=R0913
Tuple[
Callable[[], Optional[Message]],
Callable[[Message], None],
- Optional[Callable[[], None]],
+ Optional[Callable[[], Optional[int]]],
Optional[Callable[[], None]],
Optional[Callable[[int], Run]],
]
diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py
index 3e9f261c1ecf..a6417106d51b 100644
--- a/src/py/flwr/client/grpc_client/connection.py
+++ b/src/py/flwr/client/grpc_client/connection.py
@@ -72,7 +72,7 @@ def grpc_connection( # pylint: disable=R0913, R0915
Tuple[
Callable[[], Optional[Message]],
Callable[[Message], None],
- Optional[Callable[[], None]],
+ Optional[Callable[[], Optional[int]]],
Optional[Callable[[], None]],
Optional[Callable[[int], Run]],
]
diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py
index 8062ce28fcc7..e573df6854bc 100644
--- a/src/py/flwr/client/grpc_rere_client/connection.py
+++ b/src/py/flwr/client/grpc_rere_client/connection.py
@@ -79,7 +79,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
Tuple[
Callable[[], Optional[Message]],
Callable[[Message], None],
- Optional[Callable[[], None]],
+ Optional[Callable[[], Optional[int]]],
Optional[Callable[[], None]],
Optional[Callable[[int], Run]],
]
@@ -176,7 +176,7 @@ def ping() -> None:
if not ping_stop_event.is_set():
ping_stop_event.wait(next_interval)
- def create_node() -> None:
+ def create_node() -> Optional[int]:
"""Set create_node."""
# Call FleetAPI
create_node_request = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
@@ -189,6 +189,7 @@ def create_node() -> None:
nonlocal node, ping_thread
node = cast(Node, create_node_response.node)
ping_thread = start_ping_loop(ping, ping_stop_event)
+ return node.node_id
def delete_node() -> None:
"""Set delete_node."""
diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py
index e9a853a92101..1ab84eb01468 100644
--- a/src/py/flwr/client/message_handler/message_handler.py
+++ b/src/py/flwr/client/message_handler/message_handler.py
@@ -92,7 +92,7 @@ def handle_legacy_message_from_msgtype(
client_fn: ClientFnExt, message: Message, context: Context
) -> Message:
"""Handle legacy message in the inner most mod."""
- client = client_fn(message.metadata.dst_node_id, context.partition_id)
+ client = client_fn(context)
# Check if NumPyClient is returend
if isinstance(client, NumPyClient):
diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py
index 9ce4c9620c43..557d61ffb32a 100644
--- a/src/py/flwr/client/message_handler/message_handler_test.py
+++ b/src/py/flwr/client/message_handler/message_handler_test.py
@@ -19,7 +19,7 @@
import unittest
import uuid
from copy import copy
-from typing import List, Optional
+from typing import List
from flwr.client import Client
from flwr.client.typing import ClientFnExt
@@ -114,9 +114,7 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
def _get_client_fn(client: Client) -> ClientFnExt:
- def client_fn(
- node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument
- ) -> Client:
+ def client_fn(contex: Context) -> Client: # pylint: disable=unused-argument
return client
return client_fn
@@ -145,7 +143,7 @@ def test_client_without_get_properties() -> None:
actual_msg = handle_legacy_message_from_msgtype(
client_fn=_get_client_fn(client),
message=message,
- context=Context(state=RecordSet(), run_config={}),
+ context=Context(node_id=1123, node_config={}, state=RecordSet(), run_config={}),
)
# Assert
@@ -209,7 +207,7 @@ def test_client_with_get_properties() -> None:
actual_msg = handle_legacy_message_from_msgtype(
client_fn=_get_client_fn(client),
message=message,
- context=Context(state=RecordSet(), run_config={}),
+ context=Context(node_id=1123, node_config={}, state=RecordSet(), run_config={}),
)
# Assert
diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py
index 5e4c4411e1f7..2832576fb4fc 100644
--- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py
+++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py
@@ -73,7 +73,12 @@ def func(configs: Dict[str, ConfigsRecordValues]) -> ConfigsRecord:
def _make_ctxt() -> Context:
cfg = ConfigsRecord(SecAggPlusState().to_dict())
- return Context(RecordSet(configs_records={RECORD_KEY_STATE: cfg}), run_config={})
+ return Context(
+ node_id=123,
+ node_config={},
+ state=RecordSet(configs_records={RECORD_KEY_STATE: cfg}),
+ run_config={},
+ )
def _make_set_state_fn(
diff --git a/src/py/flwr/client/mod/utils_test.py b/src/py/flwr/client/mod/utils_test.py
index 7a1dd8988399..a5bbd0a0bb4d 100644
--- a/src/py/flwr/client/mod/utils_test.py
+++ b/src/py/flwr/client/mod/utils_test.py
@@ -104,7 +104,7 @@ def test_multiple_mods(self) -> None:
state = RecordSet()
state.metrics_records[METRIC] = MetricsRecord({COUNTER: 0.0})
- context = Context(state=state, run_config={})
+ context = Context(node_id=0, node_config={}, state=state, run_config={})
message = _get_dummy_flower_message()
# Execute
@@ -129,7 +129,7 @@ def test_filter(self) -> None:
# Prepare
footprint: List[str] = []
mock_app = make_mock_app("app", footprint)
- context = Context(state=RecordSet(), run_config={})
+ context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={})
message = _get_dummy_flower_message()
def filter_mod(
diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/node_state.py
index 2b090eba9720..393ca4564a35 100644
--- a/src/py/flwr/client/node_state.py
+++ b/src/py/flwr/client/node_state.py
@@ -17,7 +17,7 @@
from dataclasses import dataclass
from pathlib import Path
-from typing import Any, Dict, Optional
+from typing import Dict, Optional
from flwr.common import Context, RecordSet
from flwr.common.config import get_fused_config
@@ -35,10 +35,14 @@ class RunInfo:
class NodeState:
"""State of a node where client nodes execute runs."""
- def __init__(self, partition_id: Optional[int]) -> None:
- self._meta: Dict[str, Any] = {} # holds metadata about the node
+ def __init__(
+ self,
+ node_id: int,
+ node_config: Dict[str, str],
+ ) -> None:
+ self.node_id = node_id
+ self.node_config = node_config
self.run_infos: Dict[int, RunInfo] = {}
- self._partition_id = partition_id
def register_context(
self,
@@ -52,9 +56,10 @@ def register_context(
self.run_infos[run_id] = RunInfo(
initial_run_config=initial_run_config,
context=Context(
+ node_id=self.node_id,
+ node_config=self.node_config,
state=RecordSet(),
run_config=initial_run_config.copy(),
- partition_id=self._partition_id,
),
)
diff --git a/src/py/flwr/client/node_state_tests.py b/src/py/flwr/client/node_state_tests.py
index effd64a3ae7a..26ac4fea6855 100644
--- a/src/py/flwr/client/node_state_tests.py
+++ b/src/py/flwr/client/node_state_tests.py
@@ -41,7 +41,7 @@ def test_multirun_in_node_state() -> None:
expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"}
# NodeState
- node_state = NodeState(partition_id=None)
+ node_state = NodeState(node_id=0, node_config={})
for task in tasks:
run_id = task.run_id
diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py
index 0efa5731ae51..3e81969d898c 100644
--- a/src/py/flwr/client/rest_client/connection.py
+++ b/src/py/flwr/client/rest_client/connection.py
@@ -90,7 +90,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
Tuple[
Callable[[], Optional[Message]],
Callable[[Message], None],
- Optional[Callable[[], None]],
+ Optional[Callable[[], Optional[int]]],
Optional[Callable[[], None]],
Optional[Callable[[int], Run]],
]
@@ -237,19 +237,20 @@ def ping() -> None:
if not ping_stop_event.is_set():
ping_stop_event.wait(next_interval)
- def create_node() -> None:
+ def create_node() -> Optional[int]:
"""Set create_node."""
req = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
# Send the request
res = _request(req, CreateNodeResponse, PATH_CREATE_NODE)
if res is None:
- return
+ return None
# Remember the node and the ping-loop thread
nonlocal node, ping_thread
node = res.node
ping_thread = start_ping_loop(ping, ping_stop_event)
+ return node.node_id
def delete_node() -> None:
"""Set delete_node."""
diff --git a/src/py/flwr/client/supernode/app.py b/src/py/flwr/client/supernode/app.py
index 355a2a13a0e5..d61b986bc7af 100644
--- a/src/py/flwr/client/supernode/app.py
+++ b/src/py/flwr/client/supernode/app.py
@@ -29,7 +29,12 @@
from flwr.client.client_app import ClientApp, LoadClientAppError
from flwr.common import EventType, event
-from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir
+from flwr.common.config import (
+ get_flwr_dir,
+ get_project_config,
+ get_project_dir,
+ parse_config_args,
+)
from flwr.common.constant import (
TRANSPORT_TYPE_GRPC_ADAPTER,
TRANSPORT_TYPE_GRPC_RERE,
@@ -67,7 +72,7 @@ def run_supernode() -> None:
authentication_keys=authentication_keys,
max_retries=args.max_retries,
max_wait_time=args.max_wait_time,
- partition_id=args.partition_id,
+ node_config=parse_config_args(args.node_config),
flwr_dir=get_flwr_dir(args.flwr_dir),
)
@@ -93,6 +98,7 @@ def run_client_app() -> None:
_start_client_internal(
server_address=args.superlink,
+ node_config=parse_config_args(args.node_config),
load_client_app_fn=load_fn,
transport=args.transport,
root_certificates=root_certificates,
@@ -389,11 +395,11 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
help="The SuperNode's public key (as a path str) to enable authentication.",
)
parser.add_argument(
- "--partition-id",
- type=int,
- help="The data partition index associated with this SuperNode. Better suited "
- "for prototyping purposes where a SuperNode might only load a fraction of an "
- "artificially partitioned dataset (e.g. using `flwr-datasets`)",
+ "--node-config",
+ type=str,
+ help="A comma separated list of key/value pairs (separated by `=`) to "
+ "configure the SuperNode. "
+ "E.g. --node-config 'key1=\"value1\",partition-id=0,num-partitions=100'",
)
diff --git a/src/py/flwr/client/typing.py b/src/py/flwr/client/typing.py
index bf66a9082c77..9faed4bc7283 100644
--- a/src/py/flwr/client/typing.py
+++ b/src/py/flwr/client/typing.py
@@ -15,7 +15,7 @@
"""Custom types for Flower clients."""
-from typing import Callable, Optional
+from typing import Callable
from flwr.common import Context, Message
@@ -23,7 +23,7 @@
# Compatibility
ClientFn = Callable[[str], Client]
-ClientFnExt = Callable[[int, Optional[int]], Client]
+ClientFnExt = Callable[[Context], Client]
ClientAppCallable = Callable[[Message, Context], Message]
Mod = Callable[[Message, Context, ClientAppCallable], Message]
diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py
index 9770bdb4af2b..54d74353e4ed 100644
--- a/src/py/flwr/common/config.py
+++ b/src/py/flwr/common/config.py
@@ -121,16 +121,16 @@ def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, st
def parse_config_args(
- config_overrides: Optional[str],
+ config: Optional[str],
separator: str = ",",
) -> Dict[str, str]:
"""Parse separator separated list of key-value pairs separated by '='."""
overrides: Dict[str, str] = {}
- if config_overrides is None:
+ if config is None:
return overrides
- overrides_list = config_overrides.split(separator)
+ overrides_list = config.split(separator)
if (
len(overrides_list) == 1
and "=" not in overrides_list
diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py
index f14959589458..72256a62add7 100644
--- a/src/py/flwr/common/constant.py
+++ b/src/py/flwr/common/constant.py
@@ -57,6 +57,9 @@
FAB_CONFIG_FILE = "pyproject.toml"
FLWR_HOME = "FLWR_HOME"
+# Constants entries in Node config for Simulation
+PARTITION_ID_KEY = "partition-id"
+NUM_PARTITIONS_KEY = "num-partitions"
GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version"
GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit"
diff --git a/src/py/flwr/common/context.py b/src/py/flwr/common/context.py
index 8120723ce9e9..4da52ba44481 100644
--- a/src/py/flwr/common/context.py
+++ b/src/py/flwr/common/context.py
@@ -16,7 +16,7 @@
from dataclasses import dataclass
-from typing import Dict, Optional
+from typing import Dict
from .record import RecordSet
@@ -27,6 +27,11 @@ class Context:
Parameters
----------
+ node_id : int
+ The ID that identifies the node.
+ node_config : Dict[str, str]
+ A config (key/value mapping) unique to the node and independent of the
+ `run_config`. This config persists across all runs this node participates in.
state : RecordSet
Holds records added by the entity in a given run and that will stay local.
This means that the data it holds will never leave the system it's running from.
@@ -38,22 +43,21 @@ class Context:
A config (key/value mapping) held by the entity in a given run and that will
stay local. It can be used at any point during the lifecycle of this entity
(e.g. across multiple rounds)
- partition_id : Optional[int] (default: None)
- An index that specifies the data partition that the ClientApp using this Context
- object should make use of. Setting this attribute is better suited for
- simulation or proto typing setups.
"""
+ node_id: int
+ node_config: Dict[str, str]
state: RecordSet
- partition_id: Optional[int]
run_config: Dict[str, str]
- def __init__(
+ def __init__( # pylint: disable=too-many-arguments
self,
+ node_id: int,
+ node_config: Dict[str, str],
state: RecordSet,
run_config: Dict[str, str],
- partition_id: Optional[int] = None,
) -> None:
+ self.node_id = node_id
+ self.node_config = node_config
self.state = state
self.run_config = run_config
- self.partition_id = partition_id
diff --git a/src/py/flwr/server/compat/legacy_context.py b/src/py/flwr/server/compat/legacy_context.py
index 9e120c824103..ee09d79012dc 100644
--- a/src/py/flwr/server/compat/legacy_context.py
+++ b/src/py/flwr/server/compat/legacy_context.py
@@ -52,4 +52,4 @@ def __init__(
self.strategy = strategy
self.client_manager = client_manager
self.history = History()
- super().__init__(state, run_config={})
+ super().__init__(node_id=0, node_config={}, state=state, run_config={})
diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py
index b4697e99913f..4cc25feb7e0e 100644
--- a/src/py/flwr/server/run_serverapp.py
+++ b/src/py/flwr/server/run_serverapp.py
@@ -78,7 +78,9 @@ def _load() -> ServerApp:
server_app = _load()
# Initialize Context
- context = Context(state=RecordSet(), run_config=server_app_run_config)
+ context = Context(
+ node_id=0, node_config={}, state=RecordSet(), run_config=server_app_run_config
+ )
# Call ServerApp
server_app(driver=driver, context=context)
diff --git a/src/py/flwr/server/server_app_test.py b/src/py/flwr/server/server_app_test.py
index 7de8774d4c81..b0672b3202ed 100644
--- a/src/py/flwr/server/server_app_test.py
+++ b/src/py/flwr/server/server_app_test.py
@@ -29,7 +29,7 @@ def test_server_app_custom_mode() -> None:
# Prepare
app = ServerApp()
driver = MagicMock()
- context = Context(state=RecordSet(), run_config={})
+ context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={})
called = {"called": False}
diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py
index 0d2f4d193f0b..0ab29a234f88 100644
--- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py
+++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py
@@ -21,6 +21,7 @@
import ray
from flwr.client.client_app import ClientApp
+from flwr.common.constant import PARTITION_ID_KEY
from flwr.common.context import Context
from flwr.common.logger import log
from flwr.common.message import Message
@@ -168,7 +169,7 @@ def process_message(
Return output message and updated context.
"""
- partition_id = context.partition_id
+ partition_id = context.node_config[PARTITION_ID_KEY]
try:
# Submit a task to the pool
diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py
index 287983003f8c..a38cff96ceef 100644
--- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py
+++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py
@@ -23,6 +23,7 @@
from flwr.client import Client, NumPyClient
from flwr.client.client_app import ClientApp, LoadClientAppError
+from flwr.client.node_state import NodeState
from flwr.common import (
DEFAULT_TTL,
Config,
@@ -32,9 +33,9 @@
Message,
MessageTypeLegacy,
Metadata,
- RecordSet,
Scalar,
)
+from flwr.common.constant import PARTITION_ID_KEY
from flwr.common.object_ref import load_app
from flwr.common.recordset_compat import getpropertiesins_to_recordset
from flwr.server.superlink.fleet.vce.backend.backend import BackendConfig
@@ -53,9 +54,7 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]:
return {"result": result}
-def get_dummy_client(
- node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument
-) -> Client:
+def get_dummy_client(context: Context) -> Client: # pylint: disable=unused-argument
"""Return a DummyClient converted to Client type."""
return DummyClient().to_client()
@@ -103,12 +102,13 @@ def _create_message_and_context() -> Tuple[Message, Context, float]:
# Construct a Message
mult_factor = 2024
+ run_id = 0
getproperties_ins = GetPropertiesIns(config={"factor": mult_factor})
recordset = getpropertiesins_to_recordset(getproperties_ins)
message = Message(
content=recordset,
metadata=Metadata(
- run_id=0,
+ run_id=run_id,
message_id="",
group_id="",
src_node_id=0,
@@ -119,8 +119,10 @@ def _create_message_and_context() -> Tuple[Message, Context, float]:
),
)
- # Construct emtpy Context
- context = Context(state=RecordSet(), run_config={})
+ # Construct NodeState and retrieve context
+ node_state = NodeState(node_id=run_id, node_config={PARTITION_ID_KEY: str(0)})
+ node_state.register_context(run_id=run_id)
+ context = node_state.retrieve_context(run_id=run_id)
# Expected output
expected_output = pi * mult_factor
diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py
index 3c0b36e1ca3c..cd30c40167c5 100644
--- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py
+++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py
@@ -29,7 +29,12 @@
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
from flwr.client.node_state import NodeState
-from flwr.common.constant import PING_MAX_INTERVAL, ErrorCode
+from flwr.common.constant import (
+ NUM_PARTITIONS_KEY,
+ PARTITION_ID_KEY,
+ PING_MAX_INTERVAL,
+ ErrorCode,
+)
from flwr.common.logger import log
from flwr.common.message import Error
from flwr.common.object_ref import load_app
@@ -73,7 +78,7 @@ def worker(
task_ins: TaskIns = taskins_queue.get(timeout=1.0)
node_id = task_ins.task.consumer.node_id
- # Register and retrieve runstate
+ # Register and retrieve context
node_states[node_id].register_context(run_id=task_ins.run_id)
context = node_states[node_id].retrieve_context(run_id=task_ins.run_id)
@@ -283,8 +288,16 @@ def start_vce(
# Construct mapping of NodeStates
node_states: Dict[int, NodeState] = {}
+ # Number of unique partitions
+ num_partitions = len(set(nodes_mapping.values()))
for node_id, partition_id in nodes_mapping.items():
- node_states[node_id] = NodeState(partition_id=partition_id)
+ node_states[node_id] = NodeState(
+ node_id=node_id,
+ node_config={
+ PARTITION_ID_KEY: str(partition_id),
+ NUM_PARTITIONS_KEY: str(num_partitions),
+ },
+ )
# Load backend config
log(DEBUG, "Supported backends: %s", list(supported_backends.keys()))
diff --git a/src/py/flwr/simulation/app.py b/src/py/flwr/simulation/app.py
index 446b0bdeba38..973a9a89e652 100644
--- a/src/py/flwr/simulation/app.py
+++ b/src/py/flwr/simulation/app.py
@@ -111,9 +111,9 @@ def start_simulation(
Parameters
----------
client_fn : ClientFnExt
- A function creating Client instances. The function must have the signature
- `client_fn(node_id: int, partition_id: Optional[int]). It should return
- a single client instance of type Client. Note that the created client
+ A function creating `Client` instances. The function must have the signature
+ `client_fn(context: Context). It should return
+ a single client instance of type `Client`. Note that the created client
instances are ephemeral and will often be destroyed after a single method
invocation. Since client instances are not long-lived, they should not attempt
to carry state over method invocations. Any state required by the instance
@@ -327,6 +327,7 @@ def update_resources(f_stop: threading.Event) -> None:
client_fn=client_fn,
node_id=node_id,
partition_id=partition_id,
+ num_partitions=num_clients,
actor_pool=pool,
)
initialized_server.client_manager().register(client=client_proxy)
diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py
index 31bc22c84bd5..895272c2fd79 100644
--- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py
+++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py
@@ -24,7 +24,12 @@
from flwr.client.client_app import ClientApp
from flwr.client.node_state import NodeState
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
-from flwr.common.constant import MessageType, MessageTypeLegacy
+from flwr.common.constant import (
+ NUM_PARTITIONS_KEY,
+ PARTITION_ID_KEY,
+ MessageType,
+ MessageTypeLegacy,
+)
from flwr.common.logger import log
from flwr.common.recordset_compat import (
evaluateins_to_recordset,
@@ -43,11 +48,12 @@
class RayActorClientProxy(ClientProxy):
"""Flower client proxy which delegates work using Ray."""
- def __init__(
+ def __init__( # pylint: disable=too-many-arguments
self,
client_fn: ClientFnExt,
node_id: int,
partition_id: int,
+ num_partitions: int,
actor_pool: VirtualClientEngineActorPool,
):
super().__init__(cid=str(node_id))
@@ -59,7 +65,13 @@ def _load_app() -> ClientApp:
self.app_fn = _load_app
self.actor_pool = actor_pool
- self.proxy_state = NodeState(partition_id=self.partition_id)
+ self.proxy_state = NodeState(
+ node_id=node_id,
+ node_config={
+ PARTITION_ID_KEY: str(partition_id),
+ NUM_PARTITIONS_KEY: str(num_partitions),
+ },
+ )
def _submit_job(self, message: Message, timeout: Optional[float]) -> Message:
"""Sumbit a message to the ActorPool."""
@@ -68,18 +80,19 @@ def _submit_job(self, message: Message, timeout: Optional[float]) -> Message:
# Register state
self.proxy_state.register_context(run_id=run_id)
- # Retrieve state
- state = self.proxy_state.retrieve_context(run_id=run_id)
+ # Retrieve context
+ context = self.proxy_state.retrieve_context(run_id=run_id)
+ partition_id_str = context.node_config[PARTITION_ID_KEY]
try:
self.actor_pool.submit_client_job(
- lambda a, a_fn, mssg, partition_id, state: a.run.remote(
- a_fn, mssg, partition_id, state
+ lambda a, a_fn, mssg, partition_id, context: a.run.remote(
+ a_fn, mssg, partition_id, context
),
- (self.app_fn, message, str(self.partition_id), state),
+ (self.app_fn, message, partition_id_str, context),
)
out_mssg, updated_context = self.actor_pool.get_client_result(
- str(self.partition_id), timeout
+ partition_id_str, timeout
)
# Update state
diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py
index 83f6cfe05313..62e0cfd61c99 100644
--- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py
+++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py
@@ -17,12 +17,13 @@
from math import pi
from random import shuffle
-from typing import Dict, List, Optional, Tuple, Type
+from typing import Dict, List, Tuple, Type
import ray
from flwr.client import Client, NumPyClient
from flwr.client.client_app import ClientApp
+from flwr.client.node_state import NodeState
from flwr.common import (
DEFAULT_TTL,
Config,
@@ -31,15 +32,18 @@
Message,
MessageTypeLegacy,
Metadata,
- RecordSet,
Scalar,
)
+from flwr.common.constant import NUM_PARTITIONS_KEY, PARTITION_ID_KEY
from flwr.common.recordset_compat import (
getpropertiesins_to_recordset,
recordset_to_getpropertiesres,
)
from flwr.common.recordset_compat_test import _get_valid_getpropertiesins
-from flwr.simulation.app import _create_node_id_to_partition_mapping
+from flwr.simulation.app import (
+ NodeToPartitionMapping,
+ _create_node_id_to_partition_mapping,
+)
from flwr.simulation.ray_transport.ray_actor import (
ClientAppActor,
VirtualClientEngineActor,
@@ -65,16 +69,16 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]:
return {"result": result}
-def get_dummy_client(
- node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument
-) -> Client:
+def get_dummy_client(context: Context) -> Client:
"""Return a DummyClient converted to Client type."""
- return DummyClient(node_id).to_client()
+ return DummyClient(context.node_id).to_client()
def prep(
actor_type: Type[VirtualClientEngineActor] = ClientAppActor,
-) -> Tuple[List[RayActorClientProxy], VirtualClientEngineActorPool]: # pragma: no cover
+) -> Tuple[
+ List[RayActorClientProxy], VirtualClientEngineActorPool, NodeToPartitionMapping
+]: # pragma: no cover
"""Prepare ClientProxies and pool for tests."""
client_resources = {"num_cpus": 1, "num_gpus": 0.0}
@@ -96,12 +100,13 @@ def create_actor_fn() -> Type[VirtualClientEngineActor]:
client_fn=get_dummy_client,
node_id=node_id,
partition_id=partition_id,
+ num_partitions=num_proxies,
actor_pool=pool,
)
for node_id, partition_id in mapping.items()
]
- return proxies, pool
+ return proxies, pool, mapping
def test_cid_consistency_one_at_a_time() -> None:
@@ -109,7 +114,7 @@ def test_cid_consistency_one_at_a_time() -> None:
Submit one job and waits for completion. Then submits the next and so on
"""
- proxies, _ = prep()
+ proxies, _, _ = prep()
getproperties_ins = _get_valid_getpropertiesins()
recordset = getpropertiesins_to_recordset(getproperties_ins)
@@ -139,7 +144,7 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None:
All jobs are submitted at the same time. Then fetched one at a time. This also tests
NodeState (at each Proxy) and RunState basic functionality.
"""
- proxies, _ = prep()
+ proxies, _, _ = prep()
run_id = 0
getproperties_ins = _get_valid_getpropertiesins()
@@ -186,9 +191,19 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None:
def test_cid_consistency_without_proxies() -> None:
"""Test cid consistency of jobs submitted/retrieved to/from pool w/o ClientProxy."""
- proxies, pool = prep()
- num_clients = len(proxies)
- node_ids = list(range(num_clients))
+ _, pool, mapping = prep()
+ node_ids = list(mapping.keys())
+
+ # register node states
+ node_states: Dict[int, NodeState] = {}
+ for node_id, partition_id in mapping.items():
+ node_states[node_id] = NodeState(
+ node_id=node_id,
+ node_config={
+ PARTITION_ID_KEY: str(partition_id),
+ NUM_PARTITIONS_KEY: str(len(node_ids)),
+ },
+ )
getproperties_ins = _get_valid_getpropertiesins()
recordset = getpropertiesins_to_recordset(getproperties_ins)
@@ -198,11 +213,12 @@ def _load_app() -> ClientApp:
# submit all jobs (collect later)
shuffle(node_ids)
+ run_id = 0
for node_id in node_ids:
message = Message(
content=recordset,
metadata=Metadata(
- run_id=0,
+ run_id=run_id,
message_id="",
group_id=str(0),
src_node_id=0,
@@ -212,20 +228,20 @@ def _load_app() -> ClientApp:
message_type=MessageTypeLegacy.GET_PROPERTIES,
),
)
+ # register and retrieve context
+ node_states[node_id].register_context(run_id=run_id)
+ context = node_states[node_id].retrieve_context(run_id=run_id)
+ partition_id_str = context.node_config[PARTITION_ID_KEY]
pool.submit_client_job(
lambda a, c_fn, j_fn, nid_, state: a.run.remote(c_fn, j_fn, nid_, state),
- (
- _load_app,
- message,
- str(node_id),
- Context(state=RecordSet(), run_config={}, partition_id=node_id),
- ),
+ (_load_app, message, partition_id_str, context),
)
# fetch results one at a time
shuffle(node_ids)
for node_id in node_ids:
- message_out, _ = pool.get_client_result(str(node_id), timeout=None)
+ partition_id_str = str(mapping[node_id])
+ message_out, _ = pool.get_client_result(partition_id_str, timeout=None)
res = recordset_to_getpropertiesres(message_out.content)
assert node_id * pi == res.properties["result"]
diff --git a/src/py/flwr/superexec/app.py b/src/py/flwr/superexec/app.py
index a90422bd635f..35ca8fbb0b87 100644
--- a/src/py/flwr/superexec/app.py
+++ b/src/py/flwr/superexec/app.py
@@ -133,11 +133,11 @@ def _try_obtain_certificates(
return None
# Check if certificates are provided
if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile:
- if not Path.is_file(args.ssl_ca_certfile):
+ if not Path(args.ssl_ca_certfile).is_file():
sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.")
- if not Path.is_file(args.ssl_certfile):
+ if not Path(args.ssl_certfile).is_file():
sys.exit("Path argument `--ssl-certfile` does not point to a file.")
- if not Path.is_file(args.ssl_keyfile):
+ if not Path(args.ssl_keyfile).is_file():
sys.exit("Path argument `--ssl-keyfile` does not point to a file.")
certificates = (
Path(args.ssl_ca_certfile).read_bytes(), # CA certificate