From db90c9dd809dd264b64019c5d1d7813cf6230823 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Fri, 23 Aug 2024 10:48:59 +0200 Subject: [PATCH] Move num_groups computation out of if branches --- .../partitioner/grouped_natural_id_partitioner.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/grouped_natural_id_partitioner.py b/datasets/flwr_datasets/partitioner/grouped_natural_id_partitioner.py index b9bd610fed73..0a647e47f41a 100644 --- a/datasets/flwr_datasets/partitioner/grouped_natural_id_partitioner.py +++ b/datasets/flwr_datasets/partitioner/grouped_natural_id_partitioner.py @@ -91,12 +91,13 @@ def _create_int_partition_id_to_natural_id(self) -> None: unique_natural_ids = sorted(unique_natural_ids) num_unique_natural_ids = len(unique_natural_ids) remainder = num_unique_natural_ids % self._group_size + num_groups = num_unique_natural_ids // self._group_size + # Note that the number of groups might be different that this number + # due to certain modes, it's a base value. if self._mode == "allow-bigger": - num_groups = num_unique_natural_ids // self._group_size groups_of_natural_ids = np.array_split(unique_natural_ids, num_groups) elif self._mode == "drop-reminder": - num_groups = num_unique_natural_ids // self._group_size # Narrow down the unique_natural_ids to not have a bigger group # which is the behavior of the np.array_split unique_natural_ids = unique_natural_ids[ @@ -104,7 +105,6 @@ def _create_int_partition_id_to_natural_id(self) -> None: ] groups_of_natural_ids = np.array_split(unique_natural_ids, num_groups) elif self._mode == "allow-smaller": - num_groups = num_unique_natural_ids // self._group_size if remainder > 0: last_group_ids = unique_natural_ids[-remainder:] unique_natural_ids = unique_natural_ids[ @@ -122,7 +122,6 @@ def _create_int_partition_id_to_natural_id(self) -> None: f"enables strict mode or relax the mode parameter. Refer to the " f"documentation of the mode parameter for the available modes." ) - num_groups = num_unique_natural_ids // self._group_size groups_of_natural_ids = np.array_split(unique_natural_ids, num_groups) else: raise ValueError(