From a443f86ee7a59e4dbdb9c76200cc924d19edbc29 Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Wed, 29 May 2024 12:47:07 +0200 Subject: [PATCH] break(datasets) Rename resplitter parameter and type to preprocessor (#3476) --- datasets/flwr_datasets/__init__.py | 4 +- datasets/flwr_datasets/federated_dataset.py | 22 +++--- .../flwr_datasets/federated_dataset_test.py | 14 ++-- .../{resplitter => preprocessor}/__init__.py | 14 ++-- .../divider.py} | 12 +-- .../divider_test.py} | 74 +++++++------------ .../merger.py} | 8 +- .../merger_test.py} | 46 ++++++------ .../preprocessor.py} | 4 +- datasets/flwr_datasets/utils.py | 18 ++--- 10 files changed, 100 insertions(+), 116 deletions(-) rename datasets/flwr_datasets/{resplitter => preprocessor}/__init__.py (76%) rename datasets/flwr_datasets/{resplitter/divide_resplitter.py => preprocessor/divider.py} (98%) rename datasets/flwr_datasets/{resplitter/divide_resplitter_test.py => preprocessor/divider_test.py} (79%) rename datasets/flwr_datasets/{resplitter/merge_resplitter.py => preprocessor/merger.py} (96%) rename datasets/flwr_datasets/{resplitter/merge_resplitter_test.py => preprocessor/merger_test.py} (81%) rename datasets/flwr_datasets/{resplitter/resplitter.py => preprocessor/preprocessor.py} (91%) diff --git a/datasets/flwr_datasets/__init__.py b/datasets/flwr_datasets/__init__.py index 0b9a6685427b..2d6ecb414498 100644 --- a/datasets/flwr_datasets/__init__.py +++ b/datasets/flwr_datasets/__init__.py @@ -15,7 +15,7 @@ """Flower Datasets main package.""" -from flwr_datasets import partitioner, resplitter +from flwr_datasets import partitioner, preprocessor from flwr_datasets import utils as utils from flwr_datasets.common.version import package_version as _package_version from flwr_datasets.federated_dataset import FederatedDataset @@ -23,7 +23,7 @@ __all__ = [ "FederatedDataset", "partitioner", - "resplitter", + "preprocessor", "utils", ] diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index 6c41eaa3562f..5d98d01d4941 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -20,11 +20,11 @@ import datasets from datasets import Dataset, DatasetDict from flwr_datasets.partitioner import Partitioner -from flwr_datasets.resplitter import Resplitter +from flwr_datasets.preprocessor import Preprocessor from flwr_datasets.utils import ( _check_if_dataset_tested, + _instantiate_merger_if_needed, _instantiate_partitioners, - _instantiate_resplitter_if_needed, ) @@ -45,9 +45,11 @@ class FederatedDataset: subset : str Secondary information regarding the dataset, most often subset or version (that is passed to the name in datasets.load_dataset). - resplitter : Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] - `Callable` that transforms `DatasetDict` splits, or configuration dict for - `MergeResplitter`. + preprocessor : Optional[Union[Preprocessor, Dict[str, Tuple[str, ...]]]] + `Callable` that transforms `DatasetDict` by resplitting, removing + features, creating new features, performing any other preprocessing operation, + or configuration dict for `Merger`. Applied after shuffling. If None, + no operation is applied. partitioners : Dict[str, Union[Partitioner, int]] A dictionary mapping the Dataset split (a `str`) to a `Partitioner` or an `int` (representing the number of IID partitions that this split should be partitioned @@ -79,7 +81,7 @@ def __init__( *, dataset: str, subset: Optional[str] = None, - resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] = None, + preprocessor: Optional[Union[Preprocessor, Dict[str, Tuple[str, ...]]]] = None, partitioners: Dict[str, Union[Partitioner, int]], shuffle: bool = True, seed: Optional[int] = 42, @@ -87,8 +89,8 @@ def __init__( _check_if_dataset_tested(dataset) self._dataset_name: str = dataset self._subset: Optional[str] = subset - self._resplitter: Optional[Resplitter] = _instantiate_resplitter_if_needed( - resplitter + self._preprocessor: Optional[Preprocessor] = _instantiate_merger_if_needed( + preprocessor ) self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners( partitioners @@ -242,8 +244,8 @@ def _prepare_dataset(self) -> None: # Note it shuffles all the splits. The self._dataset is DatasetDict # so e.g. {"train": train_data, "test": test_data}. All splits get shuffled. self._dataset = self._dataset.shuffle(seed=self._seed) - if self._resplitter: - self._dataset = self._resplitter(self._dataset) + if self._preprocessor: + self._dataset = self._preprocessor(self._dataset) self._dataset_prepared = True def _check_if_no_split_keyword_possible(self) -> None: diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index 5d5179122e3b..f65aa6346f3a 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -170,20 +170,20 @@ def test_resplit_dataset_into_one(self) -> None: fds = FederatedDataset( dataset=self.dataset_name, partitioners={"train": 100}, - resplitter={"full": ("train", self.test_split)}, + preprocessor={"full": ("train", self.test_split)}, ) full = fds.load_split("full") self.assertEqual(dataset_length, len(full)) # pylint: disable=protected-access def test_resplit_dataset_to_change_names(self) -> None: - """Test resplitter to change the names of the partitions.""" + """Test preprocessor to change the names of the partitions.""" if self.test_split is None: return fds = FederatedDataset( dataset=self.dataset_name, partitioners={"new_train": 100}, - resplitter={ + preprocessor={ "new_train": ("train",), "new_" + self.test_split: (self.test_split,), }, @@ -195,7 +195,7 @@ def test_resplit_dataset_to_change_names(self) -> None: ) def test_resplit_dataset_by_callable(self) -> None: - """Test resplitter to change the names of the partitions.""" + """Test preprocessor to change the names of the partitions.""" if self.test_split is None: return @@ -209,7 +209,7 @@ def resplit(dataset: DatasetDict) -> DatasetDict: ) fds = FederatedDataset( - dataset=self.dataset_name, partitioners={"train": 100}, resplitter=resplit + dataset=self.dataset_name, partitioners={"train": 100}, preprocessor=resplit ) full = fds.load_split("full") dataset = datasets.load_dataset(self.dataset_name) @@ -298,7 +298,7 @@ def resplit(dataset: DatasetDict) -> DatasetDict: fds = FederatedDataset( dataset="does-not-matter", partitioners={"train": 10}, - resplitter=resplit, + preprocessor=resplit, shuffle=True, ) train = fds.load_split("train") @@ -411,7 +411,7 @@ def test_cannot_use_the_old_split_names(self) -> None: fds = FederatedDataset( dataset="mnist", partitioners={"train": 100}, - resplitter={"full": ("train", "test")}, + preprocessor={"full": ("train", "test")}, ) with self.assertRaises(ValueError): fds.load_partition(0, "train") diff --git a/datasets/flwr_datasets/resplitter/__init__.py b/datasets/flwr_datasets/preprocessor/__init__.py similarity index 76% rename from datasets/flwr_datasets/resplitter/__init__.py rename to datasets/flwr_datasets/preprocessor/__init__.py index bf39786e0593..bab5d82a2035 100644 --- a/datasets/flwr_datasets/resplitter/__init__.py +++ b/datasets/flwr_datasets/preprocessor/__init__.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Resplitter package.""" +"""Preprocessor package.""" -from .divide_resplitter import DivideResplitter -from .merge_resplitter import MergeResplitter -from .resplitter import Resplitter +from .divider import Divider +from .merger import Merger +from .preprocessor import Preprocessor __all__ = [ - "DivideResplitter", - "MergeResplitter", - "Resplitter", + "Merger", + "Preprocessor", + "Divider", ] diff --git a/datasets/flwr_datasets/resplitter/divide_resplitter.py b/datasets/flwr_datasets/preprocessor/divider.py similarity index 98% rename from datasets/flwr_datasets/resplitter/divide_resplitter.py rename to datasets/flwr_datasets/preprocessor/divider.py index 56150b51af85..9d7570de4cea 100644 --- a/datasets/flwr_datasets/resplitter/divide_resplitter.py +++ b/datasets/flwr_datasets/preprocessor/divider.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""DivideResplitter class for Flower Datasets.""" +"""Divider class for Flower Datasets.""" import collections @@ -25,7 +25,7 @@ # flake8: noqa: E501 # pylint: disable=line-too-long -class DivideResplitter: +class Divider: """Dive existing split(s) of the dataset and assign them custom names. Create new `DatasetDict` with new split names with corresponding percentages of data @@ -66,14 +66,14 @@ class DivideResplitter: >>> # Assuming there is a dataset_dict of type `DatasetDict` >>> # dataset_dict is {"train": train-data, "test": test-data} - >>> resplitter = DivideResplitter( + >>> divider = Divider( >>> divide_config={ >>> "train": 0.8, >>> "valid": 0.2, >>> } >>> divide_split="train", >>> ) - >>> new_dataset_dict = resplitter(dataset_dict) + >>> new_dataset_dict = divider(dataset_dict) >>> # new_dataset_dict is >>> # {"train": 80% of train, "valid": 20% of train, "test": test-data} @@ -83,7 +83,7 @@ class DivideResplitter: >>> # Assuming there is a dataset_dict of type `DatasetDict` >>> # dataset_dict is {"train": train-data, "test": test-data} - >>> resplitter = DivideResplitter( + >>> divider = Divider( >>> divide_config={ >>> "train": { >>> "train": 0.8, @@ -92,7 +92,7 @@ class DivideResplitter: >>> "test": {"test-a": 0.4, "test-b": 0.6 } >>> } >>> ) - >>> new_dataset_dict = resplitter(dataset_dict) + >>> new_dataset_dict = divider(dataset_dict) >>> # new_dataset_dict is >>> # {"train": 80% of train, "valid": 20% of train, >>> # "test-a": 40% of test, "test-b": 60% of test} diff --git a/datasets/flwr_datasets/resplitter/divide_resplitter_test.py b/datasets/flwr_datasets/preprocessor/divider_test.py similarity index 79% rename from datasets/flwr_datasets/resplitter/divide_resplitter_test.py rename to datasets/flwr_datasets/preprocessor/divider_test.py index 143297fcc1a7..ed282fbc18be 100644 --- a/datasets/flwr_datasets/resplitter/divide_resplitter_test.py +++ b/datasets/flwr_datasets/preprocessor/divider_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""DivideResplitter tests.""" +"""Divider tests.""" import unittest from typing import Dict, Union @@ -20,7 +20,7 @@ from parameterized import parameterized_class from datasets import Dataset, DatasetDict -from flwr_datasets.resplitter import DivideResplitter +from flwr_datasets.preprocessor import Divider @parameterized_class( @@ -80,8 +80,8 @@ ), ], ) -class TestDivideResplitter(unittest.TestCase): - """DivideResplitter tests.""" +class TestDivider(unittest.TestCase): + """Divider tests.""" divide_config: Union[ Dict[str, float], @@ -105,27 +105,27 @@ def setUp(self) -> None: def test_resplitting_correct_new_split_names(self) -> None: """Test if resplitting produces requested new splits.""" - resplitter = DivideResplitter( + divider = Divider( self.divide_config, self.divide_split, self.drop_remaining_splits ) - resplit_dataset = resplitter(self.dataset_dict) + resplit_dataset = divider(self.dataset_dict) new_keys = set(resplit_dataset.keys()) self.assertEqual(set(self.split_name_to_size.keys()), new_keys) def test_resplitting_correct_new_split_sizes(self) -> None: """Test if resplitting produces correct sizes of splits.""" - resplitter = DivideResplitter( + divider = Divider( self.divide_config, self.divide_split, self.drop_remaining_splits ) - resplit_dataset = resplitter(self.dataset_dict) + resplit_dataset = divider(self.dataset_dict) split_to_size = { split_name: len(split) for split_name, split in resplit_dataset.items() } self.assertEqual(self.split_name_to_size, split_to_size) -class TestDivideResplitterIncorrectUseCases(unittest.TestCase): - """Resplitter tests.""" +class TestDividerIncorrectUseCases(unittest.TestCase): + """Divider tests.""" def setUp(self) -> None: """Set up the dataset with 3 splits for tests.""" @@ -144,21 +144,17 @@ def test_doubling_names_in_config(self) -> None: drop_remaining_splits = False with self.assertRaises(ValueError): - resplitter = DivideResplitter( - divide_config, divide_split, drop_remaining_splits - ) - _ = resplitter(self.dataset_dict) + divider = Divider(divide_config, divide_split, drop_remaining_splits) + _ = divider(self.dataset_dict) def test_duplicate_names_in_config_and_dataset_split_names_multisplit(self) -> None: """Test if resplitting raises when the name collides with the old name.""" divide_config = {"train": {"valid": 0.5}} divide_split = None drop_remaining_splits = False - resplitter = DivideResplitter( - divide_config, divide_split, drop_remaining_splits - ) + divider = Divider(divide_config, divide_split, drop_remaining_splits) with self.assertRaises(ValueError): - _ = resplitter(self.dataset_dict) + _ = divider(self.dataset_dict) def test_duplicate_names_in_config_and_dataset_split_names_single_split( self, @@ -167,77 +163,63 @@ def test_duplicate_names_in_config_and_dataset_split_names_single_split( divide_config = {"valid": 0.5} divide_split = "train" drop_remaining_splits = False - resplitter = DivideResplitter( - divide_config, divide_split, drop_remaining_splits - ) + divider = Divider(divide_config, divide_split, drop_remaining_splits) with self.assertRaises(ValueError): - _ = resplitter(self.dataset_dict) + _ = divider(self.dataset_dict) def test_fraction_sum_up_to_more_than_one_multisplit(self) -> None: """Test if resplitting raises when fractions sum up to > 1.0 .""" divide_config = {"train": {"train_1": 0.5, "train_2": 0.7}} divide_split = None drop_remaining_splits = False - resplitter = DivideResplitter( - divide_config, divide_split, drop_remaining_splits - ) + divider = Divider(divide_config, divide_split, drop_remaining_splits) with self.assertRaises(ValueError): - _ = resplitter(self.dataset_dict) + _ = divider(self.dataset_dict) def test_fraction_sum_up_to_more_than_one_single_split(self) -> None: """Test if resplitting raises when fractions sum up to > 1.0 .""" divide_config = {"train_1": 0.5, "train_2": 0.7} divide_split = "train" drop_remaining_splits = False - resplitter = DivideResplitter( - divide_config, divide_split, drop_remaining_splits - ) + divider = Divider(divide_config, divide_split, drop_remaining_splits) with self.assertRaises(ValueError): - _ = resplitter(self.dataset_dict) + _ = divider(self.dataset_dict) def test_sample_sizes_sum_up_to_more_than_dataset_size_single_split(self) -> None: """Test if resplitting raises when samples size sum up to > len(datset) .""" divide_config = {"train": {"train_1": 20, "train_2": 25}} divide_split = None drop_remaining_splits = False - resplitter = DivideResplitter( - divide_config, divide_split, drop_remaining_splits - ) + divider = Divider(divide_config, divide_split, drop_remaining_splits) with self.assertRaises(ValueError): - _ = resplitter(self.dataset_dict) + _ = divider(self.dataset_dict) def test_sample_sizes_sum_up_to_more_than_dataset_size_multisplit(self) -> None: """Test if resplitting raises when samples size sum up to > len(datset) .""" divide_config = {"train_1": 20, "train_2": 25} divide_split = "train" drop_remaining_splits = False - resplitter = DivideResplitter( - divide_config, divide_split, drop_remaining_splits - ) + divider = Divider(divide_config, divide_split, drop_remaining_splits) with self.assertRaises(ValueError): - _ = resplitter(self.dataset_dict) + _ = divider(self.dataset_dict) def test_too_small_size_values_create_empty_dataset_single_split(self) -> None: """Test if resplitting raises when fraction creates empty dataset.""" divide_config = {"train": {"train_1": 0.2, "train_2": 0.0001}} divide_split = None drop_remaining_splits = False - resplitter = DivideResplitter( - divide_config, divide_split, drop_remaining_splits - ) + divider = Divider(divide_config, divide_split, drop_remaining_splits) with self.assertRaises(ValueError): - _ = resplitter(self.dataset_dict) + _ = divider(self.dataset_dict) def test_too_small_size_values_create_empty_dataset_multisplit(self) -> None: """Test if resplitting raises when fraction creates empty dataset.""" divide_config = {"train_1": 0.2, "train_2": 0.0001} divide_split = "train" drop_remaining_splits = False - resplitter = DivideResplitter( - divide_config, divide_split, drop_remaining_splits - ) + divider = Divider(divide_config, divide_split, drop_remaining_splits) with self.assertRaises(ValueError): - _ = resplitter(self.dataset_dict) + _ = divider(self.dataset_dict) if __name__ == "__main__": diff --git a/datasets/flwr_datasets/resplitter/merge_resplitter.py b/datasets/flwr_datasets/preprocessor/merger.py similarity index 96% rename from datasets/flwr_datasets/resplitter/merge_resplitter.py rename to datasets/flwr_datasets/preprocessor/merger.py index 6bb8f23e60dc..2b76dbbafe4b 100644 --- a/datasets/flwr_datasets/resplitter/merge_resplitter.py +++ b/datasets/flwr_datasets/preprocessor/merger.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""MergeResplitter class for Flower Datasets.""" +"""Merger class for Flower Datasets.""" import collections @@ -24,7 +24,7 @@ from datasets import Dataset, DatasetDict -class MergeResplitter: +class Merger: """Merge existing splits of the dataset and assign them custom names. Create new `DatasetDict` with new split names corresponding to the merged existing @@ -43,13 +43,13 @@ class MergeResplitter: >>> # Assuming there is a dataset_dict of type `DatasetDict` >>> # dataset_dict is {"train": train-data, "valid": valid-data, "test": test-data} - >>> merge_resplitter = MergeResplitter( + >>> merger = Merger( >>> merge_config={ >>> "new_train": ("train", "valid"), >>> "test": ("test", ) >>> } >>> ) - >>> new_dataset_dict = merge_resplitter(dataset_dict) + >>> new_dataset_dict = merger(dataset_dict) >>> # new_dataset_dict is >>> # {"new_train": concatenation of train-data and valid-data, "test": test-data} """ diff --git a/datasets/flwr_datasets/resplitter/merge_resplitter_test.py b/datasets/flwr_datasets/preprocessor/merger_test.py similarity index 81% rename from datasets/flwr_datasets/resplitter/merge_resplitter_test.py rename to datasets/flwr_datasets/preprocessor/merger_test.py index ebbdfb4022b0..d5c69387e53d 100644 --- a/datasets/flwr_datasets/resplitter/merge_resplitter_test.py +++ b/datasets/flwr_datasets/preprocessor/merger_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Resplitter tests.""" +"""Preprocessor tests.""" import unittest @@ -21,11 +21,11 @@ import pytest from datasets import Dataset, DatasetDict -from flwr_datasets.resplitter.merge_resplitter import MergeResplitter +from flwr_datasets.preprocessor.merger import Merger -class TestResplitter(unittest.TestCase): - """Resplitter tests.""" +class TestMerger(unittest.TestCase): + """Preprocessor tests.""" def setUp(self) -> None: """Set up the dataset with 3 splits for tests.""" @@ -40,29 +40,29 @@ def setUp(self) -> None: def test_resplitting_train_size(self) -> None: """Test if resplitting for just renaming keeps the lengths correct.""" strategy: Dict[str, Tuple[str, ...]] = {"new_train": ("train",)} - resplitter = MergeResplitter(strategy) - new_dataset = resplitter(self.dataset_dict) + merger = Merger(strategy) + new_dataset = merger(self.dataset_dict) self.assertEqual(len(new_dataset["new_train"]), 3) def test_resplitting_valid_size(self) -> None: """Test if resplitting for just renaming keeps the lengths correct.""" strategy: Dict[str, Tuple[str, ...]] = {"new_valid": ("valid",)} - resplitter = MergeResplitter(strategy) - new_dataset = resplitter(self.dataset_dict) + merger = Merger(strategy) + new_dataset = merger(self.dataset_dict) self.assertEqual(len(new_dataset["new_valid"]), 2) def test_resplitting_test_size(self) -> None: """Test if resplitting for just renaming keeps the lengths correct.""" strategy: Dict[str, Tuple[str, ...]] = {"new_test": ("test",)} - resplitter = MergeResplitter(strategy) - new_dataset = resplitter(self.dataset_dict) + merger = Merger(strategy) + new_dataset = merger(self.dataset_dict) self.assertEqual(len(new_dataset["new_test"]), 1) def test_resplitting_train_the_same(self) -> None: """Test if resplitting for just renaming keeps the dataset the same.""" strategy: Dict[str, Tuple[str, ...]] = {"new_train": ("train",)} - resplitter = MergeResplitter(strategy) - new_dataset = resplitter(self.dataset_dict) + merger = Merger(strategy) + new_dataset = merger(self.dataset_dict) self.assertTrue( datasets_are_equal(self.dataset_dict["train"], new_dataset["new_train"]) ) @@ -72,8 +72,8 @@ def test_combined_train_valid_size(self) -> None: strategy: Dict[str, Tuple[str, ...]] = { "train_valid_combined": ("train", "valid") } - resplitter = MergeResplitter(strategy) - new_dataset = resplitter(self.dataset_dict) + merger = Merger(strategy) + new_dataset = merger(self.dataset_dict) self.assertEqual(len(new_dataset["train_valid_combined"]), 5) def test_resplitting_test_with_combined_strategy_size(self) -> None: @@ -82,8 +82,8 @@ def test_resplitting_test_with_combined_strategy_size(self) -> None: "train_valid_combined": ("train", "valid"), "test": ("test",), } - resplitter = MergeResplitter(strategy) - new_dataset = resplitter(self.dataset_dict) + merger = Merger(strategy) + new_dataset = merger(self.dataset_dict) self.assertEqual(len(new_dataset["test"]), 1) def test_invalid_resplit_strategy_exception_message(self) -> None: @@ -92,20 +92,20 @@ def test_invalid_resplit_strategy_exception_message(self) -> None: "new_train": ("invalid_split",), "new_test": ("test",), } - resplitter = MergeResplitter(strategy) + merger = Merger(strategy) with self.assertRaisesRegex( ValueError, "The given dataset key 'invalid_split' is not present" ): - resplitter(self.dataset_dict) + merger(self.dataset_dict) def test_nonexistent_split_in_strategy(self) -> None: """Test if the exception is raised when the nonexistent split name is given.""" strategy: Dict[str, Tuple[str, ...]] = {"new_split": ("nonexistent_split",)} - resplitter = MergeResplitter(strategy) + merger = Merger(strategy) with self.assertRaisesRegex( ValueError, "The given dataset key 'nonexistent_split' is not present" ): - resplitter(self.dataset_dict) + merger(self.dataset_dict) def test_duplicate_merge_split_name(self) -> None: # pylint: disable=R0201 """Test that the new split names are not the same.""" @@ -114,17 +114,17 @@ def test_duplicate_merge_split_name(self) -> None: # pylint: disable=R0201 "test": ("train",), } with pytest.warns(UserWarning): - _ = MergeResplitter(strategy) + _ = Merger(strategy) def test_empty_dataset_dict(self) -> None: """Test that the error is raised when the empty DatasetDict is given.""" empty_dataset = DatasetDict({}) strategy: Dict[str, Tuple[str, ...]] = {"new_train": ("train",)} - resplitter = MergeResplitter(strategy) + merger = Merger(strategy) with self.assertRaisesRegex( ValueError, "The given dataset key 'train' is not present" ): - resplitter(empty_dataset) + merger(empty_dataset) def datasets_are_equal(ds1: Dataset, ds2: Dataset) -> bool: diff --git a/datasets/flwr_datasets/resplitter/resplitter.py b/datasets/flwr_datasets/preprocessor/preprocessor.py similarity index 91% rename from datasets/flwr_datasets/resplitter/resplitter.py rename to datasets/flwr_datasets/preprocessor/preprocessor.py index 206e2e85730c..c137b98eeeee 100644 --- a/datasets/flwr_datasets/resplitter/resplitter.py +++ b/datasets/flwr_datasets/preprocessor/preprocessor.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Resplitter.""" +"""Preprocessor.""" from typing import Callable from datasets import DatasetDict -Resplitter = Callable[[DatasetDict], DatasetDict] +Preprocessor = Callable[[DatasetDict], DatasetDict] diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index c6f6900a99cd..0ecb96ac9456 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -20,8 +20,8 @@ from datasets import Dataset, DatasetDict, concatenate_datasets from flwr_datasets.partitioner import IidPartitioner, Partitioner -from flwr_datasets.resplitter import Resplitter -from flwr_datasets.resplitter.merge_resplitter import MergeResplitter +from flwr_datasets.preprocessor import Preprocessor +from flwr_datasets.preprocessor.merger import Merger tested_datasets = [ "mnist", @@ -75,13 +75,13 @@ def _instantiate_partitioners( return instantiated_partitioners -def _instantiate_resplitter_if_needed( - resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] -) -> Optional[Resplitter]: - """Instantiate `MergeResplitter` if resplitter is merge_config.""" - if resplitter and isinstance(resplitter, Dict): - resplitter = MergeResplitter(merge_config=resplitter) - return cast(Optional[Resplitter], resplitter) +def _instantiate_merger_if_needed( + merger: Optional[Union[Preprocessor, Dict[str, Tuple[str, ...]]]] +) -> Optional[Preprocessor]: + """Instantiate `Merger` if preprocessor is merge_config.""" + if merger and isinstance(merger, Dict): + merger = Merger(merge_config=merger) + return cast(Optional[Preprocessor], merger) def _check_if_dataset_tested(dataset: str) -> None: