Skip to content

Commit

Permalink
break(datasets) Rename resplitter parameter and type to preprocessor (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak authored May 29, 2024
1 parent 03c9f79 commit a443f86
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 116 deletions.
4 changes: 2 additions & 2 deletions datasets/flwr_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
"""Flower Datasets main package."""


from flwr_datasets import partitioner, resplitter
from flwr_datasets import partitioner, preprocessor
from flwr_datasets import utils as utils
from flwr_datasets.common.version import package_version as _package_version
from flwr_datasets.federated_dataset import FederatedDataset

__all__ = [
"FederatedDataset",
"partitioner",
"resplitter",
"preprocessor",
"utils",
]

Expand Down
22 changes: 12 additions & 10 deletions datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import datasets
from datasets import Dataset, DatasetDict
from flwr_datasets.partitioner import Partitioner
from flwr_datasets.resplitter import Resplitter
from flwr_datasets.preprocessor import Preprocessor
from flwr_datasets.utils import (
_check_if_dataset_tested,
_instantiate_merger_if_needed,
_instantiate_partitioners,
_instantiate_resplitter_if_needed,
)


Expand All @@ -45,9 +45,11 @@ class FederatedDataset:
subset : str
Secondary information regarding the dataset, most often subset or version
(that is passed to the name in datasets.load_dataset).
resplitter : Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]]
`Callable` that transforms `DatasetDict` splits, or configuration dict for
`MergeResplitter`.
preprocessor : Optional[Union[Preprocessor, Dict[str, Tuple[str, ...]]]]
`Callable` that transforms `DatasetDict` by resplitting, removing
features, creating new features, performing any other preprocessing operation,
or configuration dict for `Merger`. Applied after shuffling. If None,
no operation is applied.
partitioners : Dict[str, Union[Partitioner, int]]
A dictionary mapping the Dataset split (a `str`) to a `Partitioner` or an `int`
(representing the number of IID partitions that this split should be partitioned
Expand Down Expand Up @@ -79,16 +81,16 @@ def __init__(
*,
dataset: str,
subset: Optional[str] = None,
resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] = None,
preprocessor: Optional[Union[Preprocessor, Dict[str, Tuple[str, ...]]]] = None,
partitioners: Dict[str, Union[Partitioner, int]],
shuffle: bool = True,
seed: Optional[int] = 42,
) -> None:
_check_if_dataset_tested(dataset)
self._dataset_name: str = dataset
self._subset: Optional[str] = subset
self._resplitter: Optional[Resplitter] = _instantiate_resplitter_if_needed(
resplitter
self._preprocessor: Optional[Preprocessor] = _instantiate_merger_if_needed(
preprocessor
)
self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners(
partitioners
Expand Down Expand Up @@ -242,8 +244,8 @@ def _prepare_dataset(self) -> None:
# Note it shuffles all the splits. The self._dataset is DatasetDict
# so e.g. {"train": train_data, "test": test_data}. All splits get shuffled.
self._dataset = self._dataset.shuffle(seed=self._seed)
if self._resplitter:
self._dataset = self._resplitter(self._dataset)
if self._preprocessor:
self._dataset = self._preprocessor(self._dataset)
self._dataset_prepared = True

def _check_if_no_split_keyword_possible(self) -> None:
Expand Down
14 changes: 7 additions & 7 deletions datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,20 +170,20 @@ def test_resplit_dataset_into_one(self) -> None:
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners={"train": 100},
resplitter={"full": ("train", self.test_split)},
preprocessor={"full": ("train", self.test_split)},
)
full = fds.load_split("full")
self.assertEqual(dataset_length, len(full))

# pylint: disable=protected-access
def test_resplit_dataset_to_change_names(self) -> None:
"""Test resplitter to change the names of the partitions."""
"""Test preprocessor to change the names of the partitions."""
if self.test_split is None:
return
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners={"new_train": 100},
resplitter={
preprocessor={
"new_train": ("train",),
"new_" + self.test_split: (self.test_split,),
},
Expand All @@ -195,7 +195,7 @@ def test_resplit_dataset_to_change_names(self) -> None:
)

def test_resplit_dataset_by_callable(self) -> None:
"""Test resplitter to change the names of the partitions."""
"""Test preprocessor to change the names of the partitions."""
if self.test_split is None:
return

Expand All @@ -209,7 +209,7 @@ def resplit(dataset: DatasetDict) -> DatasetDict:
)

