Skip to content

Commit

Permalink
Reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
KarhouTam committed Jun 20, 2024
1 parent f178eff commit e8cb7ab
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 26 deletions.
2 changes: 1 addition & 1 deletion datasets/flwr_datasets/partitioner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"LinearPartitioner",
"NaturalIdPartitioner",
"Partitioner",
"SemanticPartitioner"
"SemanticPartitioner",
"ShardPartitioner",
"SizePartitioner",
"SquarePartitioner",
Expand Down
48 changes: 23 additions & 25 deletions datasets/flwr_datasets/partitioner/semantic_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,31 +44,31 @@ class SemanticPartitioner(Partitioner):
References:
What Do We Mean by Generalization in Federated Learning? (accepted by ICLR 2022)
https://arxiv.org/abs/2110.14216
https://arxiv.org/abs/2110.14216
(Cited from section 4.1 in the paper)
Semantic partitioner's goal is to reverse-engineer the federated dataset-generating
process so that each client possesses semantically similar data. For example, for
the EMNIST dataset, we expect every client (writer) to (i) write in a consistent style
for each digit (intra-client intra-label similarity) and (ii) use a consistent writing
style across all digits (intra-client inter-label similarity). A simple approach might
be to cluster similar examples together and sample client data from clusters. However,
if one directly clusters the entire dataset, the resulting clusters may end up largely
correlated to labels. To disentangle the effect of label heterogeneity and semantic
heterogeneity, we propose the following algorithm to enforce intra-client intra-label
Semantic partitioner's goal is to reverse-engineer the federated dataset-generating
process so that each client possesses semantically similar data. For example, for
the EMNIST dataset, we expect every client (writer) to (i) write in a consistent style
for each digit (intra-client intra-label similarity) and (ii) use a consistent writing
style across all digits (intra-client inter-label similarity). A simple approach might
be to cluster similar examples together and sample client data from clusters. However,
if one directly clusters the entire dataset, the resulting clusters may end up largely
correlated to labels. To disentangle the effect of label heterogeneity and semantic
heterogeneity, we propose the following algorithm to enforce intra-client intra-label
similarity and intra-client inter-label similarity in two separate stages.
• Stage 1: For each label, we embed examples using a pretrained neural network
(extracting semantic features), and fit a Gaussian Mixture Model to cluster pretrained
embeddings into groups. Note that this results in multiple groups per label.
• Stage 1: For each label, we embed examples using a pretrained neural network
(extracting semantic features), and fit a Gaussian Mixture Model to cluster pretrained
embeddings into groups. Note that this results in multiple groups per label.
This stage enforces intra-client intra-label consistency.
• Stage 2: To package the clusters from different labels into clients, we aim to compute
an optimal multi-partite matching with cost-matrix defined by KL-divergence between
the Gaussian clusters. To reduce complexity, we heuristically solve the optimal multi-partite
matching by progressively solving the optimal bipartite matching at each time for
randomly-chosen label pairs.
• Stage 2: To package the clusters from different labels into clients, we aim to compute
an optimal multi-partite matching with cost-matrix defined by KL-divergence between
the Gaussian clusters. To reduce complexity, we heuristically solve the optimal multi-partite
matching by progressively solving the optimal bipartite matching at each time for
randomly-chosen label pairs.
This stage enforces intra-client inter-label consistency.
Parameters
Expand Down Expand Up @@ -338,9 +338,7 @@ def _determine_partition_id_to_indices_if_needed(self) -> None:
self._partition_id_to_indices_determined = True

def _preprocess_dataset_images(self):
images = np.array(
self.dataset[self._data_column_name], dtype=np.float32
)
images = np.array(self.dataset[self._data_column_name], dtype=np.float32)
if len(images.shape) == 3: # 1D
images = np.reshape(
images, (images.shape[0], 1, images.shape[1], images.shape[2])
Expand Down Expand Up @@ -401,7 +399,7 @@ def _check_data_validation_if_needed(self):
)
elif len(data.shape) == 3:
x, y, z = data.shape
if not ((x < y and x < z) or (z < x and z < y)) :
if not ((x < y and x < z) or (z < x and z < y)):
raise ValueError(
"The 3D image shape should be [C, H, W] or [H, W, C]. "
f"Now: {data.shape}. "
Expand Down

0 comments on commit e8cb7ab

Please sign in to comment.