Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(datasets) Add pathological partitioner #3623

Merged
merged 25 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7c95f68
Add class contrained partitioner
adam-narozniak Jun 17, 2024
de2c117
Update class constrained partitioner
adam-narozniak Jun 26, 2024
d5af1c6
Add to init
adam-narozniak Jun 26, 2024
079f7b2
Add tests for class_constrained_partitioner
adam-narozniak Jun 26, 2024
27a086c
Add class_assignment_mode argument
adam-narozniak Jul 9, 2024
6799cf1
Update tests
adam-narozniak Jul 9, 2024
6d695c0
Improve docs
adam-narozniak Jul 9, 2024
c3896de
Fix tests
adam-narozniak Jul 9, 2024
c31183a
Fix formatting
adam-narozniak Jul 9, 2024
38469bd
Rename the partitioner to pathological
adam-narozniak Jul 9, 2024
c801ce7
Rename the partitioner to pathological
adam-narozniak Jul 9, 2024
127155b
Fix docs to render correctly
adam-narozniak Jul 9, 2024
f9bb20f
Merge branch 'main' into fds-add-class-constrained
adam-narozniak Jul 9, 2024
cb1d338
Add information how the class_assignment_mode relates to the paper
adam-narozniak Jul 9, 2024
42555a1
Apply suggestions from code review
adam-narozniak Jul 10, 2024
9caa761
Update label types
adam-narozniak Jul 10, 2024
0ec1317
Add unique label sorting
adam-narozniak Jul 10, 2024
1a2698e
Add information about unique_labels sorting
adam-narozniak Jul 10, 2024
4210e3c
Update datasets/flwr_datasets/partitioner/pathological_partitioner_te…
jafermarq Jul 10, 2024
f5cef40
Remove prints from the tests
adam-narozniak Jul 11, 2024
28af927
Raname method name for clarity
adam-narozniak Jul 11, 2024
9903628
Add explanation to _count_partitions_having_each_unique_label
adam-narozniak Jul 11, 2024
2623601
Update datasets/flwr_datasets/partitioner/pathological_partitioner.py
jafermarq Jul 12, 2024
3b95128
Merge branch 'main' into fds-add-class-constrained
jafermarq Jul 12, 2024
c5730bc
Fix formatting
adam-narozniak Jul 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions datasets/flwr_datasets/partitioner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .linear_partitioner import LinearPartitioner
from .natural_id_partitioner import NaturalIdPartitioner
from .partitioner import Partitioner
from .pathological_partitioner import PathologicalPartitioner
from .shard_partitioner import ShardPartitioner
from .size_partitioner import SizePartitioner
from .square_partitioner import SquarePartitioner
Expand All @@ -34,6 +35,7 @@
"LinearPartitioner",
"NaturalIdPartitioner",
"Partitioner",
"PathologicalPartitioner",
"ShardPartitioner",
"SizePartitioner",
"SquarePartitioner",
Expand Down
293 changes: 293 additions & 0 deletions datasets/flwr_datasets/partitioner/pathological_partitioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pathological partitioner class that works with Hugging Face Datasets."""


import warnings
from typing import Dict, List, Literal, Optional

import numpy as np

import datasets
from flwr_datasets.common.typing import NDArray
from flwr_datasets.partitioner.partitioner import Partitioner


# pylint: disable=too-many-arguments, too-many-instance-attributes
class PathologicalPartitioner(Partitioner):
"""Partition dataset such that each partition has a chosen number of classes.

Implementation based on Federated Learning on Non-IID Data Silos: An Experimental
Study https://arxiv.org/pdf/2102.02079.

The algorithm firstly determines which classe will be assigned to which partitions.
For each partition `num_classes_per_partition` are sampled in a way chosen in
`class_assignment_mode`. Given the information about the required classes for each
partition, it is determined into how many parts the samples corresponding to this
label should be divided. Such division is performed for each class.

Parameters
----------
num_partitions : int
The total number of partitions that the data will be divided into.
partition_by : str
Column name of the labels (targets) based on which partitioning works.
num_classes_per_partition: int
The (exact) number of unique classes that a partition each partition will have.
class_assignment_mode: Literal["random", "deterministic", "first-deterministic"]
The way how the classes are assigned to the partitions. The default is "random".
The possible values are:

