Skip to content

Commit

Permalink
Fds add num_partitions property to partitioners (#3095)
Browse files Browse the repository at this point in the history
* Add num_partition property

* Trigger the partitioning in the num_partitions

---------

Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
adam-narozniak and danieljanes authored Mar 12, 2024
1 parent d6f274b commit 1057001
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 0 deletions.
7 changes: 7 additions & 0 deletions datasets/flwr_datasets/partitioner/dirichlet_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ def load_partition(self, node_id: int) -> datasets.Dataset:
self._determine_node_id_to_indices_if_needed()
return self.dataset.select(self._node_id_to_indices[node_id])

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
self._check_num_partitions_correctness_if_needed()
self._determine_node_id_to_indices_if_needed()
return self._num_partitions

def _initialize_alpha(
self, alpha: Union[int, float, List[float], NDArrayFloat]
) -> NDArrayFloat:
Expand Down
5 changes: 5 additions & 0 deletions datasets/flwr_datasets/partitioner/iid_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,8 @@ def load_partition(self, node_id: int) -> datasets.Dataset:
return self.dataset.shard(
num_shards=self._num_partitions, index=node_id, contiguous=True
)

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
return self._num_partitions
11 changes: 11 additions & 0 deletions datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,17 @@ def load_partition(self, node_id: int) -> datasets.Dataset:
self._determine_node_id_to_indices_if_needed()
return self.dataset.select(self._node_id_to_indices[node_id])

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
self._check_num_partitions_correctness_if_needed()
self._check_partition_sizes_correctness_if_needed()
self._check_the_sum_of_partition_sizes()
self._determine_num_unique_classes_if_needed()
self._alpha = self._initialize_alpha_if_needed(self._initial_alpha)
self._determine_node_id_to_indices_if_needed()
return self._num_partitions

def _initialize_alpha_if_needed(
self, alpha: Union[int, float, List[float], NDArrayFloat]
) -> NDArrayFloat:
Expand Down
7 changes: 7 additions & 0 deletions datasets/flwr_datasets/partitioner/natural_id_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ def load_partition(self, node_id: int) -> datasets.Dataset:
lambda row: row[self._partition_by] == self._node_id_to_natural_id[node_id]
)

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
if len(self._node_id_to_natural_id) == 0:
self._create_int_node_id_to_natural_id()
return len(self._node_id_to_natural_id)

@property
def node_id_to_natural_id(self) -> Dict[int, str]:
"""Node id to corresponding natural id present.
Expand Down
5 changes: 5 additions & 0 deletions datasets/flwr_datasets/partitioner/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,8 @@ def is_dataset_assigned(self) -> bool:
True if a dataset is assigned, otherwise False.
"""
return self._dataset is not None

@property
@abstractmethod
def num_partitions(self) -> int:
"""Total number of partitions."""
9 changes: 9 additions & 0 deletions datasets/flwr_datasets/partitioner/shard_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@ def load_partition(self, node_id: int) -> datasets.Dataset:
self._determine_node_id_to_indices_if_needed()
return self.dataset.select(self._node_id_to_indices[node_id])

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
self._check_num_partitions_correctness_if_needed()
self._check_possibility_of_partitions_creation()
self._sort_dataset_if_needed()
self._determine_node_id_to_indices_if_needed()
return self._num_partitions

def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0914
"""Assign sample indices to each node id.
Expand Down
6 changes: 6 additions & 0 deletions datasets/flwr_datasets/partitioner/size_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ def load_partition(self, node_id: int) -> datasets.Dataset:
self._determine_node_id_to_indices_if_needed()
return self.dataset.select(self._node_id_to_indices[node_id])

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
self._determine_node_id_to_indices_if_needed()
return self._num_partitions

@property
def node_id_to_size(self) -> Dict[int, int]:
"""Node id to the number of samples."""
Expand Down

0 comments on commit 1057001

Please sign in to comment.