Skip to content

Commit

Permalink
Add a partition_sizes check
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Feb 28, 2024
1 parent 31ed29d commit 76d2fa9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
14 changes: 12 additions & 2 deletions datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def load_partition(self, node_id: int) -> datasets.Dataset:
# requested. Only the first call creates the indices assignments for all the
# partition indices.
self._check_num_partitions_correctness_if_needed()
self._check_partition_sizes_correctness_if_needed()
self._check_the_sum_of_partition_sizes()
self._determine_num_unique_classes_if_needed()
self._alpha = self._initialize_alpha_if_needed(self._initial_alpha)
Expand Down Expand Up @@ -234,8 +235,17 @@ def _check_num_partitions_correctness_if_needed(self) -> None:
if not self._node_id_to_indices_determined:
if self._num_partitions > self.dataset.num_rows:
raise ValueError(
"The number of partitions needs to be smaller than the number of "
"samples in the dataset."
"The number of partitions needs to be smaller or equal to "
" the number of samples in the dataset."
)

def _check_partition_sizes_correctness_if_needed(self) -> None:
"""Test partition_sizes when the dataset is given (in load_partition)."""
if not self._node_id_to_indices_determined:
if sum(self._partition_sizes) > self.dataset.num_rows:
raise ValueError(
"The sum of the `partition_sizes` needs to be smaller or equal to "
"the number of samples in the dataset."
)

def _check_num_partitions_greater_than_zero(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,17 @@ def test_incorrect_shape_of_alpha(self) -> None:
with self.assertRaises(ValueError):
_ = partitioner.load_partition(0)

def test_too_big_sum_of_partition_sizes(self) -> None:
"""Test sum of partition_sizes greater than the size of the dataset."""
num_rows = 113
partition_by = "labels"
alpha = 1.0
partition_sizes = [60, 60, 30, 43]

_, partitioner = _dummy_setup(num_rows, partition_by, partition_sizes, alpha)
with self.assertRaises(ValueError):
_ = partitioner.load_partition(0)


if __name__ == "__main__":
unittest.main()

0 comments on commit 76d2fa9

Please sign in to comment.