- "random": Randomly assign classes to the partitions. For each partition choose
the `num_classes_per_partition` classes without replacement.
- "first-deterministic": Assign the first class for each partition in a
deterministic way (class is the partition_id%num_unique_classes).
The rest of the classes are assigned randomly. In case the number of
partitions is smaller than the number of unique classes, not all classes will
be used in the first iteration, otherwise all the classes will be used (such
it will be present in at least one partition).
- "deterministic": Assign all the classes to the partitions in a deterministic
way. Classes are assigned based on the formula: partion_id has classes
identified by the index: (partition_id + i) % num_unique_classes
where i in {0, ..., num_classes_per_partition}. So, partition 0 will have
classes 0, 1, 2, ..., `num_classes_per_partition`-1, partition 1 will have
classes 1, 2, 3, ...,`num_classes_per_partition`, ....

`class_assignment_mode="first-deterministic"` was used in the orginal paper,
here we provide the option to use the other modes as well.
shuffle: bool
Whether to randomize the order of samples. Shuffling applied after the
samples assignment to partitions.
seed: int
Seed used for dataset shuffling. It has no effect if `shuffle` is False.

Examples
--------
In order to mimic the original behavior of the paper follow the setup below
(the `class_assignment_mode="first-deterministic"`):

>>> from flwr_datasets.partitioner import PathologicalPartitioner
>>> from flwr_datasets import FederatedDataset
>>>
>>> partitioner = PathologicalPartitioner(
>>> num_partitions=10,
>>> partition_by="label",
>>> num_classes_per_partition=2,
>>> class_assignment_mode="first-deterministic"
>>> )
>>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner})
>>> partition = fds.load_partition(0)
"""

def __init__(
self,
num_partitions: int,
partition_by: str,
num_classes_per_partition: int,
class_assignment_mode: Literal[
"random", "deterministic", "first-deterministic"
] = "random",
shuffle: bool = True,
seed: Optional[int] = 42,
) -> None:
super().__init__()
self._num_partitions = num_partitions
self._partition_by = partition_by
self._num_classes_per_partition = num_classes_per_partition
self._class_assignment_mode = class_assignment_mode
self._shuffle = shuffle
self._seed = seed
self._rng = np.random.default_rng(seed=self._seed)

# Utility attributes
self._partition_id_to_indices: Dict[int, List[int]] = {}
self._partition_id_to_unique_labels: Dict[int, List[int]] = {
pid: [] for pid in range(self._num_partitions)
}
self._unique_labels: List[int] = []
# Count in how many partitions the label is used
self._unique_label_to_times_used_counter: Dict[int, int] = {}
self._partition_id_to_indices_determined = False

def load_partition(self, partition_id: int) -> datasets.Dataset:
"""Load a partition based on the partition index.

Parameters
----------
partition_id : int
the index that corresponds to the requested partition

