Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into fds-rename-load-full-…
Browse files Browse the repository at this point in the history
…to-load-split
  • Loading branch information
adam-narozniak committed Mar 13, 2024
2 parents 3a3587c + 930cdaf commit ff9652f
Show file tree
Hide file tree
Showing 50 changed files with 1,356 additions and 890 deletions.
2 changes: 2 additions & 0 deletions datasets/flwr_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@


from flwr_datasets import partitioner, resplitter
from flwr_datasets import utils as utils
from flwr_datasets.common.version import package_version as _package_version
from flwr_datasets.federated_dataset import FederatedDataset

__all__ = [
"FederatedDataset",
"partitioner",
"resplitter",
"utils",
]

__version__ = _package_version
133 changes: 122 additions & 11 deletions datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""FederatedDataset."""


from typing import Dict, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union, cast

import datasets
from datasets import Dataset, DatasetDict
Expand All @@ -25,9 +25,12 @@
_check_if_dataset_tested,
_instantiate_partitioners,
_instantiate_resplitter_if_needed,
divide_dataset,
)


# flake8: noqa: E501
# pylint: disable=line-too-long
class FederatedDataset:
"""Representation of a dataset for federated learning/evaluation/analytics.
Expand All @@ -51,6 +54,19 @@ class FederatedDataset:
(representing the number of IID partitions that this split should be partitioned
into). One or multiple `Partitioner` objects can be specified in that manner,
but at most, one per split.
partition_division : Optional[Union[List[float], Tuple[float, ...],
Dict[str, float], Dict[str, Optional[Union[List[float], Tuple[float, ...],
Dict[str, float]]]]]]
Fractions specifing the division of the partition assiciated with certain split
(and partitioner) that enable returning already divided partition from the
`load_partition` method. You can think of this as on-edge division of the data
into multiple divisions (e.g. into train and validation). You can also name the
divisions by using the Dict or create specify it as a List/Tuple. If you
specified a single partitioner you can provide the simplified form e.g.
[0.8, 0.2] or {"partition_train": 0.8, "partition_test": 0.2} but when multiple
partitioners are specified you need to indicate the result of which partitioner
are further divided e.g. {"train": [0.8, 0.2]} would result in dividing only the
partitions that are created from the "train" split.
shuffle : bool
Whether to randomize the order of samples. Applied prior to resplitting,
speratelly to each of the present splits in the dataset. It uses the `seed`
Expand All @@ -64,14 +80,18 @@ class FederatedDataset:
Use MNIST dataset for Federated Learning with 100 clients (edge devices):
>>> mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": 100})
Load partition for client with ID 10.
>>> # Load partition for client with ID 10.
>>> partition = mnist_fds.load_partition(10, "train")
Use test split for centralized evaluation.
>>> # Use test split for centralized evaluation.
>>> centralized = mnist_fds.load_split("test")
Automatically divde the data returned from `load_partition`
>>> mnist_fds = FederatedDataset(
>>> dataset="mnist",
>>> partitioners={"train": 100},
>>> partition_division=[0.8, 0.2],
>>> )
>>> partition_train, partition_test = mnist_fds.load_partition(10, "train")
"""

# pylint: disable=too-many-instance-attributes
Expand All @@ -82,6 +102,17 @@ def __init__(
subset: Optional[str] = None,
resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] = None,
partitioners: Dict[str, Union[Partitioner, int]],
partition_division: Optional[
Union[
List[float],
Tuple[float, ...],
Dict[str, float],
Dict[
str,
Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]],
],
]
] = None,
shuffle: bool = True,
seed: Optional[int] = 42,
) -> None:
Expand All @@ -94,6 +125,9 @@ def __init__(
self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners(
partitioners
)
self._partition_division = self._initialize_partition_division(
partition_division
)
self._shuffle = shuffle
self._seed = seed
# _dataset is prepared lazily on the first call to `load_partition`
Expand All @@ -102,7 +136,11 @@ def __init__(
# Indicate if the dataset is prepared for `load_partition` or `load_split`
self._dataset_prepared: bool = False

def load_partition(self, node_id: int, split: Optional[str] = None) -> Dataset:
def load_partition(
self,
node_id: int,
split: Optional[str] = None,
) -> Union[Dataset, List[Dataset], DatasetDict]:
"""Load the partition specified by the idx in the selected split.
The dataset is downloaded only when the first call to `load_partition` or
Expand All @@ -122,8 +160,13 @@ def load_partition(self, node_id: int, split: Optional[str] = None) -> Dataset:
Returns
-------
partition : Dataset
Single partition from the dataset split.
partition : Union[Dataset, List[Dataset], DatasetDict]
Undivided or divided partition from the dataset split.
If `partition_division` is not specified then `Dataset` is returned.
If `partition_division` is specified as `List` or `Tuple` then
`List[Dataset]` is returned.
If `partition_division` is specified as `Dict` then `DatasetDict` is
returned.
"""
if not self._dataset_prepared:
self._prepare_dataset()
Expand All @@ -136,7 +179,16 @@ def load_partition(self, node_id: int, split: Optional[str] = None) -> Dataset:
self._check_if_split_possible_to_federate(split)
partitioner: Partitioner = self._partitioners[split]
self._assign_dataset_to_partitioner(split)
return partitioner.load_partition(node_id)
partition = partitioner.load_partition(node_id)
if self._partition_division is None:
return partition
partition_division = self._partition_division.get(split)
if partition_division is None:
return partition
divided_partition: Union[List[Dataset], DatasetDict] = divide_dataset(
partition, partition_division
)
return divided_partition

def load_split(self, split: str) -> Dataset:
"""Load the full split of the dataset.
Expand Down Expand Up @@ -230,3 +282,62 @@ def _check_if_no_split_keyword_possible(self) -> None:
"Please set the `split` argument. You can only omit the split keyword "
"if there is exactly one partitioner specified."
)

