From 1057001fc05ace6dcb87b373ff251bc870f7fc72 Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Tue, 12 Mar 2024 22:36:40 +0100 Subject: [PATCH] Fds add num_partitions property to partitioners (#3095) * Add num_partition property * Trigger the partitioning in the num_partitions --------- Co-authored-by: Daniel J. Beutel --- .../partitioner/dirichlet_partitioner.py | 7 +++++++ datasets/flwr_datasets/partitioner/iid_partitioner.py | 5 +++++ .../partitioner/inner_dirichlet_partitioner.py | 11 +++++++++++ .../partitioner/natural_id_partitioner.py | 7 +++++++ datasets/flwr_datasets/partitioner/partitioner.py | 5 +++++ .../flwr_datasets/partitioner/shard_partitioner.py | 9 +++++++++ .../flwr_datasets/partitioner/size_partitioner.py | 6 ++++++ 7 files changed, 50 insertions(+) diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py index 5f1df71991bb..5271aad74a1e 100644 --- a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py @@ -132,6 +132,13 @@ def load_partition(self, node_id: int) -> datasets.Dataset: self._determine_node_id_to_indices_if_needed() return self.dataset.select(self._node_id_to_indices[node_id]) + @property + def num_partitions(self) -> int: + """Total number of partitions.""" + self._check_num_partitions_correctness_if_needed() + self._determine_node_id_to_indices_if_needed() + return self._num_partitions + def _initialize_alpha( self, alpha: Union[int, float, List[float], NDArrayFloat] ) -> NDArrayFloat: diff --git a/datasets/flwr_datasets/partitioner/iid_partitioner.py b/datasets/flwr_datasets/partitioner/iid_partitioner.py index c72b34f081f2..faa1dfa10615 100644 --- a/datasets/flwr_datasets/partitioner/iid_partitioner.py +++ b/datasets/flwr_datasets/partitioner/iid_partitioner.py @@ -50,3 +50,8 @@ def load_partition(self, node_id: int) -> datasets.Dataset: return self.dataset.shard( num_shards=self._num_partitions, index=node_id, contiguous=True ) + + @property + def num_partitions(self) -> int: + """Total number of partitions.""" + return self._num_partitions diff --git a/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner.py index c25a9b059d18..bf07ab3591f5 100644 --- a/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner.py +++ b/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner.py @@ -119,6 +119,17 @@ def load_partition(self, node_id: int) -> datasets.Dataset: self._determine_node_id_to_indices_if_needed() return self.dataset.select(self._node_id_to_indices[node_id]) + @property + def num_partitions(self) -> int: + """Total number of partitions.""" + self._check_num_partitions_correctness_if_needed() + self._check_partition_sizes_correctness_if_needed() + self._check_the_sum_of_partition_sizes() + self._determine_num_unique_classes_if_needed() + self._alpha = self._initialize_alpha_if_needed(self._initial_alpha) + self._determine_node_id_to_indices_if_needed() + return self._num_partitions + def _initialize_alpha_if_needed( self, alpha: Union[int, float, List[float], NDArrayFloat] ) -> NDArrayFloat: diff --git a/datasets/flwr_datasets/partitioner/natural_id_partitioner.py b/datasets/flwr_datasets/partitioner/natural_id_partitioner.py index b8f28696f3b7..947501965cc6 100644 --- a/datasets/flwr_datasets/partitioner/natural_id_partitioner.py +++ b/datasets/flwr_datasets/partitioner/natural_id_partitioner.py @@ -65,6 +65,13 @@ def load_partition(self, node_id: int) -> datasets.Dataset: lambda row: row[self._partition_by] == self._node_id_to_natural_id[node_id] ) + @property + def num_partitions(self) -> int: + """Total number of partitions.""" + if len(self._node_id_to_natural_id) == 0: + self._create_int_node_id_to_natural_id() + return len(self._node_id_to_natural_id) + @property def node_id_to_natural_id(self) -> Dict[int, str]: """Node id to corresponding natural id present. diff --git a/datasets/flwr_datasets/partitioner/partitioner.py b/datasets/flwr_datasets/partitioner/partitioner.py index 92405152efc6..73eb6f4a17b3 100644 --- a/datasets/flwr_datasets/partitioner/partitioner.py +++ b/datasets/flwr_datasets/partitioner/partitioner.py @@ -79,3 +79,8 @@ def is_dataset_assigned(self) -> bool: True if a dataset is assigned, otherwise False. """ return self._dataset is not None + + @property + @abstractmethod + def num_partitions(self) -> int: + """Total number of partitions.""" diff --git a/datasets/flwr_datasets/partitioner/shard_partitioner.py b/datasets/flwr_datasets/partitioner/shard_partitioner.py index 7c86570fe487..05444f537c8c 100644 --- a/datasets/flwr_datasets/partitioner/shard_partitioner.py +++ b/datasets/flwr_datasets/partitioner/shard_partitioner.py @@ -179,6 +179,15 @@ def load_partition(self, node_id: int) -> datasets.Dataset: self._determine_node_id_to_indices_if_needed() return self.dataset.select(self._node_id_to_indices[node_id]) + @property + def num_partitions(self) -> int: + """Total number of partitions.""" + self._check_num_partitions_correctness_if_needed() + self._check_possibility_of_partitions_creation() + self._sort_dataset_if_needed() + self._determine_node_id_to_indices_if_needed() + return self._num_partitions + def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0914 """Assign sample indices to each node id. diff --git a/datasets/flwr_datasets/partitioner/size_partitioner.py b/datasets/flwr_datasets/partitioner/size_partitioner.py index 35ca750949ee..29fc2e5b1add 100644 --- a/datasets/flwr_datasets/partitioner/size_partitioner.py +++ b/datasets/flwr_datasets/partitioner/size_partitioner.py @@ -84,6 +84,12 @@ def load_partition(self, node_id: int) -> datasets.Dataset: self._determine_node_id_to_indices_if_needed() return self.dataset.select(self._node_id_to_indices[node_id]) + @property + def num_partitions(self) -> int: + """Total number of partitions.""" + self._determine_node_id_to_indices_if_needed() + return self._num_partitions + @property def node_id_to_size(self) -> Dict[int, int]: """Node id to the number of samples."""