Returns
-------
dataset_partition : Dataset
single partition of a dataset
"""
# The partitioning is done lazily - only when the first partition is
# requested. Only the first call creates the indices assignments for all the
# partition indices.
self._check_num_partitions_correctness_if_needed()
self._determine_partition_id_to_indices_if_needed()
return self.dataset.select(self._partition_id_to_indices[partition_id])

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
self._check_num_partitions_correctness_if_needed()
self._determine_partition_id_to_indices_if_needed()
return self._num_partitions

def _determine_partition_id_to_indices_if_needed(self) -> None:
"""Create an assignment of indices to the partition indices."""
if self._partition_id_to_indices_determined:
return
self._determine_partition_id_to_unique_labels()
assert self._unique_labels is not None
self._determine_unique_label_to_times_used()

labels = np.asarray(self.dataset[self._partition_by])
self._check_correctness_of_unique_label_to_times_used_counter(labels)
for partition_id in range(self._num_partitions):
self._partition_id_to_indices[partition_id] = []

unused_labels = []
for unique_label in self._unique_labels:
if self._unique_label_to_times_used_counter[unique_label] == 0:
unused_labels.append(unique_label)
continue
# Get the indices in the original dataset where the y == unique_label
unique_label_to_indices = np.where(labels == unique_label)[0]

split_unique_labels_to_indices = np.array_split(
unique_label_to_indices,
self._unique_label_to_times_used_counter[unique_label],
)

split_index = 0
for partition_id in range(self._num_partitions):
if unique_label in self._partition_id_to_unique_labels[partition_id]:
self._partition_id_to_indices[partition_id].extend(
split_unique_labels_to_indices[split_index]
)
split_index += 1

if len(unused_labels) >= 1:
warnings.warn(
f"Classes: {unused_labels} will NOT be used due to the chosen "
f"configuration. If it is undesired behavior consider setting"
f" 'first_class_deterministic_assignment=True' which in case when"
f" the number of classes is smaller than the number of partitions will "
f"utilize all the classes for the created partitions.",
stacklevel=1,
)
if self._shuffle:
for indices in self._partition_id_to_indices.values():
# In place shuffling
self._rng.shuffle(indices)

self._partition_id_to_indices_determined = True

def _check_num_partitions_correctness_if_needed(self) -> None:
"""Test num_partitions when the dataset is given (in load_partition)."""
if not self._partition_id_to_indices_determined:
if self._num_partitions > self.dataset.num_rows:
raise ValueError(
"The number of partitions needs to be smaller than the number of "
"samples in the dataset."
)

def _determine_partition_id_to_unique_labels(self) -> None:
"""Determine the assignment of unique labels to the partitions."""
self._unique_labels = self.dataset.unique(self._partition_by)
num_unique_classes = len(self._unique_labels)

if self._num_classes_per_partition > num_unique_classes:
raise ValueError(
f"The specified `num_classes_per_partition`"
f"={self._num_classes_per_partition} which is greater than the number "
f"of unique classes in the given dataset={num_unique_classes}. "
f"Reduce the `num_classes_per_partition` or make use different dataset "
f"to apply this partitioning."
)
if self._class_assignment_mode == "first-deterministic":
# if self._first_class_deterministic_assignment:
for partition_id in range(self._num_partitions):
label = partition_id % num_unique_classes
self._partition_id_to_unique_labels[partition_id].append(label)

while (
len(self._partition_id_to_unique_labels[partition_id])
< self._num_classes_per_partition
):
label = self._rng.choice(self._unique_labels, size=1)[0]
if label not in self._partition_id_to_unique_labels[partition_id]:
self._partition_id_to_unique_labels[partition_id].append(label)
elif self._class_assignment_mode == "deterministic":
for partition_id in range(self._num_partitions):
labels = []
for i in range(self._num_classes_per_partition):
label = self._unique_labels[
(partition_id + i) % len(self._unique_labels)
]
labels.append(label)
self._partition_id_to_unique_labels[partition_id] = labels
elif self._class_assignment_mode == "random":
for partition_id in range(self._num_partitions):
labels = self._rng.choice(
self._unique_labels,
size=self._num_classes_per_partition,
replace=False,
).tolist()
self._partition_id_to_unique_labels[partition_id] = labels
else:
raise ValueError(
f"The supported class_assignment_mode are: 'random', 'deterministic', "
f"'first-deterministic'. You provided: {self._class_assignment_mode}."
)

def _determine_unique_label_to_times_used(self) -> None:
"""Determine how many times the label is used in the partitions.

This computation is based on the assigment of the label to the partition_id in
the `_determine_partition_id_to_unique_labels` method.
"""
for unique_label in self._unique_labels:
self._unique_label_to_times_used_counter[unique_label] = 0
for unique_labels in self._partition_id_to_unique_labels.values():
for unique_label in unique_labels:
self._unique_label_to_times_used_counter[unique_label] += 1

def _check_correctness_of_unique_label_to_times_used_counter(
self, labels: NDArray
) -> None:
"""Check if the number of times the label is possible to execute.

The number of times the label can be used must be smaller or equal to the number
of times that the label is present in the dataset.
"""
for unique_label in self._unique_labels:
num_unique = np.sum(labels == unique_label)
if self._unique_label_to_times_used_counter[unique_label] > num_unique:
raise ValueError(
f"Label: {unique_label} is needed to be assigned to more "
f"partitions "
f"({self._unique_label_to_times_used_counter[unique_label]})"
f" than there are samples (corresponding to this label) in the "
f"dataset ({num_unique}). Please decrease the `num_partitions`, "
f"`num_classes_per_partition` to avoid this situation, "
f"or try `class_assigment_mode='deterministic'` to create a more "
f"even distribution of classes along the partitions. "
f"Alternatively use a different dataset if you can not adjust"
f" the any of these parameters."
)
Loading