fds = FederatedDataset(
dataset=self.dataset_name, partitioners={"train": 100}, resplitter=resplit
dataset=self.dataset_name, partitioners={"train": 100}, preprocessor=resplit
)
full = fds.load_split("full")
dataset = datasets.load_dataset(self.dataset_name)
Expand Down Expand Up @@ -298,7 +298,7 @@ def resplit(dataset: DatasetDict) -> DatasetDict:
fds = FederatedDataset(
dataset="does-not-matter",
partitioners={"train": 10},
resplitter=resplit,
preprocessor=resplit,
shuffle=True,
)
train = fds.load_split("train")
Expand Down Expand Up @@ -411,7 +411,7 @@ def test_cannot_use_the_old_split_names(self) -> None:
fds = FederatedDataset(
dataset="mnist",
partitioners={"train": 100},
resplitter={"full": ("train", "test")},
preprocessor={"full": ("train", "test")},
)
with self.assertRaises(ValueError):
fds.load_partition(0, "train")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Resplitter package."""
"""Preprocessor package."""


from .divide_resplitter import DivideResplitter
from .merge_resplitter import MergeResplitter
from .resplitter import Resplitter
from .divider import Divider
from .merger import Merger
from .preprocessor import Preprocessor

__all__ = [
"DivideResplitter",
"MergeResplitter",
"Resplitter",
"Merger",
"Preprocessor",
"Divider",
]
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DivideResplitter class for Flower Datasets."""
"""Divider class for Flower Datasets."""


import collections
Expand All @@ -25,7 +25,7 @@

# flake8: noqa: E501
# pylint: disable=line-too-long
class DivideResplitter:
class Divider:
"""Dive existing split(s) of the dataset and assign them custom names.
Create new `DatasetDict` with new split names with corresponding percentages of data
Expand Down Expand Up @@ -66,14 +66,14 @@ class DivideResplitter:
>>> # Assuming there is a dataset_dict of type `DatasetDict`
>>> # dataset_dict is {"train": train-data, "test": test-data}
>>> resplitter = DivideResplitter(
>>> divider = Divider(
>>> divide_config={
>>> "train": 0.8,
>>> "valid": 0.2,
>>> }
>>> divide_split="train",
>>> )
>>> new_dataset_dict = resplitter(dataset_dict)
>>> new_dataset_dict = divider(dataset_dict)
>>> # new_dataset_dict is
>>> # {"train": 80% of train, "valid": 20% of train, "test": test-data}
Expand All @@ -83,7 +83,7 @@ class DivideResplitter:
>>> # Assuming there is a dataset_dict of type `DatasetDict`
>>> # dataset_dict is {"train": train-data, "test": test-data}
>>> resplitter = DivideResplitter(
>>> divider = Divider(
>>> divide_config={
>>> "train": {
>>> "train": 0.8,
Expand All @@ -92,7 +92,7 @@ class DivideResplitter:
>>> "test": {"test-a": 0.4, "test-b": 0.6 }
>>> }
>>> )
>>> new_dataset_dict = resplitter(dataset_dict)
>>> new_dataset_dict = divider(dataset_dict)
>>> # new_dataset_dict is
>>> # {"train": 80% of train, "valid": 20% of train,
>>> # "test-a": 40% of test, "test-b": 60% of test}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DivideResplitter tests."""
"""Divider tests."""

import unittest
from typing import Dict, Union

from parameterized import parameterized_class

from datasets import Dataset, DatasetDict
from flwr_datasets.resplitter import DivideResplitter
from flwr_datasets.preprocessor import Divider


@parameterized_class(
Expand Down Expand Up @@ -80,8 +80,8 @@
),
],
)
class TestDivideResplitter(unittest.TestCase):
"""DivideResplitter tests."""
class TestDivider(unittest.TestCase):
"""Divider tests."""

