diff --git a/datasets/flwr_datasets/partitioner/grouped_natural_id_partitioner.py b/datasets/flwr_datasets/partitioner/grouped_natural_id_partitioner.py index b9bd610fed73..0a647e47f41a 100644 --- a/datasets/flwr_datasets/partitioner/grouped_natural_id_partitioner.py +++ b/datasets/flwr_datasets/partitioner/grouped_natural_id_partitioner.py @@ -91,12 +91,13 @@ def _create_int_partition_id_to_natural_id(self) -> None: unique_natural_ids = sorted(unique_natural_ids) num_unique_natural_ids = len(unique_natural_ids) remainder = num_unique_natural_ids % self._group_size + num_groups = num_unique_natural_ids // self._group_size + # Note that the number of groups might be different that this number + # due to certain modes, it's a base value. if self._mode == "allow-bigger": - num_groups = num_unique_natural_ids // self._group_size groups_of_natural_ids = np.array_split(unique_natural_ids, num_groups) elif self._mode == "drop-reminder": - num_groups = num_unique_natural_ids // self._group_size # Narrow down the unique_natural_ids to not have a bigger group # which is the behavior of the np.array_split unique_natural_ids = unique_natural_ids[ @@ -104,7 +105,6 @@ def _create_int_partition_id_to_natural_id(self) -> None: ] groups_of_natural_ids = np.array_split(unique_natural_ids, num_groups) elif self._mode == "allow-smaller": - num_groups = num_unique_natural_ids // self._group_size if remainder > 0: last_group_ids = unique_natural_ids[-remainder:] unique_natural_ids = unique_natural_ids[ @@ -122,7 +122,6 @@ def _create_int_partition_id_to_natural_id(self) -> None: f"enables strict mode or relax the mode parameter. Refer to the " f"documentation of the mode parameter for the available modes." ) - num_groups = num_unique_natural_ids // self._group_size groups_of_natural_ids = np.array_split(unique_natural_ids, num_groups) else: raise ValueError(