Skip to content

Commit

Permalink
Add comments, examples, more functions
Browse files Browse the repository at this point in the history
  • Loading branch information
KarhouTam committed Jun 20, 2024
1 parent cd7bcbc commit f865947
Showing 1 changed file with 124 additions and 46 deletions.
170 changes: 124 additions & 46 deletions datasets/flwr_datasets/partitioner/semantic_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Semantic partitioner class that works with Hugging Face Datasets."""
# NOTE: Semantic Partioner can only work with image dataset.


import warnings
from typing import Any, Callable, Dict, List, Optional, Union
Expand All @@ -36,36 +36,58 @@
class SemanticPartitioner(Partitioner):
"""Partitioner based on data semantic information.
Implementation based on Bayesian Nonparametric Federated Learning of Neural Networks
https://arxiv.org/abs/1905.12022.
The algorithm sequentially divides the data with each label. The fractions of the
data with each label is drawn from Dirichlet distribution and adjusted in case of
balancing. The data is assigned. In case the `min_partition_size` is not satisfied
the algorithm is run again (the fractions will change since it is a random process
even though the alpha stays the same).
The notion of balancing is explicitly introduced here (not mentioned in paper but
implemented in the code). It is a mechanism that excludes the partition from
assigning new samples to it if the current number of samples on that partition
exceeds the average number that the partition would get in case of even data
distribution. It is controlled by`self_balancing` parameter.
NOTE: Semantic Partioner can ONLY work with image dataset.
This implementation is modified from the original implementation:
https://github.com/google-research/federated/tree/master/generalization,
which used tensorflow-federated.
References:
https://arxiv.org/abs/2110.14216 (accepted by ICLR 2022)
(Cited from section 4.1 in the paper)
Semantic partitioner's goal is to reverse-engineer the federated dataset-generating
process so that each client possesses semantically similar data. For example, for
the EMNIST dataset, we expect every client (writer) to (i) write in a consistent style
for each digit (intra-client intra-label similarity) and (ii) use a consistent writing
style across all digits (intra-client inter-label similarity). A simple approach might
be to cluster similar examples together and sample client data from clusters. However,
if one directly clusters the entire dataset, the resulting clusters may end up largely
correlated to labels. To disentangle the effect of label heterogeneity and semantic
heterogeneity, we propose the following algorithm to enforce intra-client intra-label
similarity and intra-client inter-label similarity in two separate stages.
• Stage 1: For each label, we embed examples using a pretrained neural network
(extracting semantic features), and fit a Gaussian Mixture Model to cluster pretrained
embeddings into groups. Note that this results in multiple groups per label.
This stage enforces intra-client intra-label consistency.
• Stage 2: To package the clusters from different labels into clients, we aim to compute
an optimal multi-partite matching with cost-matrix defined by KL-divergence between
the Gaussian clusters. To reduce complexity, we heuristically solve the optimal multi-partite
matching by progressively solving the optimal bipartite matching at each time for
randomly-chosen label pairs.
This stage enforces intra-client inter-label consistency.
Parameters
----------
num_partitions : int
The total number of partitions that the data will be divided into.
partition_by : str
Column name of the labels (targets) based on which Dirichlet sampling works.
alpha : Union[int, float, List[float], NDArrayFloat]
Concentration parameter to the Dirichlet distribution
min_partition_size : int
The minimum number of samples that each partitions will have (the sampling
process is repeated if any partition is too small).
self_balancing : bool
Whether assign further samples to a partition after the number of samples
exceeded the average number of samples per partition. (True in the original
paper's code although not mentioned in paper itself).
efficient_net_type: int
The type of pretrained EfficientNet model.
Options: [0, 1, 2, 3, 4, 5, 6, 7], corresponding to EfficientNet B0-B7 models.
pca_components: int
The number of PCA components for dimensionality reduction.
gmm_max_iter: int
The maximum number of iterations for the GMM algorithm.
gmm_init_params: str
The initialization method for the GMM algorithm.
Options: ["random", "kmeans", "k-means++"]
use_cuda: bool
Whether to use CUDA for computation acceleration.
shuffle: bool
Whether to randomize the order of samples. Shuffling applied after the
samples assignment to partitions.
Expand All @@ -75,27 +97,32 @@ class SemanticPartitioner(Partitioner):
Examples
--------
>>> from flwr_datasets import FederatedDataset
>>> from flwr_datasets.partitioner import SemanticPartitioner
>>>
>>> from flwr_datasets import SemanticPartitioner
>>> partitioner = SemanticPartitioner(
>>> num_partitions=5, partition_by="label", gmm_max_iter=2
>>> num_partitions=10,
>>> partition_by="label",
>>> pca_components=128,
>>> gmm_max_iter=100,
>>> gmm_init_params="kmeans",
>>> use_cuda=True,
>>> shuffle=True,
>>> )
>>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner})
>>> partition = fds.load_partition(0)
>>> print(partition[0]) # Print the first example
{'image': <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7FE9D07D2C20>, 'label': 9}
>>> {'image': <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7FCF49741B10>, 'label': 3}
>>> partition_sizes = partition_sizes = [
>>> len(fds.load_partition(partition_id)) for partition_id in range(5)
>>> ]
>>> print(sorted(partition_sizes))
[8660, 8751, 13120, 13672, 15797]
>>> [3163, 5278, 5496, 6320, 9522]
"""

def __init__( # pylint: disable=R0913
self,
num_partitions: int,
partition_by: str,
efficient_net_type: int = 0,
efficient_net_type: int = 3,
pca_components: int = 128,
gmm_max_iter: int = 100,
gmm_init_params: str = "kmeans",
Expand Down Expand Up @@ -126,6 +153,9 @@ def __init__( # pylint: disable=R0913
self._shuffle = shuffle
self._seed = seed
self._rng_numpy = np.random.default_rng(seed=self._seed)
# defaults, but some datasets have different names, e.g. cifar10 is "img"
# So this variable might be changed in self._check_dataset_type_if_needed()
self._data_column_name = "image"
self._check_variable_validation()
# Utility attributes
# The attributes below are determined during the first call to load_partition
Expand Down Expand Up @@ -155,6 +185,7 @@ def load_partition(self, partition_id: int) -> datasets.Dataset:
# The partitioning is done lazily - only when the first partition is
# requested. Only the first call creates the indices assignments for all the
# partition indices.
self._check_data_validation_if_needed()
self._check_num_partitions_correctness_if_needed()
self._check_pca_components_validation_if_needed()
self._determine_partition_id_to_indices_if_needed()
Expand Down Expand Up @@ -306,15 +337,21 @@ def _determine_partition_id_to_indices_if_needed(self) -> None:
self._partition_id_to_indices_determined = True

def _preprocess_dataset_images(self):
images = np.array(self.dataset["image"], dtype=np.float32)
images = np.array(
self.dataset[self._data_column_name], dtype=np.float32
)
if len(images.shape) == 3: # 1D
images = np.reshape(
images, (images.shape[0], 1, images.shape[1], images.shape[2])
)
elif len(images.shape) == 4: # 2D
images = np.transpose(images, (0, 3, 1, 2))
x, y, z = images.shape[1:]
if z < x and z < y: # [H, W, C]
images = np.transpose(images, (0, 3, 1, 2))
elif x < y and x < z: # [C, H, W]
pass
else:
raise ValueError("The image shape is not supported.")
raise ValueError(f"The image shape is not supported. Now: {images.shape}")
return images

def _check_num_partitions_correctness_if_needed(self) -> None:
Expand All @@ -340,6 +377,35 @@ def _check_pca_components_validation_if_needed(self) -> None:
f"Now: {self._pca_components}."
)

def _check_data_validation_if_needed(self):
"""Test whether dataset is image dataset"""
if not self._partition_id_to_indices_determined:
features_dict = self.dataset.features.to_dict()
self._data_column_name = list(features_dict.keys())[0]
try:
data = np.array(
self.dataset[self._data_column_name][0], dtype=np.float32
)
except:
raise TypeError(
"The dataset needs to be image dataset. "
f"Now: {type(self.dataset[self._data_column_name][0])}."
)

if not (2 <= len(data.shape) <= 3):
raise ValueError(
"The image shape is not supported. "
"The image shape should among {[H, W], [C, H, W], [H, W, C]}. "
f"Now: {data.shape}. "
)
elif len(data.shape) == 3:
x, y, z = data.shape
if not ((x < y and x < z) or (z < x and z < y)) :
raise ValueError(
"The 3D image shape should be [C, H, W] or [H, W, C]. "
f"Now: {data.shape}. "
)

def _check_variable_validation(self):
"""Test class variables validation."""
if not self._num_partitions > 0:
Expand Down Expand Up @@ -378,23 +444,35 @@ def _pairwise_kl_div(


if __name__ == "__main__":
from flwr_datasets import FederatedDataset
# ===================== Test with custom Dataset =====================
from datasets import Dataset

# data = {
# "labels": [i % 3 for i in range(50)],
# "features": [np.random.randn(1, 28, 28) for _ in range(50)],
# }
# dataset = Dataset.from_dict(data)
data = {
"image": [np.random.randn(28, 28) for _ in range(50)],
"label": [i % 3 for i in range(50)],
}
dataset = Dataset.from_dict(data)
partitioner = SemanticPartitioner(
num_partitions=5, partition_by="label", gmm_max_iter=2
num_partitions=5, partition_by="label", pca_components=30
)
# partitioner.dataset = dataset
# partitioner.load_partition(0)
fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner})
partition = fds.load_partition(0)
print(partition[0]) # Print the first example
partitioner.dataset = dataset
partition = partitioner.load_partition(0)
partition_sizes = partition_sizes = [
len(fds.load_partition(partition_id)) for partition_id in range(5)
len(partitioner.load_partition(partition_id)) for partition_id in range(5)
]
print(sorted(partition_sizes))
# ====================================================================

# ===================== Test with FederatedDataset =====================
# from flwr_datasets import FederatedDataset
# partitioner = SemanticPartitioner(
# num_partitions=5, partition_by="label", pca_components=128
# )
# fds = FederatedDataset(dataset="cifar10", partitioners={"train": partitioner})
# partition = fds.load_partition(0)
# print(partition[0]) # Print the first example
# partition_sizes = partition_sizes = [
# len(fds.load_partition(partition_id)) for partition_id in range(5)
# ]
# print(sorted(partition_sizes))
# ======================================================================

0 comments on commit f865947

Please sign in to comment.