diff --git a/datasets/flwr_datasets/partitioner/distribution_partitioner.py b/datasets/flwr_datasets/partitioner/distribution_partitioner.py index e3c251ec41d5..701f6719682b 100644 --- a/datasets/flwr_datasets/partitioner/distribution_partitioner.py +++ b/datasets/flwr_datasets/partitioner/distribution_partitioner.py @@ -41,11 +41,13 @@ class DistributionPartitioner(Partitioner): # pylint: disable=R0902 ( `num_unique_labels`, ---------------------------------------------------- ), `num_unique_labels` the label_id at the i'th row is assigned to the partition_id based on the formula: - partition_id = alpha + beta + partition_id = where, + <.> denotes the reindexed sequence of partition_ids in monotone increasing + order for all j's alpha* = (i - num_unique_labels_per_partition + 1) \ - + (j % num_unique_labels_per_partition) - alpha = alpha* + (alpha* > 0 ? 0 : num_unique_labels) + + (j % num_unique_labels_per_partition), + alpha = alpha* + (alpha* >= 0 ? 0 : num_unique_labels), beta = num_unique_labels * (j // num_unique_labels_per_partition) and j in {0, 1, 2, ..., `num_columns`}. Each list representing the partition_ids for the i'th row is sorted in ascending order. So, for a dataset with 10 unique labels