def _initialize_partition_division(
self,
partition_division: Optional[
Union[
List[float],
Tuple[float, ...],
Dict[str, float],
Dict[
str,
Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]],
],
]
],
) -> Optional[
Dict[
str,
Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]],
]
]:
"""Create the partition division in the full format.
Reduced format (possible if only one partitioner exist):
Union[List[float], Tuple[float, ...], Dict[str, float]
Full format: Dict[str, Reduced format]
Full format represents the split to division mapping.
"""
# Check for simple dict, list, or tuple types directly
if isinstance(partition_division, (list, tuple)) or (
isinstance(partition_division, dict)
and all(isinstance(value, float) for value in partition_division.values())
):
if len(self._partitioners) > 1:
raise ValueError(
f"The specified partition_division {partition_division} does not "
f"provide mapping to split but more than one partitioners is "
f"specified. Please adjust the partition_division specification to "
f"have the split names as the keys."
)
return cast(
Dict[
str,
Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]],
],
{list(self._partitioners.keys())[0]: partition_division},
)
if isinstance(partition_division, dict):
return cast(
Dict[
str,
Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]],
],
partition_division,
)
if partition_division is None:
return None
raise TypeError("Unsupported type for partition_division")
44 changes: 43 additions & 1 deletion datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


import unittest
from typing import Dict, Union
from typing import Dict, List, Optional, Tuple, Union
from unittest.mock import Mock, patch

import pytest
Expand Down Expand Up @@ -67,6 +67,48 @@ def test_load_partition_size(self, _: str, train_num_partitions: int) -> None:
len(dataset_partition0), len(dataset["train"]) // train_num_partitions
)

@parameterized.expand( # type: ignore
[
((0.2, 0.8), 2, False),
({"train": 0.2, "test": 0.8}, 2, False),
({"train": {"train": 0.2, "test": 0.8}}, 2, True),
# Not full dataset
([0.2, 0.1], 2, False),
({"train": 0.2, "test": 0.1}, 2, False),
(None, None, False),
],
)
def test_divide_partition_integration_size(
self,
partition_division: Optional[
Union[
List[float],
Tuple[float, ...],
Dict[str, float],
Dict[
str,
Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]],
],
]
],
expected_length: Optional[int],
add_test_partitioner: bool,
):
"""Test is the `partition_division` create correct data."""
partitioners: Dict[str, Union[Partitioner, int]] = {"train": 10}
if add_test_partitioner:
partitioners[self.test_split] = 10
dataset_fds = FederatedDataset(
dataset=self.dataset_name,
partitioners=partitioners,
partition_division=partition_division,
)
partition = dataset_fds.load_partition(0, "train")
if partition_division is None:
self.assertEqual(expected_length, None)
else:
self.assertEqual(len(partition), expected_length)

