Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
chongshenng committed Jul 12, 2024
1 parent 350d9b4 commit f7c705f
Showing 1 changed file with 41 additions and 32 deletions.
73 changes: 41 additions & 32 deletions datasets/flwr_datasets/partitioner/distribution_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================

""" Distribution partitioner."""
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional

import numpy as np
from flwr_datasets.partitioner.partitioner import Partitioner
Expand Down Expand Up @@ -68,13 +68,17 @@ class label distribution. `float` values are rounded to the nearest `int`.
>>> # Generate a vector from a log-normal probability distribution
>>> rng = np.random.default_rng(2024)
>>> mu, sigma = 0., 2.
>>> distribution_proba = rng.lognormal(mu, sigma, (num_clients*num_unique_labels_per_client,))
>>> distribution_proba = rng.lognormal(
>>> mu,
>>> sigma,
>>> (num_clients*num_unique_labels_per_client),
>>> )
>>>
>>> partitioner = DistributionPartitioner(
>>> distribution_array=distribution_proba,
>>> num_partitions=num_clients,
>>> num_unique_labels_per_partition=num_labels_per_client,
>>> partition_by="label", # Assumes that the dataset has a target column `label`
>>> partition_by="label", # Dataset has a target column `label`
>>> preassigned_num_samples_per_label=5,
>>> )
>>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner})
Expand Down Expand Up @@ -155,40 +159,45 @@ def _determine_partition_id_to_indices_if_needed(
unique_label_to_indices[unique_label]
)

label_distribution = np.fromiter(label_distribution_dict.values(), dtype=float)
if self._rescale:
# Compute the normalized distribution for each class label
self._distribution_array = self._distribution_array / np.sum(
self._distribution_array, axis=-1, keepdims=True
)

# Compute the normalized distribution for each class label
self._distribution_array = self._distribution_array / np.sum(
self._distribution_array, axis=-1, keepdims=True
)
# Compute the total preassigned number of samples per label for all labels
# and partitions. This sum will be subtracted
# from the label distribution from the original dataset, and added back later.
# It ensures that (1) each partition will have at least
# `self._preassigned_num_samples_per_label` and (2) there is sufficient
# indices to sample from the dataset.
total_preassigned_samples = int(
self._preassigned_num_samples_per_label
* self._num_unique_labels_per_partition
* self._num_partitions
/ self._num_unique_labels
)

# Compute the total preassigned number of samples per label for all labels
# and partitions. This sum will be subtracted
# from the label distribution from the original dataset, and added back later.
# It ensures that (1) each partition will have at least
# `self._preassigned_num_samples_per_label` and (2) there is sufficient
# indices to sample from the dataset.
total_preassigned_samples = int(
self._preassigned_num_samples_per_label
* self._num_unique_labels_per_partition
* self._num_partitions
/ self._num_unique_labels
)
label_distribution = np.fromiter(
label_distribution_dict.values(),
dtype=float,
)

# Subtract the preassigned total amount from the label distribution,
# we'll add these back later.
label_distribution -= total_preassigned_samples
# Subtract the preassigned total amount from the label distribution,
# we'll add these back later.
label_distribution -= total_preassigned_samples

# Rescale normalized distribution with the actual label distribution.
# Each row represents the number of samples to be taken for that class label
# and the sum of each row equals the total of each class label.
# TODO: Skip this step if rescale = False
label_sampling_matrix = np.floor(
(self._distribution_array * label_distribution[:, np.newaxis])
).astype(int)
# Rescale normalized distribution with the actual label distribution.
# Each row represents the number of samples to be taken for that class label
# and the sum of each row equals the total of each class label.
label_sampling_matrix = np.floor(
(self._distribution_array * label_distribution[:, np.newaxis])
).astype(int)

# Add back the preassigned total amount
label_sampling_matrix += self._preassigned_num_samples_per_label
# Add back the preassigned total amount
label_sampling_matrix += self._preassigned_num_samples_per_label
else:
label_sampling_matrix = self._distribution_array.astype(int)

# Create the label sampling dictionary
label_samples = {
Expand Down

0 comments on commit f7c705f

Please sign in to comment.