divide_config: Union[
Dict[str, float],
Expand All @@ -105,27 +105,27 @@ def setUp(self) -> None:

def test_resplitting_correct_new_split_names(self) -> None:
"""Test if resplitting produces requested new splits."""
resplitter = DivideResplitter(
divider = Divider(
self.divide_config, self.divide_split, self.drop_remaining_splits
)
resplit_dataset = resplitter(self.dataset_dict)
resplit_dataset = divider(self.dataset_dict)
new_keys = set(resplit_dataset.keys())
self.assertEqual(set(self.split_name_to_size.keys()), new_keys)

def test_resplitting_correct_new_split_sizes(self) -> None:
"""Test if resplitting produces correct sizes of splits."""
resplitter = DivideResplitter(
divider = Divider(
self.divide_config, self.divide_split, self.drop_remaining_splits
)
resplit_dataset = resplitter(self.dataset_dict)
resplit_dataset = divider(self.dataset_dict)
split_to_size = {
split_name: len(split) for split_name, split in resplit_dataset.items()
}
self.assertEqual(self.split_name_to_size, split_to_size)


class TestDivideResplitterIncorrectUseCases(unittest.TestCase):
"""Resplitter tests."""
class TestDividerIncorrectUseCases(unittest.TestCase):
"""Divider tests."""

def setUp(self) -> None:
"""Set up the dataset with 3 splits for tests."""
Expand All @@ -144,21 +144,17 @@ def test_doubling_names_in_config(self) -> None:
drop_remaining_splits = False

with self.assertRaises(ValueError):
resplitter = DivideResplitter(
divide_config, divide_split, drop_remaining_splits
)
_ = resplitter(self.dataset_dict)
divider = Divider(divide_config, divide_split, drop_remaining_splits)
_ = divider(self.dataset_dict)

def test_duplicate_names_in_config_and_dataset_split_names_multisplit(self) -> None:
"""Test if resplitting raises when the name collides with the old name."""
divide_config = {"train": {"valid": 0.5}}
divide_split = None
drop_remaining_splits = False
resplitter = DivideResplitter(
divide_config, divide_split, drop_remaining_splits
)
divider = Divider(divide_config, divide_split, drop_remaining_splits)
with self.assertRaises(ValueError):
_ = resplitter(self.dataset_dict)
_ = divider(self.dataset_dict)

def test_duplicate_names_in_config_and_dataset_split_names_single_split(
self,
Expand All @@ -167,77 +163,63 @@ def test_duplicate_names_in_config_and_dataset_split_names_single_split(
divide_config = {"valid": 0.5}
divide_split = "train"
drop_remaining_splits = False
resplitter = DivideResplitter(
divide_config, divide_split, drop_remaining_splits
)
divider = Divider(divide_config, divide_split, drop_remaining_splits)
with self.assertRaises(ValueError):
_ = resplitter(self.dataset_dict)
_ = divider(self.dataset_dict)

def test_fraction_sum_up_to_more_than_one_multisplit(self) -> None:
"""Test if resplitting raises when fractions sum up to > 1.0 ."""
divide_config = {"train": {"train_1": 0.5, "train_2": 0.7}}
divide_split = None
drop_remaining_splits = False
resplitter = DivideResplitter(
divide_config, divide_split, drop_remaining_splits
)
divider = Divider(divide_config, divide_split, drop_remaining_splits)
with self.assertRaises(ValueError):
_ = resplitter(self.dataset_dict)
_ = divider(self.dataset_dict)

def test_fraction_sum_up_to_more_than_one_single_split(self) -> None:
"""Test if resplitting raises when fractions sum up to > 1.0 ."""
divide_config = {"train_1": 0.5, "train_2": 0.7}
divide_split = "train"
drop_remaining_splits = False
resplitter = DivideResplitter(
divide_config, divide_split, drop_remaining_splits
)
divider = Divider(divide_config, divide_split, drop_remaining_splits)
with self.assertRaises(ValueError):
_ = resplitter(self.dataset_dict)
_ = divider(self.dataset_dict)

def test_sample_sizes_sum_up_to_more_than_dataset_size_single_split(self) -> None:
"""Test if resplitting raises when samples size sum up to > len(datset) ."""
divide_config = {"train": {"train_1": 20, "train_2": 25}}
divide_split = None
drop_remaining_splits = False
resplitter = DivideResplitter(
divide_config, divide_split, drop_remaining_splits
)
divider = Divider(divide_config, divide_split, drop_remaining_splits)
with self.assertRaises(ValueError):
_ = resplitter(self.dataset_dict)
_ = divider(self.dataset_dict)

def test_sample_sizes_sum_up_to_more_than_dataset_size_multisplit(self) -> None:
"""Test if resplitting raises when samples size sum up to > len(datset) ."""
divide_config = {"train_1": 20, "train_2": 25}
divide_split = "train"
drop_remaining_splits = False
resplitter = DivideResplitter(
divide_config, divide_split, drop_remaining_splits
)
divider = Divider(divide_config, divide_split, drop_remaining_splits)
with self.assertRaises(ValueError):
_ = resplitter(self.dataset_dict)
_ = divider(self.dataset_dict)

def test_too_small_size_values_create_empty_dataset_single_split(self) -> None:
"""Test if resplitting raises when fraction creates empty dataset."""
divide_config = {"train": {"train_1": 0.2, "train_2": 0.0001}}
divide_split = None
drop_remaining_splits = False
resplitter = DivideResplitter(
divide_config, divide_split, drop_remaining_splits
)
divider = Divider(divide_config, divide_split, drop_remaining_splits)
with self.assertRaises(ValueError):
_ = resplitter(self.dataset_dict)
_ = divider(self.dataset_dict)

def test_too_small_size_values_create_empty_dataset_multisplit(self) -> None:
"""Test if resplitting raises when fraction creates empty dataset."""
divide_config = {"train_1": 0.2, "train_2": 0.0001}
divide_split = "train"
drop_remaining_splits = False
resplitter = DivideResplitter(
divide_config, divide_split, drop_remaining_splits
)
divider = Divider(divide_config, divide_split, drop_remaining_splits)
with self.assertRaises(ValueError):
_ = resplitter(self.dataset_dict)
_ = divider(self.dataset_dict)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit a443f86

Please sign in to comment.