def test_load_split(self) -> None:
"""Test if the load_split works with the correct split name."""
dataset_fds = FederatedDataset(
Expand Down
9 changes: 8 additions & 1 deletion datasets/flwr_datasets/partitioner/dirichlet_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__( # pylint: disable=R0913
partition_by: str,
alpha: Union[int, float, List[float], NDArrayFloat],
min_partition_size: int = 10,
self_balancing: bool = True,
self_balancing: bool = False,
shuffle: bool = True,
seed: Optional[int] = 42,
) -> None:
Expand Down Expand Up @@ -132,6 +132,13 @@ def load_partition(self, node_id: int) -> datasets.Dataset:
self._determine_node_id_to_indices_if_needed()
return self.dataset.select(self._node_id_to_indices[node_id])

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
self._check_num_partitions_correctness_if_needed()
self._determine_node_id_to_indices_if_needed()
return self._num_partitions

def _initialize_alpha(
self, alpha: Union[int, float, List[float], NDArrayFloat]
) -> NDArrayFloat:
Expand Down
5 changes: 5 additions & 0 deletions datasets/flwr_datasets/partitioner/iid_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,8 @@ def load_partition(self, node_id: int) -> datasets.Dataset:
return self.dataset.shard(
num_shards=self._num_partitions, index=node_id, contiguous=True
)

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
return self._num_partitions
11 changes: 11 additions & 0 deletions datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,17 @@ def load_partition(self, node_id: int) -> datasets.Dataset:
self._determine_node_id_to_indices_if_needed()
return self.dataset.select(self._node_id_to_indices[node_id])

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
self._check_num_partitions_correctness_if_needed()
self._check_partition_sizes_correctness_if_needed()
self._check_the_sum_of_partition_sizes()
self._determine_num_unique_classes_if_needed()
self._alpha = self._initialize_alpha_if_needed(self._initial_alpha)
self._determine_node_id_to_indices_if_needed()
return self._num_partitions

def _initialize_alpha_if_needed(
self, alpha: Union[int, float, List[float], NDArrayFloat]
) -> NDArrayFloat:
Expand Down
7 changes: 7 additions & 0 deletions datasets/flwr_datasets/partitioner/natural_id_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ def load_partition(self, node_id: int) -> datasets.Dataset:
lambda row: row[self._partition_by] == self._node_id_to_natural_id[node_id]
)

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
if len(self._node_id_to_natural_id) == 0:
self._create_int_node_id_to_natural_id()
return len(self._node_id_to_natural_id)

@property
def node_id_to_natural_id(self) -> Dict[int, str]:
"""Node id to corresponding natural id present.
Expand Down
5 changes: 5 additions & 0 deletions datasets/flwr_datasets/partitioner/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,8 @@ def is_dataset_assigned(self) -> bool:
True if a dataset is assigned, otherwise False.
"""
return self._dataset is not None

@property
@abstractmethod
def num_partitions(self) -> int:
"""Total number of partitions."""
9 changes: 9 additions & 0 deletions datasets/flwr_datasets/partitioner/shard_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@ def load_partition(self, node_id: int) -> datasets.Dataset:
self._determine_node_id_to_indices_if_needed()
return self.dataset.select(self._node_id_to_indices[node_id])

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
self._check_num_partitions_correctness_if_needed()
self._check_possibility_of_partitions_creation()
self._sort_dataset_if_needed()
self._determine_node_id_to_indices_if_needed()
return self._num_partitions

def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0914
"""Assign sample indices to each node id.
Expand Down
6 changes: 6 additions & 0 deletions datasets/flwr_datasets/partitioner/size_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ def load_partition(self, node_id: int) -> datasets.Dataset:
self._determine_node_id_to_indices_if_needed()
return self.dataset.select(self._node_id_to_indices[node_id])

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
self._determine_node_id_to_indices_if_needed()
return self._num_partitions

@property
def node_id_to_size(self) -> Dict[int, int]:
"""Node id to the number of samples."""
Expand Down
Loading

0 comments on commit ff9652f

Please sign in to comment.