diff --git a/datasets/README.md b/datasets/README.md index 1d8014d57ea3..50fc67376ae4 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -42,11 +42,12 @@ Create **custom partitioning schemes** or choose from the **implemented [partiti * IID partitioning `IidPartitioner(num_partitions)` * Dirichlet partitioning `DirichletPartitioner(num_partitions, partition_by, alpha)` * InnerDirichlet partitioning `InnerDirichletPartitioner(partition_sizes, partition_by, alpha)` -* Natural ID partitioner `NaturalIdPartitioner(partition_by)` -* Size partitioner (the abstract base class for the partitioners dictating the division based the number of samples) `SizePartitioner` -* Linear partitioner `LinearPartitioner(num_partitions)` -* Square partitioner `SquarePartitioner(num_partitions)` -* Exponential partitioner `ExponentialPartitioner(num_partitions)` +* Pathological partitioning `PathologicalPartitioner(num_partitions, partition_by, num_classes_per_partition, class_assignment_mode)` +* Natural ID partitioning `NaturalIdPartitioner(partition_by)` +* Size based partitioning (the abstract base class for the partitioners dictating the division based the number of samples) `SizePartitioner` +* Linear partitioning `LinearPartitioner(num_partitions)` +* Square partitioning `SquarePartitioner(num_partitions)` +* Exponential partitioning `ExponentialPartitioner(num_partitions)` * more to come in the future releases (contributions are welcome).

Comparison of partitioning schemes. diff --git a/datasets/doc/source/index.rst b/datasets/doc/source/index.rst index bdcea7650bbc..fcc7920711bf 100644 --- a/datasets/doc/source/index.rst +++ b/datasets/doc/source/index.rst @@ -94,6 +94,7 @@ Here are a few of the ``Partitioner`` s that are available: (for a full list see * IID partitioning ``IidPartitioner(num_partitions)`` * Dirichlet partitioning ``DirichletPartitioner(num_partitions, partition_by, alpha)`` * InnerDirichlet partitioning ``InnerDirichletPartitioner(partition_sizes, partition_by, alpha)`` +* PathologicalPartitioner ``PathologicalPartitioner(num_partitions, partition_by, num_classes_per_partition, class_assignment_mode)`` * Natural ID partitioner ``NaturalIdPartitioner(partition_by)`` * Size partitioner (the abstract base class for the partitioners dictating the division based the number of samples) ``SizePartitioner`` * Linear partitioner ``LinearPartitioner(num_partitions)`` diff --git a/datasets/flwr_datasets/partitioner/__init__.py b/datasets/flwr_datasets/partitioner/__init__.py index 1fc00ed90323..0c75dbce387a 100644 --- a/datasets/flwr_datasets/partitioner/__init__.py +++ b/datasets/flwr_datasets/partitioner/__init__.py @@ -22,6 +22,7 @@ from .linear_partitioner import LinearPartitioner from .natural_id_partitioner import NaturalIdPartitioner from .partitioner import Partitioner +from .pathological_partitioner import PathologicalPartitioner from .shard_partitioner import ShardPartitioner from .size_partitioner import SizePartitioner from .square_partitioner import SquarePartitioner @@ -34,6 +35,7 @@ "LinearPartitioner", "NaturalIdPartitioner", "Partitioner", + "PathologicalPartitioner", "ShardPartitioner", "SizePartitioner", "SquarePartitioner", diff --git a/datasets/flwr_datasets/partitioner/pathological_partitioner.py b/datasets/flwr_datasets/partitioner/pathological_partitioner.py new file mode 100644 index 000000000000..1ee60d283044 --- /dev/null +++ b/datasets/flwr_datasets/partitioner/pathological_partitioner.py @@ -0,0 +1,305 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Pathological partitioner class that works with Hugging Face Datasets.""" + + +import warnings +from typing import Any, Dict, List, Literal, Optional + +import numpy as np + +import datasets +from flwr_datasets.common.typing import NDArray +from flwr_datasets.partitioner.partitioner import Partitioner + + +# pylint: disable=too-many-arguments, too-many-instance-attributes +class PathologicalPartitioner(Partitioner): + """Partition dataset such that each partition has a chosen number of classes. + + Implementation based on Federated Learning on Non-IID Data Silos: An Experimental + Study https://arxiv.org/pdf/2102.02079. + + The algorithm firstly determines which classe will be assigned to which partitions. + For each partition `num_classes_per_partition` are sampled in a way chosen in + `class_assignment_mode`. Given the information about the required classes for each + partition, it is determined into how many parts the samples corresponding to this + label should be divided. Such division is performed for each class. + + Parameters + ---------- + num_partitions : int + The total number of partitions that the data will be divided into. + partition_by : str + Column name of the labels (targets) based on which partitioning works. + num_classes_per_partition: int + The (exact) number of unique classes that each partition will have. + class_assignment_mode: Literal["random", "deterministic", "first-deterministic"] + The way how the classes are assigned to the partitions. The default is "random". + The possible values are: + + - "random": Randomly assign classes to the partitions. For each partition choose + the `num_classes_per_partition` classes without replacement. + - "first-deterministic": Assign the first class for each partition in a + deterministic way (class id is the partition_id % num_unique_classes). + The rest of the classes are assigned randomly. In case the number of + partitions is smaller than the number of unique classes, not all classes will + be used in the first iteration, otherwise all the classes will be used (such + it will be present in at least one partition). + - "deterministic": Assign all the classes to the partitions in a deterministic + way. Classes are assigned based on the formula: partion_id has classes + identified by the index: (partition_id + i) % num_unique_classes + where i in {0, ..., num_classes_per_partition}. So, partition 0 will have + classes 0, 1, 2, ..., `num_classes_per_partition`-1, partition 1 will have + classes 1, 2, 3, ...,`num_classes_per_partition`, .... + + The list representing the unique lables is sorted in ascending order. In case + of numbers starting from zero the class id corresponds to the number itself. + `class_assignment_mode="first-deterministic"` was used in the orginal paper, + here we provide the option to use the other modes as well. + shuffle: bool + Whether to randomize the order of samples. Shuffling applied after the + samples assignment to partitions. + seed: int + Seed used for dataset shuffling. It has no effect if `shuffle` is False. + + Examples + -------- + In order to mimic the original behavior of the paper follow the setup below + (the `class_assignment_mode="first-deterministic"`): + + >>> from flwr_datasets.partitioner import PathologicalPartitioner + >>> from flwr_datasets import FederatedDataset + >>> + >>> partitioner = PathologicalPartitioner( + >>> num_partitions=10, + >>> partition_by="label", + >>> num_classes_per_partition=2, + >>> class_assignment_mode="first-deterministic" + >>> ) + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) + >>> partition = fds.load_partition(0) + """ + + def __init__( + self, + num_partitions: int, + partition_by: str, + num_classes_per_partition: int, + class_assignment_mode: Literal[ + "random", "deterministic", "first-deterministic" + ] = "random", + shuffle: bool = True, + seed: Optional[int] = 42, + ) -> None: + super().__init__() + self._num_partitions = num_partitions + self._partition_by = partition_by + self._num_classes_per_partition = num_classes_per_partition + self._class_assignment_mode = class_assignment_mode + self._shuffle = shuffle + self._seed = seed + self._rng = np.random.default_rng(seed=self._seed) + + # Utility attributes + self._partition_id_to_indices: Dict[int, List[int]] = {} + self._partition_id_to_unique_labels: Dict[int, List[Any]] = { + pid: [] for pid in range(self._num_partitions) + } + self._unique_labels: List[Any] = [] + # Count in how many partitions the label is used + self._unique_label_to_times_used_counter: Dict[Any, int] = {} + self._partition_id_to_indices_determined = False + + def load_partition(self, partition_id: int) -> datasets.Dataset: + """Load a partition based on the partition index. + + Parameters + ---------- + partition_id : int + The index that corresponds to the requested partition. + + Returns + ------- + dataset_partition : Dataset + Single partition of a dataset. + """ + # The partitioning is done lazily - only when the first partition is + # requested. Only the first call creates the indices assignments for all the + # partition indices. + self._check_num_partitions_correctness_if_needed() + self._determine_partition_id_to_indices_if_needed() + return self.dataset.select(self._partition_id_to_indices[partition_id]) + + @property + def num_partitions(self) -> int: + """Total number of partitions.""" + self._check_num_partitions_correctness_if_needed() + self._determine_partition_id_to_indices_if_needed() + return self._num_partitions + + def _determine_partition_id_to_indices_if_needed(self) -> None: + """Create an assignment of indices to the partition indices.""" + if self._partition_id_to_indices_determined: + return + self._determine_partition_id_to_unique_labels() + assert self._unique_labels is not None + self._count_partitions_having_each_unique_label() + + labels = np.asarray(self.dataset[self._partition_by]) + self._check_correctness_of_unique_label_to_times_used_counter(labels) + for partition_id in range(self._num_partitions): + self._partition_id_to_indices[partition_id] = [] + + unused_labels = [] + for unique_label in self._unique_labels: + if self._unique_label_to_times_used_counter[unique_label] == 0: + unused_labels.append(unique_label) + continue + # Get the indices in the original dataset where the y == unique_label + unique_label_to_indices = np.where(labels == unique_label)[0] + + split_unique_labels_to_indices = np.array_split( + unique_label_to_indices, + self._unique_label_to_times_used_counter[unique_label], + ) + + split_index = 0 + for partition_id in range(self._num_partitions): + if unique_label in self._partition_id_to_unique_labels[partition_id]: + self._partition_id_to_indices[partition_id].extend( + split_unique_labels_to_indices[split_index] + ) + split_index += 1 + + if len(unused_labels) >= 1: + warnings.warn( + f"Classes: {unused_labels} will NOT be used due to the chosen " + f"configuration. If it is undesired behavior consider setting" + f" 'first_class_deterministic_assignment=True' which in case when" + f" the number of classes is smaller than the number of partitions will " + f"utilize all the classes for the created partitions.", + stacklevel=1, + ) + if self._shuffle: + for indices in self._partition_id_to_indices.values(): + # In place shuffling + self._rng.shuffle(indices) + + self._partition_id_to_indices_determined = True + + def _check_num_partitions_correctness_if_needed(self) -> None: + """Test num_partitions when the dataset is given (in load_partition).""" + if not self._partition_id_to_indices_determined: + if self._num_partitions > self.dataset.num_rows: + raise ValueError( + "The number of partitions needs to be smaller than the number of " + "samples in the dataset." + ) + + def _determine_partition_id_to_unique_labels(self) -> None: + """Determine the assignment of unique labels to the partitions.""" + self._unique_labels = sorted(self.dataset.unique(self._partition_by)) + num_unique_classes = len(self._unique_labels) + + if self._num_classes_per_partition > num_unique_classes: + raise ValueError( + f"The specified `num_classes_per_partition`" + f"={self._num_classes_per_partition} is greater than the number " + f"of unique classes in the given dataset={num_unique_classes}. " + f"Reduce the `num_classes_per_partition` or make use different dataset " + f"to apply this partitioning." + ) + if self._class_assignment_mode == "first-deterministic": + # if self._first_class_deterministic_assignment: + for partition_id in range(self._num_partitions): + label = partition_id % num_unique_classes + self._partition_id_to_unique_labels[partition_id].append(label) + + while ( + len(self._partition_id_to_unique_labels[partition_id]) + < self._num_classes_per_partition + ): + label = self._rng.choice(self._unique_labels, size=1)[0] + if label not in self._partition_id_to_unique_labels[partition_id]: + self._partition_id_to_unique_labels[partition_id].append(label) + elif self._class_assignment_mode == "deterministic": + for partition_id in range(self._num_partitions): + labels = [] + for i in range(self._num_classes_per_partition): + label = self._unique_labels[ + (partition_id + i) % len(self._unique_labels) + ] + labels.append(label) + self._partition_id_to_unique_labels[partition_id] = labels + elif self._class_assignment_mode == "random": + for partition_id in range(self._num_partitions): + labels = self._rng.choice( + self._unique_labels, + size=self._num_classes_per_partition, + replace=False, + ).tolist() + self._partition_id_to_unique_labels[partition_id] = labels + else: + raise ValueError( + f"The supported class_assignment_mode are: 'random', 'deterministic', " + f"'first-deterministic'. You provided: {self._class_assignment_mode}." + ) + + def _count_partitions_having_each_unique_label(self) -> None: + """Count the number of partitions that have each unique label. + + This computation is based on the assigment of the label to the partition_id in + the `_determine_partition_id_to_unique_labels` method. + Given: + * partition 0 has only labels: 0,1 (not necessarily just two samples it can have + many samples but either from 0 or 1) + * partition 1 has only labels: 1, 2 (same count note as above) + * and there are only two partitions then the following will be computed: + { + 0: 1, + 1: 2, + 2: 1 + } + """ + for unique_label in self._unique_labels: + self._unique_label_to_times_used_counter[unique_label] = 0 + for unique_labels in self._partition_id_to_unique_labels.values(): + for unique_label in unique_labels: + self._unique_label_to_times_used_counter[unique_label] += 1 + + def _check_correctness_of_unique_label_to_times_used_counter( + self, labels: NDArray + ) -> None: + """Check if partitioning is possible given the presence requirements. + + The number of times the label can be used must be smaller or equal to the number + of times that the label is present in the dataset. + """ + for unique_label in self._unique_labels: + num_unique = np.sum(labels == unique_label) + if self._unique_label_to_times_used_counter[unique_label] > num_unique: + raise ValueError( + f"Label: {unique_label} is needed to be assigned to more " + f"partitions " + f"({self._unique_label_to_times_used_counter[unique_label]})" + f" than there are samples (corresponding to this label) in the " + f"dataset ({num_unique}). Please decrease the `num_partitions`, " + f"`num_classes_per_partition` to avoid this situation, " + f"or try `class_assigment_mode='deterministic'` to create a more " + f"even distribution of classes along the partitions. " + f"Alternatively use a different dataset if you can not adjust" + f" the any of these parameters." + ) diff --git a/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py b/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py new file mode 100644 index 000000000000..151b7e14659c --- /dev/null +++ b/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py @@ -0,0 +1,262 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test cases for PathologicalPartitioner.""" + + +import unittest +from typing import Dict + +import numpy as np +from parameterized import parameterized + +import datasets +from datasets import Dataset +from flwr_datasets.partitioner.pathological_partitioner import PathologicalPartitioner + + +def _dummy_dataset_setup( + num_samples: int, partition_by: str, num_unique_classes: int +) -> Dataset: + """Create a dummy dataset for testing.""" + data = { + partition_by: np.tile( + np.arange(num_unique_classes), num_samples // num_unique_classes + 1 + )[:num_samples], + "features": np.random.randn(num_samples), + } + return Dataset.from_dict(data) + + +def _dummy_heterogeneous_dataset_setup( + num_samples: int, partition_by: str, num_unique_classes: int +) -> Dataset: + """Create a dummy dataset for testing.""" + data = { + partition_by: np.tile( + np.arange(num_unique_classes), num_samples // num_unique_classes + 1 + )[:num_samples], + "features": np.random.randn(num_samples), + } + return Dataset.from_dict(data) + + +class TestClassConstrainedPartitioner(unittest.TestCase): + """Unit tests for PathologicalPartitioner.""" + + @parameterized.expand( # type: ignore + [ + # num_partition, num_classes_per_partition, num_samples, total_classes + (3, 1, 60, 3), # Single class per partition scenario + (5, 2, 100, 5), + (5, 2, 100, 10), + (4, 3, 120, 6), + ] + ) + def test_correct_num_classes_when_partitioned( + self, + num_partitions: int, + num_classes_per_partition: int, + num_samples: int, + num_unique_classes: int, + ) -> None: + """Test correct number of unique classes.""" + dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes) + partitioner = PathologicalPartitioner( + num_partitions=num_partitions, + partition_by="labels", + num_classes_per_partition=num_classes_per_partition, + ) + partitioner.dataset = dataset + partitions: Dict[int, Dataset] = { + pid: partitioner.load_partition(pid) for pid in range(num_partitions) + } + unique_classes_per_partition = { + pid: np.unique(partition["labels"]) for pid, partition in partitions.items() + } + + for unique_classes in unique_classes_per_partition.values(): + self.assertEqual(num_classes_per_partition, len(unique_classes)) + + def test_first_class_deterministic_assignment(self) -> None: + """Test deterministic assignment of first classes to partitions. + + Test if all the classes are used (which has to be the case, given num_partitions + >= than the number of unique classes). + """ + dataset = _dummy_dataset_setup(100, "labels", 10) + partitioner = PathologicalPartitioner( + num_partitions=10, + partition_by="labels", + num_classes_per_partition=2, + class_assignment_mode="first-deterministic", + ) + partitioner.dataset = dataset + partitioner.load_partition(0) + expected_classes = set(range(10)) + actual_classes = set() + for pid in range(10): + partition = partitioner.load_partition(pid) + actual_classes.update(np.unique(partition["labels"])) + self.assertEqual(expected_classes, actual_classes) + + @parameterized.expand( + [ # type: ignore + # num_partitions, num_classes_per_partition, num_samples, num_unique_classes + (4, 2, 80, 8), + (10, 2, 100, 10), + ] + ) + def test_deterministic_class_assignment( + self, num_partitions, num_classes_per_partition, num_samples, num_unique_classes + ): + """Test deterministic assignment of classes to partitions.""" + dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes) + partitioner = PathologicalPartitioner( + num_partitions=num_partitions, + partition_by="labels", + num_classes_per_partition=num_classes_per_partition, + class_assignment_mode="deterministic", + ) + partitioner.dataset = dataset + partitions = { + pid: partitioner.load_partition(pid) for pid in range(num_partitions) + } + + # Verify each partition has the expected classes, order does not matter + for pid, partition in partitions.items(): + expected_labels = sorted( + [ + (pid + i) % num_unique_classes + for i in range(num_classes_per_partition) + ] + ) + actual_labels = sorted(np.unique(partition["labels"])) + self.assertTrue( + np.array_equal(expected_labels, actual_labels), + f"Partition {pid} does not have the expected labels: " + f"{expected_labels} but instead {actual_labels}.", + ) + + @parameterized.expand( + [ # type: ignore + # num_partitions, num_classes_per_partition, num_samples, num_unique_classes + (10, 3, 20, 3), + ] + ) + def test_too_many_partitions_for_a_class( + self, num_partitions, num_classes_per_partition, num_samples, num_unique_classes + ) -> None: + """Test too many partitions for the number of samples in a class.""" + dataset_1 = _dummy_dataset_setup( + num_samples // 2, "labels", num_unique_classes - 1 + ) + # Create a skewed part of the dataset for the last label + data = { + "labels": np.array([num_unique_classes - 1] * (num_samples // 2)), + "features": np.random.randn(num_samples // 2), + } + dataset_2 = Dataset.from_dict(data) + dataset = datasets.concatenate_datasets([dataset_1, dataset_2]) + + partitioner = PathologicalPartitioner( + num_partitions=num_partitions, + partition_by="labels", + num_classes_per_partition=num_classes_per_partition, + class_assignment_mode="random", + ) + partitioner.dataset = dataset + + with self.assertRaises(ValueError) as context: + _ = partitioner.load_partition(0) + self.assertEqual( + str(context.exception), + "Label: 0 is needed to be assigned to more partitions (10) than there are " + "samples (corresponding to this label) in the dataset (5). " + "Please decrease the `num_partitions`, `num_classes_per_partition` to " + "avoid this situation, or try `class_assigment_mode='deterministic'` to " + "create a more even distribution of classes along the partitions. " + "Alternatively use a different dataset if you can not adjust the any of " + "these parameters.", + ) + + @parameterized.expand( # type: ignore + [ + # num_partitions, num_classes_per_partition, num_samples, num_unique_classes + (10, 11, 100, 10), # 11 > 10 + (5, 11, 100, 10), # 11 > 10 + (10, 20, 100, 5), # 20 > 5 + ] + ) + def test_more_classes_per_partition_than_num_unique_classes_in_dataset_raises( + self, + num_partitions: int, + num_classes_per_partition: int, + num_samples: int, + num_unique_classes: int, + ) -> None: + """Test more num_classes_per_partition > num_unique_classes in the dataset.""" + dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes) + with self.assertRaises(ValueError) as context: + partitioner = PathologicalPartitioner( + num_partitions=num_partitions, + partition_by="labels", + num_classes_per_partition=num_classes_per_partition, + ) + partitioner.dataset = dataset + partitioner.load_partition(0) + self.assertEqual( + str(context.exception), + "The specified " + f"`num_classes_per_partition`={num_classes_per_partition} is " + f"greater than the number of unique classes in the given " + f"dataset={len(dataset.unique('labels'))}. Reduce the " + f"`num_classes_per_partition` or make use different dataset " + f"to apply this partitioning.", + ) + + @parameterized.expand( # type: ignore + [ + # num_classes_per_partition should be irrelevant since the exception should + # be raised at the very beginning + # num_partitions, num_classes_per_partition, num_samples + (10, 2, 5), + (10, 10, 5), + (100, 10, 99), + ] + ) + def test_more_partitions_than_samples_raises( + self, num_partitions: int, num_classes_per_partition: int, num_samples: int + ) -> None: + """Test if generation of more partitions that there are samples raises.""" + # The number of unique classes in the dataset should be irrelevant since the + # exception should be raised at the very beginning + dataset = _dummy_dataset_setup(num_samples, "labels", num_unique_classes=5) + with self.assertRaises(ValueError) as context: + partitioner = PathologicalPartitioner( + num_partitions=num_partitions, + partition_by="labels", + num_classes_per_partition=num_classes_per_partition, + ) + partitioner.dataset = dataset + partitioner.load_partition(0) + self.assertEqual( + str(context.exception), + "The number of partitions needs to be smaller than the number of " + "samples in the dataset.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/e2e/bare-client-auth/client.py b/e2e/bare-client-auth/client.py index e82f17088bd9..c7b0d59b8ea5 100644 --- a/e2e/bare-client-auth/client.py +++ b/e2e/bare-client-auth/client.py @@ -1,13 +1,14 @@ import numpy as np -import flwr as fl +from flwr.client import ClientApp, NumPyClient +from flwr.common import Context model_params = np.array([1]) objective = 5 # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model_params @@ -23,10 +24,10 @@ def evaluate(self, parameters, config): return loss, 1, {"accuracy": accuracy} -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) diff --git a/e2e/bare-https/client.py b/e2e/bare-https/client.py index 8f5c1412fd01..4a682af3aec3 100644 --- a/e2e/bare-https/client.py +++ b/e2e/bare-https/client.py @@ -2,14 +2,15 @@ import numpy as np -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context model_params = np.array([1]) objective = 5 # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model_params @@ -25,17 +26,17 @@ def evaluate(self, parameters, config): return loss, 1, {"accuracy": accuracy} -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient().to_client(), root_certificates=Path("certificates/ca.crt").read_bytes(), diff --git a/e2e/bare/client.py b/e2e/bare/client.py index 402d775ac3a9..943e60d5db9f 100644 --- a/e2e/bare/client.py +++ b/e2e/bare/client.py @@ -2,8 +2,8 @@ import numpy as np -import flwr as fl -from flwr.common import ConfigsRecord +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import ConfigsRecord, Context SUBSET_SIZE = 1000 STATE_VAR = "timestamp" @@ -14,7 +14,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model_params @@ -51,16 +51,14 @@ def evaluate(self, parameters, config): ) -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() - ) + start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/docker/client.py b/e2e/docker/client.py index 8451b810416b..44313c7c3af6 100644 --- a/e2e/docker/client.py +++ b/e2e/docker/client.py @@ -9,6 +9,7 @@ from torchvision.transforms import Compose, Normalize, ToTensor from flwr.client import ClientApp, NumPyClient +from flwr.common import Context # ############################################################################# # 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader @@ -122,7 +123,7 @@ def evaluate(self, parameters, config): return loss, len(testloader.dataset), {"accuracy": accuracy} -def client_fn(cid: str): +def client_fn(context: Context): """Create and return an instance of Flower `Client`.""" return FlowerClient().to_client() diff --git a/e2e/framework-fastai/client.py b/e2e/framework-fastai/client.py index 1d98a1134941..161b27b5a548 100644 --- a/e2e/framework-fastai/client.py +++ b/e2e/framework-fastai/client.py @@ -5,7 +5,8 @@ import torch from fastai.vision.all import * -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context warnings.filterwarnings("ignore", category=UserWarning) @@ -29,7 +30,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return [val.cpu().numpy() for _, val in learn.model.state_dict().items()] @@ -49,18 +50,18 @@ def evaluate(self, parameters, config): return loss, len(dls.valid), {"accuracy": 1 - error_rate} -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient().to_client(), ) diff --git a/e2e/framework-jax/client.py b/e2e/framework-jax/client.py index 347a005d923a..c9ff67b3e38e 100644 --- a/e2e/framework-jax/client.py +++ b/e2e/framework-jax/client.py @@ -6,7 +6,8 @@ import jax_training import numpy as np -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context # Load data and determine model shape train_x, train_y, test_x, test_y = jax_training.load_data() @@ -14,7 +15,7 @@ model_shape = train_x.shape[1:] -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def __init__(self): self.params = jax_training.load_model(model_shape) @@ -48,16 +49,14 @@ def evaluate( return float(loss), num_examples, {"loss": float(loss)} -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() - ) + start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/framework-opacus/client.py b/e2e/framework-opacus/client.py index c9ebe319063a..167fa4584e37 100644 --- a/e2e/framework-opacus/client.py +++ b/e2e/framework-opacus/client.py @@ -9,7 +9,8 @@ from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10 -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context # Define parameters. PARAMS = { @@ -95,7 +96,7 @@ def load_data(): # Define Flower client. -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def __init__(self, model) -> None: super().__init__() # Create a privacy engine which will add DP and keep track of the privacy budget. @@ -139,16 +140,16 @@ def evaluate(self, parameters, config): return float(loss), len(testloader), {"accuracy": float(accuracy)} -def client_fn(cid): +def client_fn(context: Context): model = Net() return FlowerClient(model).to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient(model).to_client() ) diff --git a/e2e/framework-pandas/client.py b/e2e/framework-pandas/client.py index 19e15f5a3b11..0c3300e1dd3f 100644 --- a/e2e/framework-pandas/client.py +++ b/e2e/framework-pandas/client.py @@ -3,7 +3,8 @@ import numpy as np import pandas as pd -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context df = pd.read_csv("./data/client.csv") @@ -16,7 +17,7 @@ def compute_hist(df: pd.DataFrame, col_name: str) -> np.ndarray: # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def fit( self, parameters: List[np.ndarray], config: Dict[str, str] ) -> Tuple[List[np.ndarray], int, Dict]: @@ -32,17 +33,17 @@ def fit( ) -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient().to_client(), ) diff --git a/e2e/framework-pytorch-lightning/client.py b/e2e/framework-pytorch-lightning/client.py index fdd55b3dc344..bf291a1ca2c5 100644 --- a/e2e/framework-pytorch-lightning/client.py +++ b/e2e/framework-pytorch-lightning/client.py @@ -4,10 +4,11 @@ import pytorch_lightning as pl import torch -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def __init__(self, model, train_loader, val_loader, test_loader): self.model = model self.train_loader = train_loader @@ -51,7 +52,7 @@ def _set_parameters(model, parameters): model.load_state_dict(state_dict, strict=True) -def client_fn(cid): +def client_fn(context: Context): model = mnist.LitAutoEncoder() train_loader, val_loader, test_loader = mnist.load_data() @@ -59,7 +60,7 @@ def client_fn(cid): return FlowerClient(model, train_loader, val_loader, test_loader).to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) @@ -71,7 +72,7 @@ def main() -> None: # Flower client client = FlowerClient(model, train_loader, val_loader, test_loader).to_client() - fl.client.start_client(server_address="127.0.0.1:8080", client=client) + start_client(server_address="127.0.0.1:8080", client=client) if __name__ == "__main__": diff --git a/e2e/framework-pytorch/client.py b/e2e/framework-pytorch/client.py index dbfbfed1ffa7..ab4bc7b5c5b9 100644 --- a/e2e/framework-pytorch/client.py +++ b/e2e/framework-pytorch/client.py @@ -10,8 +10,8 @@ from torchvision.transforms import Compose, Normalize, ToTensor from tqdm import tqdm -import flwr as fl -from flwr.common import ConfigsRecord +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import ConfigsRecord, Context # ############################################################################# # 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader @@ -89,7 +89,7 @@ def load_data(): # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return [val.cpu().numpy() for _, val in net.state_dict().items()] @@ -136,18 +136,18 @@ def set_parameters(model, parameters): return -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient().to_client(), ) diff --git a/e2e/framework-scikit-learn/client.py b/e2e/framework-scikit-learn/client.py index b0691e75a79d..24c6617c1289 100644 --- a/e2e/framework-scikit-learn/client.py +++ b/e2e/framework-scikit-learn/client.py @@ -5,7 +5,8 @@ from sklearn.linear_model import LogisticRegression from sklearn.metrics import log_loss -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context # Load MNIST dataset from https://www.openml.org/d/554 (X_train, y_train), (X_test, y_test) = utils.load_mnist() @@ -26,7 +27,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): # type: ignore return utils.get_model_parameters(model) @@ -45,16 +46,14 @@ def evaluate(self, parameters, config): # type: ignore return loss, len(X_test), {"accuracy": accuracy} -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="0.0.0.0:8080", client=FlowerClient().to_client() - ) + start_client(server_address="0.0.0.0:8080", client=FlowerClient().to_client()) diff --git a/e2e/framework-tensorflow/client.py b/e2e/framework-tensorflow/client.py index 779be0c3746d..351f495a3acb 100644 --- a/e2e/framework-tensorflow/client.py +++ b/e2e/framework-tensorflow/client.py @@ -2,7 +2,8 @@ import tensorflow as tf -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context SUBSET_SIZE = 1000 @@ -18,7 +19,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model.get_weights() @@ -33,16 +34,14 @@ def evaluate(self, parameters, config): return loss, len(x_test), {"accuracy": accuracy} -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() - ) + start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/strategies/client.py b/e2e/strategies/client.py index 505340e013a5..0403416cc3b7 100644 --- a/e2e/strategies/client.py +++ b/e2e/strategies/client.py @@ -2,7 +2,8 @@ import tensorflow as tf -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context SUBSET_SIZE = 1000 @@ -33,7 +34,7 @@ def get_model(): # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model.get_weights() @@ -48,17 +49,15 @@ def evaluate(self, parameters, config): return loss, len(x_test), {"accuracy": accuracy} -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() - ) + start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/strategies/test.py b/e2e/strategies/test.py index abf9cdb5a5c7..c567f33b236b 100644 --- a/e2e/strategies/test.py +++ b/e2e/strategies/test.py @@ -3,8 +3,8 @@ import tensorflow as tf from client import SUBSET_SIZE, FlowerClient, get_model -import flwr as fl -from flwr.common import ndarrays_to_parameters +from flwr.common import Context, ndarrays_to_parameters +from flwr.server import ServerConfig from flwr.server.strategy import ( FaultTolerantFedAvg, FedAdagrad, @@ -15,6 +15,7 @@ FedYogi, QFedAvg, ) +from flwr.simulation import start_simulation STRATEGY_LIST = [ FedMedian, @@ -42,8 +43,7 @@ def get_strat(name): init_model = get_model() -def client_fn(cid): - _ = cid +def client_fn(context: Context): return FlowerClient() @@ -71,10 +71,10 @@ def evaluate(server_round, parameters, config): if start_idx >= OPT_IDX: strat_args["tau"] = 0.01 -hist = fl.simulation.start_simulation( +hist = start_simulation( client_fn=client_fn, num_clients=2, - config=fl.server.ServerConfig(num_rounds=3), + config=ServerConfig(num_rounds=3), strategy=strategy(**strat_args), ) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 851083d4abb7..348ef8910dd3 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -18,7 +18,7 @@ import sys import time from dataclasses import dataclass -from logging import DEBUG, ERROR, INFO, WARN +from logging import ERROR, INFO, WARN from pathlib import Path from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union @@ -28,7 +28,7 @@ from flwr.client.client import Client from flwr.client.client_app import ClientApp, LoadClientAppError from flwr.client.typing import ClientFnExt -from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event +from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event from flwr.common.address import parse_address from flwr.common.constant import ( MISSING_EXTRA_REST, @@ -138,8 +138,8 @@ class `flwr.client.Client` (default: None) Starting an SSL-enabled gRPC client using system certificates: - >>> def client_fn(node_id: int, partition_id: Optional[int]): - >>> return FlowerClient() + >>> def client_fn(context: Context): + >>> return FlowerClient().to_client() >>> >>> start_client( >>> server_address=localhost:8080, @@ -160,6 +160,7 @@ class `flwr.client.Client` (default: None) event(EventType.START_CLIENT_ENTER) _start_client_internal( server_address=server_address, + node_config={}, load_client_app_fn=None, client_fn=client_fn, client=client, @@ -181,6 +182,7 @@ class `flwr.client.Client` (default: None) def _start_client_internal( *, server_address: str, + node_config: Dict[str, str], load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None, client_fn: Optional[ClientFnExt] = None, client: Optional[Client] = None, @@ -193,7 +195,6 @@ def _start_client_internal( ] = None, max_retries: Optional[int] = None, max_wait_time: Optional[float] = None, - partition_id: Optional[int] = None, flwr_dir: Optional[Path] = None, ) -> None: """Start a Flower client node which connects to a Flower server. @@ -204,6 +205,8 @@ def _start_client_internal( The IPv4 or IPv6 address of the server. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"[::]:8080"`. + node_config: Dict[str, str] + The configuration of the node. load_client_app_fn : Optional[Callable[[], ClientApp]] (default: None) A function that can be used to load a `ClientApp` instance. client_fn : Optional[ClientFnExt] @@ -238,9 +241,6 @@ class `flwr.client.Client` (default: None) The maximum duration before the client stops trying to connect to the server in case of connection error. If set to None, there is no limit to the total time. - partition_id: Optional[int] (default: None) - The data partition index associated with this node. Better suited for - prototyping purposes. flwr_dir: Optional[Path] (default: None) The fully resolved path containing installed Flower Apps. """ @@ -253,8 +253,7 @@ class `flwr.client.Client` (default: None) if client_fn is None: # Wrap `Client` instance in `client_fn` def single_client_factory( - node_id: int, # pylint: disable=unused-argument - partition_id: Optional[int], # pylint: disable=unused-argument + context: Context, # pylint: disable=unused-argument ) -> Client: if client is None: # Added this to keep mypy happy raise ValueError( @@ -295,7 +294,7 @@ def _on_backoff(retry_state: RetryState) -> None: log(WARN, "Connection attempt failed, retrying...") else: log( - DEBUG, + WARN, "Connection attempt failed, retrying in %.2f seconds", retry_state.actual_wait, ) @@ -319,7 +318,9 @@ def _on_backoff(retry_state: RetryState) -> None: on_backoff=_on_backoff, ) - node_state = NodeState(partition_id=partition_id) + # NodeState gets initialized when the first connection is established + node_state: Optional[NodeState] = None + runs: Dict[int, Run] = {} while not app_state_tracker.interrupt: @@ -334,9 +335,31 @@ def _on_backoff(retry_state: RetryState) -> None: ) as conn: receive, send, create_node, delete_node, get_run = conn - # Register node - if create_node is not None: - create_node() # pylint: disable=not-callable + # Register node when connecting the first time + if node_state is None: + if create_node is None: + if transport not in ["grpc-bidi", None]: + raise NotImplementedError( + "All transports except `grpc-bidi` require " + "an implementation for `create_node()`.'" + ) + # gRPC-bidi doesn't have the concept of node_id, + # so we set it to -1 + node_state = NodeState( + node_id=-1, + node_config={}, + ) + else: + # Call create_node fn to register node + node_id: Optional[int] = ( # pylint: disable=assignment-from-none + create_node() + ) # pylint: disable=not-callable + if node_id is None: + raise ValueError("Node registration failed") + node_state = NodeState( + node_id=node_id, + node_config=node_config, + ) app_state_tracker.register_signal_handler() while not app_state_tracker.interrupt: @@ -580,7 +603,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ Tuple[ Callable[[], Optional[Message]], Callable[[Message], None], - Optional[Callable[[], None]], + Optional[Callable[[], Optional[int]]], Optional[Callable[[], None]], Optional[Callable[[int], Run]], ] diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py index 663d83a8b19e..2a913b3a248d 100644 --- a/src/py/flwr/client/client_app.py +++ b/src/py/flwr/client/client_app.py @@ -30,21 +30,41 @@ from .typing import ClientAppCallable +def _alert_erroneous_client_fn() -> None: + raise ValueError( + "A `ClientApp` cannot make use of a `client_fn` that does " + "not have a signature in the form: `def client_fn(context: " + "Context)`. You can import the `Context` like this: " + "`from flwr.common import Context`" + ) + + def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt: client_fn_args = inspect.signature(client_fn).parameters + first_arg = list(client_fn_args.keys())[0] + + if len(client_fn_args) != 1: + _alert_erroneous_client_fn() + + first_arg_type = client_fn_args[first_arg].annotation - if not all(key in client_fn_args for key in ["node_id", "partition_id"]): + if first_arg_type is str or first_arg == "cid": + # Warn previous signature for `client_fn` seems to be used warn_deprecated_feature( - "`client_fn` now expects a signature `def client_fn(node_id: int, " - "partition_id: Optional[int])`.\nYou provided `client_fn` with signature: " - f"{dict(client_fn_args.items())}" + "`client_fn` now expects a signature `def client_fn(context: Context)`." + "The provided `client_fn` has signature: " + f"{dict(client_fn_args.items())}. You can import the `Context` like this:" + " `from flwr.common import Context`" ) # Wrap depcreated client_fn inside a function with the expected signature def adaptor_fn( - node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument - ) -> Client: - return client_fn(str(partition_id)) # type: ignore + context: Context, + ) -> Client: # pylint: disable=unused-argument + # if patition-id is defined, pass it. Else pass node_id that should + # always be defined during Context init. + cid = context.node_config.get("partition-id", context.node_id) + return client_fn(str(cid)) # type: ignore return adaptor_fn @@ -71,7 +91,7 @@ class ClientApp: >>> class FlowerClient(NumPyClient): >>> # ... >>> - >>> def client_fn(node_id: int, partition_id: Optional[int]): + >>> def client_fn(context: Context): >>> return FlowerClient().to_client() >>> >>> app = ClientApp(client_fn) diff --git a/src/py/flwr/client/grpc_adapter_client/connection.py b/src/py/flwr/client/grpc_adapter_client/connection.py index 971b630e470b..80a5cf0b4656 100644 --- a/src/py/flwr/client/grpc_adapter_client/connection.py +++ b/src/py/flwr/client/grpc_adapter_client/connection.py @@ -44,7 +44,7 @@ def grpc_adapter( # pylint: disable=R0913 Tuple[ Callable[[], Optional[Message]], Callable[[Message], None], - Optional[Callable[[], None]], + Optional[Callable[[], Optional[int]]], Optional[Callable[[], None]], Optional[Callable[[int], Run]], ] diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 3e9f261c1ecf..a6417106d51b 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -72,7 +72,7 @@ def grpc_connection( # pylint: disable=R0913, R0915 Tuple[ Callable[[], Optional[Message]], Callable[[Message], None], - Optional[Callable[[], None]], + Optional[Callable[[], Optional[int]]], Optional[Callable[[], None]], Optional[Callable[[int], Run]], ] diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 8062ce28fcc7..e573df6854bc 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -79,7 +79,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915 Tuple[ Callable[[], Optional[Message]], Callable[[Message], None], - Optional[Callable[[], None]], + Optional[Callable[[], Optional[int]]], Optional[Callable[[], None]], Optional[Callable[[int], Run]], ] @@ -176,7 +176,7 @@ def ping() -> None: if not ping_stop_event.is_set(): ping_stop_event.wait(next_interval) - def create_node() -> None: + def create_node() -> Optional[int]: """Set create_node.""" # Call FleetAPI create_node_request = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL) @@ -189,6 +189,7 @@ def create_node() -> None: nonlocal node, ping_thread node = cast(Node, create_node_response.node) ping_thread = start_ping_loop(ping, ping_stop_event) + return node.node_id def delete_node() -> None: """Set delete_node.""" diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index e9a853a92101..1ab84eb01468 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -92,7 +92,7 @@ def handle_legacy_message_from_msgtype( client_fn: ClientFnExt, message: Message, context: Context ) -> Message: """Handle legacy message in the inner most mod.""" - client = client_fn(message.metadata.dst_node_id, context.partition_id) + client = client_fn(context) # Check if NumPyClient is returend if isinstance(client, NumPyClient): diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 9ce4c9620c43..557d61ffb32a 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -19,7 +19,7 @@ import unittest import uuid from copy import copy -from typing import List, Optional +from typing import List from flwr.client import Client from flwr.client.typing import ClientFnExt @@ -114,9 +114,7 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes: def _get_client_fn(client: Client) -> ClientFnExt: - def client_fn( - node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument - ) -> Client: + def client_fn(contex: Context) -> Client: # pylint: disable=unused-argument return client return client_fn @@ -145,7 +143,7 @@ def test_client_without_get_properties() -> None: actual_msg = handle_legacy_message_from_msgtype( client_fn=_get_client_fn(client), message=message, - context=Context(state=RecordSet(), run_config={}), + context=Context(node_id=1123, node_config={}, state=RecordSet(), run_config={}), ) # Assert @@ -209,7 +207,7 @@ def test_client_with_get_properties() -> None: actual_msg = handle_legacy_message_from_msgtype( client_fn=_get_client_fn(client), message=message, - context=Context(state=RecordSet(), run_config={}), + context=Context(node_id=1123, node_config={}, state=RecordSet(), run_config={}), ) # Assert diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py index 5e4c4411e1f7..2832576fb4fc 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py @@ -73,7 +73,12 @@ def func(configs: Dict[str, ConfigsRecordValues]) -> ConfigsRecord: def _make_ctxt() -> Context: cfg = ConfigsRecord(SecAggPlusState().to_dict()) - return Context(RecordSet(configs_records={RECORD_KEY_STATE: cfg}), run_config={}) + return Context( + node_id=123, + node_config={}, + state=RecordSet(configs_records={RECORD_KEY_STATE: cfg}), + run_config={}, + ) def _make_set_state_fn( diff --git a/src/py/flwr/client/mod/utils_test.py b/src/py/flwr/client/mod/utils_test.py index 7a1dd8988399..a5bbd0a0bb4d 100644 --- a/src/py/flwr/client/mod/utils_test.py +++ b/src/py/flwr/client/mod/utils_test.py @@ -104,7 +104,7 @@ def test_multiple_mods(self) -> None: state = RecordSet() state.metrics_records[METRIC] = MetricsRecord({COUNTER: 0.0}) - context = Context(state=state, run_config={}) + context = Context(node_id=0, node_config={}, state=state, run_config={}) message = _get_dummy_flower_message() # Execute @@ -129,7 +129,7 @@ def test_filter(self) -> None: # Prepare footprint: List[str] = [] mock_app = make_mock_app("app", footprint) - context = Context(state=RecordSet(), run_config={}) + context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={}) message = _get_dummy_flower_message() def filter_mod( diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/node_state.py index 2b090eba9720..393ca4564a35 100644 --- a/src/py/flwr/client/node_state.py +++ b/src/py/flwr/client/node_state.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Optional +from typing import Dict, Optional from flwr.common import Context, RecordSet from flwr.common.config import get_fused_config @@ -35,10 +35,14 @@ class RunInfo: class NodeState: """State of a node where client nodes execute runs.""" - def __init__(self, partition_id: Optional[int]) -> None: - self._meta: Dict[str, Any] = {} # holds metadata about the node + def __init__( + self, + node_id: int, + node_config: Dict[str, str], + ) -> None: + self.node_id = node_id + self.node_config = node_config self.run_infos: Dict[int, RunInfo] = {} - self._partition_id = partition_id def register_context( self, @@ -52,9 +56,10 @@ def register_context( self.run_infos[run_id] = RunInfo( initial_run_config=initial_run_config, context=Context( + node_id=self.node_id, + node_config=self.node_config, state=RecordSet(), run_config=initial_run_config.copy(), - partition_id=self._partition_id, ), ) diff --git a/src/py/flwr/client/node_state_tests.py b/src/py/flwr/client/node_state_tests.py index effd64a3ae7a..26ac4fea6855 100644 --- a/src/py/flwr/client/node_state_tests.py +++ b/src/py/flwr/client/node_state_tests.py @@ -41,7 +41,7 @@ def test_multirun_in_node_state() -> None: expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"} # NodeState - node_state = NodeState(partition_id=None) + node_state = NodeState(node_id=0, node_config={}) for task in tasks: run_id = task.run_id diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index 0efa5731ae51..3e81969d898c 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -90,7 +90,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915 Tuple[ Callable[[], Optional[Message]], Callable[[Message], None], - Optional[Callable[[], None]], + Optional[Callable[[], Optional[int]]], Optional[Callable[[], None]], Optional[Callable[[int], Run]], ] @@ -237,19 +237,20 @@ def ping() -> None: if not ping_stop_event.is_set(): ping_stop_event.wait(next_interval) - def create_node() -> None: + def create_node() -> Optional[int]: """Set create_node.""" req = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL) # Send the request res = _request(req, CreateNodeResponse, PATH_CREATE_NODE) if res is None: - return + return None # Remember the node and the ping-loop thread nonlocal node, ping_thread node = res.node ping_thread = start_ping_loop(ping, ping_stop_event) + return node.node_id def delete_node() -> None: """Set delete_node.""" diff --git a/src/py/flwr/client/supernode/app.py b/src/py/flwr/client/supernode/app.py index 355a2a13a0e5..d61b986bc7af 100644 --- a/src/py/flwr/client/supernode/app.py +++ b/src/py/flwr/client/supernode/app.py @@ -29,7 +29,12 @@ from flwr.client.client_app import ClientApp, LoadClientAppError from flwr.common import EventType, event -from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir +from flwr.common.config import ( + get_flwr_dir, + get_project_config, + get_project_dir, + parse_config_args, +) from flwr.common.constant import ( TRANSPORT_TYPE_GRPC_ADAPTER, TRANSPORT_TYPE_GRPC_RERE, @@ -67,7 +72,7 @@ def run_supernode() -> None: authentication_keys=authentication_keys, max_retries=args.max_retries, max_wait_time=args.max_wait_time, - partition_id=args.partition_id, + node_config=parse_config_args(args.node_config), flwr_dir=get_flwr_dir(args.flwr_dir), ) @@ -93,6 +98,7 @@ def run_client_app() -> None: _start_client_internal( server_address=args.superlink, + node_config=parse_config_args(args.node_config), load_client_app_fn=load_fn, transport=args.transport, root_certificates=root_certificates, @@ -389,11 +395,11 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None: help="The SuperNode's public key (as a path str) to enable authentication.", ) parser.add_argument( - "--partition-id", - type=int, - help="The data partition index associated with this SuperNode. Better suited " - "for prototyping purposes where a SuperNode might only load a fraction of an " - "artificially partitioned dataset (e.g. using `flwr-datasets`)", + "--node-config", + type=str, + help="A comma separated list of key/value pairs (separated by `=`) to " + "configure the SuperNode. " + "E.g. --node-config 'key1=\"value1\",partition-id=0,num-partitions=100'", ) diff --git a/src/py/flwr/client/typing.py b/src/py/flwr/client/typing.py index bf66a9082c77..9faed4bc7283 100644 --- a/src/py/flwr/client/typing.py +++ b/src/py/flwr/client/typing.py @@ -15,7 +15,7 @@ """Custom types for Flower clients.""" -from typing import Callable, Optional +from typing import Callable from flwr.common import Context, Message @@ -23,7 +23,7 @@ # Compatibility ClientFn = Callable[[str], Client] -ClientFnExt = Callable[[int, Optional[int]], Client] +ClientFnExt = Callable[[Context], Client] ClientAppCallable = Callable[[Message, Context], Message] Mod = Callable[[Message, Context, ClientAppCallable], Message] diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index 9770bdb4af2b..54d74353e4ed 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -121,16 +121,16 @@ def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, st def parse_config_args( - config_overrides: Optional[str], + config: Optional[str], separator: str = ",", ) -> Dict[str, str]: """Parse separator separated list of key-value pairs separated by '='.""" overrides: Dict[str, str] = {} - if config_overrides is None: + if config is None: return overrides - overrides_list = config_overrides.split(separator) + overrides_list = config.split(separator) if ( len(overrides_list) == 1 and "=" not in overrides_list diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index f14959589458..72256a62add7 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -57,6 +57,9 @@ FAB_CONFIG_FILE = "pyproject.toml" FLWR_HOME = "FLWR_HOME" +# Constants entries in Node config for Simulation +PARTITION_ID_KEY = "partition-id" +NUM_PARTITIONS_KEY = "num-partitions" GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version" GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit" diff --git a/src/py/flwr/common/context.py b/src/py/flwr/common/context.py index 8120723ce9e9..4da52ba44481 100644 --- a/src/py/flwr/common/context.py +++ b/src/py/flwr/common/context.py @@ -16,7 +16,7 @@ from dataclasses import dataclass -from typing import Dict, Optional +from typing import Dict from .record import RecordSet @@ -27,6 +27,11 @@ class Context: Parameters ---------- + node_id : int + The ID that identifies the node. + node_config : Dict[str, str] + A config (key/value mapping) unique to the node and independent of the + `run_config`. This config persists across all runs this node participates in. state : RecordSet Holds records added by the entity in a given run and that will stay local. This means that the data it holds will never leave the system it's running from. @@ -38,22 +43,21 @@ class Context: A config (key/value mapping) held by the entity in a given run and that will stay local. It can be used at any point during the lifecycle of this entity (e.g. across multiple rounds) - partition_id : Optional[int] (default: None) - An index that specifies the data partition that the ClientApp using this Context - object should make use of. Setting this attribute is better suited for - simulation or proto typing setups. """ + node_id: int + node_config: Dict[str, str] state: RecordSet - partition_id: Optional[int] run_config: Dict[str, str] - def __init__( + def __init__( # pylint: disable=too-many-arguments self, + node_id: int, + node_config: Dict[str, str], state: RecordSet, run_config: Dict[str, str], - partition_id: Optional[int] = None, ) -> None: + self.node_id = node_id + self.node_config = node_config self.state = state self.run_config = run_config - self.partition_id = partition_id diff --git a/src/py/flwr/server/compat/legacy_context.py b/src/py/flwr/server/compat/legacy_context.py index 9e120c824103..ee09d79012dc 100644 --- a/src/py/flwr/server/compat/legacy_context.py +++ b/src/py/flwr/server/compat/legacy_context.py @@ -52,4 +52,4 @@ def __init__( self.strategy = strategy self.client_manager = client_manager self.history = History() - super().__init__(state, run_config={}) + super().__init__(node_id=0, node_config={}, state=state, run_config={}) diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index b4697e99913f..4cc25feb7e0e 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -78,7 +78,9 @@ def _load() -> ServerApp: server_app = _load() # Initialize Context - context = Context(state=RecordSet(), run_config=server_app_run_config) + context = Context( + node_id=0, node_config={}, state=RecordSet(), run_config=server_app_run_config + ) # Call ServerApp server_app(driver=driver, context=context) diff --git a/src/py/flwr/server/server_app_test.py b/src/py/flwr/server/server_app_test.py index 7de8774d4c81..b0672b3202ed 100644 --- a/src/py/flwr/server/server_app_test.py +++ b/src/py/flwr/server/server_app_test.py @@ -29,7 +29,7 @@ def test_server_app_custom_mode() -> None: # Prepare app = ServerApp() driver = MagicMock() - context = Context(state=RecordSet(), run_config={}) + context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={}) called = {"called": False} diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py index 0d2f4d193f0b..0ab29a234f88 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py @@ -21,6 +21,7 @@ import ray from flwr.client.client_app import ClientApp +from flwr.common.constant import PARTITION_ID_KEY from flwr.common.context import Context from flwr.common.logger import log from flwr.common.message import Message @@ -168,7 +169,7 @@ def process_message( Return output message and updated context. """ - partition_id = context.partition_id + partition_id = context.node_config[PARTITION_ID_KEY] try: # Submit a task to the pool diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py index 287983003f8c..a38cff96ceef 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py @@ -23,6 +23,7 @@ from flwr.client import Client, NumPyClient from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.client.node_state import NodeState from flwr.common import ( DEFAULT_TTL, Config, @@ -32,9 +33,9 @@ Message, MessageTypeLegacy, Metadata, - RecordSet, Scalar, ) +from flwr.common.constant import PARTITION_ID_KEY from flwr.common.object_ref import load_app from flwr.common.recordset_compat import getpropertiesins_to_recordset from flwr.server.superlink.fleet.vce.backend.backend import BackendConfig @@ -53,9 +54,7 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]: return {"result": result} -def get_dummy_client( - node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument -) -> Client: +def get_dummy_client(context: Context) -> Client: # pylint: disable=unused-argument """Return a DummyClient converted to Client type.""" return DummyClient().to_client() @@ -103,12 +102,13 @@ def _create_message_and_context() -> Tuple[Message, Context, float]: # Construct a Message mult_factor = 2024 + run_id = 0 getproperties_ins = GetPropertiesIns(config={"factor": mult_factor}) recordset = getpropertiesins_to_recordset(getproperties_ins) message = Message( content=recordset, metadata=Metadata( - run_id=0, + run_id=run_id, message_id="", group_id="", src_node_id=0, @@ -119,8 +119,10 @@ def _create_message_and_context() -> Tuple[Message, Context, float]: ), ) - # Construct emtpy Context - context = Context(state=RecordSet(), run_config={}) + # Construct NodeState and retrieve context + node_state = NodeState(node_id=run_id, node_config={PARTITION_ID_KEY: str(0)}) + node_state.register_context(run_id=run_id) + context = node_state.retrieve_context(run_id=run_id) # Expected output expected_output = pi * mult_factor diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index 3c0b36e1ca3c..cd30c40167c5 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -29,7 +29,12 @@ from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError from flwr.client.node_state import NodeState -from flwr.common.constant import PING_MAX_INTERVAL, ErrorCode +from flwr.common.constant import ( + NUM_PARTITIONS_KEY, + PARTITION_ID_KEY, + PING_MAX_INTERVAL, + ErrorCode, +) from flwr.common.logger import log from flwr.common.message import Error from flwr.common.object_ref import load_app @@ -73,7 +78,7 @@ def worker( task_ins: TaskIns = taskins_queue.get(timeout=1.0) node_id = task_ins.task.consumer.node_id - # Register and retrieve runstate + # Register and retrieve context node_states[node_id].register_context(run_id=task_ins.run_id) context = node_states[node_id].retrieve_context(run_id=task_ins.run_id) @@ -283,8 +288,16 @@ def start_vce( # Construct mapping of NodeStates node_states: Dict[int, NodeState] = {} + # Number of unique partitions + num_partitions = len(set(nodes_mapping.values())) for node_id, partition_id in nodes_mapping.items(): - node_states[node_id] = NodeState(partition_id=partition_id) + node_states[node_id] = NodeState( + node_id=node_id, + node_config={ + PARTITION_ID_KEY: str(partition_id), + NUM_PARTITIONS_KEY: str(num_partitions), + }, + ) # Load backend config log(DEBUG, "Supported backends: %s", list(supported_backends.keys())) diff --git a/src/py/flwr/simulation/app.py b/src/py/flwr/simulation/app.py index 446b0bdeba38..973a9a89e652 100644 --- a/src/py/flwr/simulation/app.py +++ b/src/py/flwr/simulation/app.py @@ -111,9 +111,9 @@ def start_simulation( Parameters ---------- client_fn : ClientFnExt - A function creating Client instances. The function must have the signature - `client_fn(node_id: int, partition_id: Optional[int]). It should return - a single client instance of type Client. Note that the created client + A function creating `Client` instances. The function must have the signature + `client_fn(context: Context). It should return + a single client instance of type `Client`. Note that the created client instances are ephemeral and will often be destroyed after a single method invocation. Since client instances are not long-lived, they should not attempt to carry state over method invocations. Any state required by the instance @@ -327,6 +327,7 @@ def update_resources(f_stop: threading.Event) -> None: client_fn=client_fn, node_id=node_id, partition_id=partition_id, + num_partitions=num_clients, actor_pool=pool, ) initialized_server.client_manager().register(client=client_proxy) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 31bc22c84bd5..895272c2fd79 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -24,7 +24,12 @@ from flwr.client.client_app import ClientApp from flwr.client.node_state import NodeState from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet -from flwr.common.constant import MessageType, MessageTypeLegacy +from flwr.common.constant import ( + NUM_PARTITIONS_KEY, + PARTITION_ID_KEY, + MessageType, + MessageTypeLegacy, +) from flwr.common.logger import log from flwr.common.recordset_compat import ( evaluateins_to_recordset, @@ -43,11 +48,12 @@ class RayActorClientProxy(ClientProxy): """Flower client proxy which delegates work using Ray.""" - def __init__( + def __init__( # pylint: disable=too-many-arguments self, client_fn: ClientFnExt, node_id: int, partition_id: int, + num_partitions: int, actor_pool: VirtualClientEngineActorPool, ): super().__init__(cid=str(node_id)) @@ -59,7 +65,13 @@ def _load_app() -> ClientApp: self.app_fn = _load_app self.actor_pool = actor_pool - self.proxy_state = NodeState(partition_id=self.partition_id) + self.proxy_state = NodeState( + node_id=node_id, + node_config={ + PARTITION_ID_KEY: str(partition_id), + NUM_PARTITIONS_KEY: str(num_partitions), + }, + ) def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: """Sumbit a message to the ActorPool.""" @@ -68,18 +80,19 @@ def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: # Register state self.proxy_state.register_context(run_id=run_id) - # Retrieve state - state = self.proxy_state.retrieve_context(run_id=run_id) + # Retrieve context + context = self.proxy_state.retrieve_context(run_id=run_id) + partition_id_str = context.node_config[PARTITION_ID_KEY] try: self.actor_pool.submit_client_job( - lambda a, a_fn, mssg, partition_id, state: a.run.remote( - a_fn, mssg, partition_id, state + lambda a, a_fn, mssg, partition_id, context: a.run.remote( + a_fn, mssg, partition_id, context ), - (self.app_fn, message, str(self.partition_id), state), + (self.app_fn, message, partition_id_str, context), ) out_mssg, updated_context = self.actor_pool.get_client_result( - str(self.partition_id), timeout + partition_id_str, timeout ) # Update state diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 83f6cfe05313..62e0cfd61c99 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -17,12 +17,13 @@ from math import pi from random import shuffle -from typing import Dict, List, Optional, Tuple, Type +from typing import Dict, List, Tuple, Type import ray from flwr.client import Client, NumPyClient from flwr.client.client_app import ClientApp +from flwr.client.node_state import NodeState from flwr.common import ( DEFAULT_TTL, Config, @@ -31,15 +32,18 @@ Message, MessageTypeLegacy, Metadata, - RecordSet, Scalar, ) +from flwr.common.constant import NUM_PARTITIONS_KEY, PARTITION_ID_KEY from flwr.common.recordset_compat import ( getpropertiesins_to_recordset, recordset_to_getpropertiesres, ) from flwr.common.recordset_compat_test import _get_valid_getpropertiesins -from flwr.simulation.app import _create_node_id_to_partition_mapping +from flwr.simulation.app import ( + NodeToPartitionMapping, + _create_node_id_to_partition_mapping, +) from flwr.simulation.ray_transport.ray_actor import ( ClientAppActor, VirtualClientEngineActor, @@ -65,16 +69,16 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]: return {"result": result} -def get_dummy_client( - node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument -) -> Client: +def get_dummy_client(context: Context) -> Client: """Return a DummyClient converted to Client type.""" - return DummyClient(node_id).to_client() + return DummyClient(context.node_id).to_client() def prep( actor_type: Type[VirtualClientEngineActor] = ClientAppActor, -) -> Tuple[List[RayActorClientProxy], VirtualClientEngineActorPool]: # pragma: no cover +) -> Tuple[ + List[RayActorClientProxy], VirtualClientEngineActorPool, NodeToPartitionMapping +]: # pragma: no cover """Prepare ClientProxies and pool for tests.""" client_resources = {"num_cpus": 1, "num_gpus": 0.0} @@ -96,12 +100,13 @@ def create_actor_fn() -> Type[VirtualClientEngineActor]: client_fn=get_dummy_client, node_id=node_id, partition_id=partition_id, + num_partitions=num_proxies, actor_pool=pool, ) for node_id, partition_id in mapping.items() ] - return proxies, pool + return proxies, pool, mapping def test_cid_consistency_one_at_a_time() -> None: @@ -109,7 +114,7 @@ def test_cid_consistency_one_at_a_time() -> None: Submit one job and waits for completion. Then submits the next and so on """ - proxies, _ = prep() + proxies, _, _ = prep() getproperties_ins = _get_valid_getpropertiesins() recordset = getpropertiesins_to_recordset(getproperties_ins) @@ -139,7 +144,7 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: All jobs are submitted at the same time. Then fetched one at a time. This also tests NodeState (at each Proxy) and RunState basic functionality. """ - proxies, _ = prep() + proxies, _, _ = prep() run_id = 0 getproperties_ins = _get_valid_getpropertiesins() @@ -186,9 +191,19 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: def test_cid_consistency_without_proxies() -> None: """Test cid consistency of jobs submitted/retrieved to/from pool w/o ClientProxy.""" - proxies, pool = prep() - num_clients = len(proxies) - node_ids = list(range(num_clients)) + _, pool, mapping = prep() + node_ids = list(mapping.keys()) + + # register node states + node_states: Dict[int, NodeState] = {} + for node_id, partition_id in mapping.items(): + node_states[node_id] = NodeState( + node_id=node_id, + node_config={ + PARTITION_ID_KEY: str(partition_id), + NUM_PARTITIONS_KEY: str(len(node_ids)), + }, + ) getproperties_ins = _get_valid_getpropertiesins() recordset = getpropertiesins_to_recordset(getproperties_ins) @@ -198,11 +213,12 @@ def _load_app() -> ClientApp: # submit all jobs (collect later) shuffle(node_ids) + run_id = 0 for node_id in node_ids: message = Message( content=recordset, metadata=Metadata( - run_id=0, + run_id=run_id, message_id="", group_id=str(0), src_node_id=0, @@ -212,20 +228,20 @@ def _load_app() -> ClientApp: message_type=MessageTypeLegacy.GET_PROPERTIES, ), ) + # register and retrieve context + node_states[node_id].register_context(run_id=run_id) + context = node_states[node_id].retrieve_context(run_id=run_id) + partition_id_str = context.node_config[PARTITION_ID_KEY] pool.submit_client_job( lambda a, c_fn, j_fn, nid_, state: a.run.remote(c_fn, j_fn, nid_, state), - ( - _load_app, - message, - str(node_id), - Context(state=RecordSet(), run_config={}, partition_id=node_id), - ), + (_load_app, message, partition_id_str, context), ) # fetch results one at a time shuffle(node_ids) for node_id in node_ids: - message_out, _ = pool.get_client_result(str(node_id), timeout=None) + partition_id_str = str(mapping[node_id]) + message_out, _ = pool.get_client_result(partition_id_str, timeout=None) res = recordset_to_getpropertiesres(message_out.content) assert node_id * pi == res.properties["result"] diff --git a/src/py/flwr/superexec/app.py b/src/py/flwr/superexec/app.py index a90422bd635f..35ca8fbb0b87 100644 --- a/src/py/flwr/superexec/app.py +++ b/src/py/flwr/superexec/app.py @@ -133,11 +133,11 @@ def _try_obtain_certificates( return None # Check if certificates are provided if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile: - if not Path.is_file(args.ssl_ca_certfile): + if not Path(args.ssl_ca_certfile).is_file(): sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.") - if not Path.is_file(args.ssl_certfile): + if not Path(args.ssl_certfile).is_file(): sys.exit("Path argument `--ssl-certfile` does not point to a file.") - if not Path.is_file(args.ssl_keyfile): + if not Path(args.ssl_keyfile).is_file(): sys.exit("Path argument `--ssl-keyfile` does not point to a file.") certificates = ( Path(args.ssl_ca_certfile).read_bytes(), # CA certificate