diff --git a/datasets/flwr_datasets/partitioner/distribution_partitioner.py b/datasets/flwr_datasets/partitioner/distribution_partitioner.py index f3721dacb4fe..06c437c9d533 100644 --- a/datasets/flwr_datasets/partitioner/distribution_partitioner.py +++ b/datasets/flwr_datasets/partitioner/distribution_partitioner.py @@ -14,7 +14,7 @@ # ============================================================================== """ Distribution partitioner.""" -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional import numpy as np from flwr_datasets.partitioner.partitioner import Partitioner @@ -68,13 +68,17 @@ class label distribution. `float` values are rounded to the nearest `int`. >>> # Generate a vector from a log-normal probability distribution >>> rng = np.random.default_rng(2024) >>> mu, sigma = 0., 2. - >>> distribution_proba = rng.lognormal(mu, sigma, (num_clients*num_unique_labels_per_client,)) + >>> distribution_proba = rng.lognormal( + >>> mu, + >>> sigma, + >>> (num_clients*num_unique_labels_per_client), + >>> ) >>> >>> partitioner = DistributionPartitioner( >>> distribution_array=distribution_proba, >>> num_partitions=num_clients, >>> num_unique_labels_per_partition=num_labels_per_client, - >>> partition_by="label", # Assumes that the dataset has a target column `label` + >>> partition_by="label", # Dataset has a target column `label` >>> preassigned_num_samples_per_label=5, >>> ) >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) @@ -155,40 +159,45 @@ def _determine_partition_id_to_indices_if_needed( unique_label_to_indices[unique_label] ) - label_distribution = np.fromiter(label_distribution_dict.values(), dtype=float) + if self._rescale: + # Compute the normalized distribution for each class label + self._distribution_array = self._distribution_array / np.sum( + self._distribution_array, axis=-1, keepdims=True + ) - # Compute the normalized distribution for each class label - self._distribution_array = self._distribution_array / np.sum( - self._distribution_array, axis=-1, keepdims=True - ) + # Compute the total preassigned number of samples per label for all labels + # and partitions. This sum will be subtracted + # from the label distribution from the original dataset, and added back later. + # It ensures that (1) each partition will have at least + # `self._preassigned_num_samples_per_label` and (2) there is sufficient + # indices to sample from the dataset. + total_preassigned_samples = int( + self._preassigned_num_samples_per_label + * self._num_unique_labels_per_partition + * self._num_partitions + / self._num_unique_labels + ) - # Compute the total preassigned number of samples per label for all labels - # and partitions. This sum will be subtracted - # from the label distribution from the original dataset, and added back later. - # It ensures that (1) each partition will have at least - # `self._preassigned_num_samples_per_label` and (2) there is sufficient - # indices to sample from the dataset. - total_preassigned_samples = int( - self._preassigned_num_samples_per_label - * self._num_unique_labels_per_partition - * self._num_partitions - / self._num_unique_labels - ) + label_distribution = np.fromiter( + label_distribution_dict.values(), + dtype=float, + ) - # Subtract the preassigned total amount from the label distribution, - # we'll add these back later. - label_distribution -= total_preassigned_samples + # Subtract the preassigned total amount from the label distribution, + # we'll add these back later. + label_distribution -= total_preassigned_samples - # Rescale normalized distribution with the actual label distribution. - # Each row represents the number of samples to be taken for that class label - # and the sum of each row equals the total of each class label. - # TODO: Skip this step if rescale = False - label_sampling_matrix = np.floor( - (self._distribution_array * label_distribution[:, np.newaxis]) - ).astype(int) + # Rescale normalized distribution with the actual label distribution. + # Each row represents the number of samples to be taken for that class label + # and the sum of each row equals the total of each class label. + label_sampling_matrix = np.floor( + (self._distribution_array * label_distribution[:, np.newaxis]) + ).astype(int) - # Add back the preassigned total amount - label_sampling_matrix += self._preassigned_num_samples_per_label + # Add back the preassigned total amount + label_sampling_matrix += self._preassigned_num_samples_per_label + else: + label_sampling_matrix = self._distribution_array.astype(int) # Create the label sampling dictionary label_samples = {