diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index 6c41eaa3562f..1595abb9e750 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -20,7 +20,7 @@ import datasets from datasets import Dataset, DatasetDict from flwr_datasets.partitioner import Partitioner -from flwr_datasets.resplitter import Resplitter +from flwr_datasets.preprocessor import Preprocessor from flwr_datasets.utils import ( _check_if_dataset_tested, _instantiate_partitioners, @@ -45,7 +45,7 @@ class FederatedDataset: subset : str Secondary information regarding the dataset, most often subset or version (that is passed to the name in datasets.load_dataset). - resplitter : Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] + preprocessor : Optional[Union[Preprocessor, Dict[str, Tuple[str, ...]]]] `Callable` that transforms `DatasetDict` splits, or configuration dict for `MergeResplitter`. partitioners : Dict[str, Union[Partitioner, int]] @@ -79,7 +79,7 @@ def __init__( *, dataset: str, subset: Optional[str] = None, - resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] = None, + preprocessor: Optional[Union[Preprocessor, Dict[str, Tuple[str, ...]]]] = None, partitioners: Dict[str, Union[Partitioner, int]], shuffle: bool = True, seed: Optional[int] = 42, @@ -87,8 +87,8 @@ def __init__( _check_if_dataset_tested(dataset) self._dataset_name: str = dataset self._subset: Optional[str] = subset - self._resplitter: Optional[Resplitter] = _instantiate_resplitter_if_needed( - resplitter + self._resplitter: Optional[Preprocessor] = _instantiate_resplitter_if_needed( + preprocessor ) self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners( partitioners diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index 5d5179122e3b..f65aa6346f3a 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -170,20 +170,20 @@ def test_resplit_dataset_into_one(self) -> None: fds = FederatedDataset( dataset=self.dataset_name, partitioners={"train": 100}, - resplitter={"full": ("train", self.test_split)}, + preprocessor={"full": ("train", self.test_split)}, ) full = fds.load_split("full") self.assertEqual(dataset_length, len(full)) # pylint: disable=protected-access def test_resplit_dataset_to_change_names(self) -> None: - """Test resplitter to change the names of the partitions.""" + """Test preprocessor to change the names of the partitions.""" if self.test_split is None: return fds = FederatedDataset( dataset=self.dataset_name, partitioners={"new_train": 100}, - resplitter={ + preprocessor={ "new_train": ("train",), "new_" + self.test_split: (self.test_split,), }, @@ -195,7 +195,7 @@ def test_resplit_dataset_to_change_names(self) -> None: ) def test_resplit_dataset_by_callable(self) -> None: - """Test resplitter to change the names of the partitions.""" + """Test preprocessor to change the names of the partitions.""" if self.test_split is None: return @@ -209,7 +209,7 @@ def resplit(dataset: DatasetDict) -> DatasetDict: ) fds = FederatedDataset( - dataset=self.dataset_name, partitioners={"train": 100}, resplitter=resplit + dataset=self.dataset_name, partitioners={"train": 100}, preprocessor=resplit ) full = fds.load_split("full") dataset = datasets.load_dataset(self.dataset_name) @@ -298,7 +298,7 @@ def resplit(dataset: DatasetDict) -> DatasetDict: fds = FederatedDataset( dataset="does-not-matter", partitioners={"train": 10}, - resplitter=resplit, + preprocessor=resplit, shuffle=True, ) train = fds.load_split("train") @@ -411,7 +411,7 @@ def test_cannot_use_the_old_split_names(self) -> None: fds = FederatedDataset( dataset="mnist", partitioners={"train": 100}, - resplitter={"full": ("train", "test")}, + preprocessor={"full": ("train", "test")}, ) with self.assertRaises(ValueError): fds.load_partition(0, "train") diff --git a/datasets/flwr_datasets/resplitter/resplitter.py b/datasets/flwr_datasets/preprocessor.py similarity index 91% rename from datasets/flwr_datasets/resplitter/resplitter.py rename to datasets/flwr_datasets/preprocessor.py index 206e2e85730c..c137b98eeeee 100644 --- a/datasets/flwr_datasets/resplitter/resplitter.py +++ b/datasets/flwr_datasets/preprocessor.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Resplitter.""" +"""Preprocessor.""" from typing import Callable from datasets import DatasetDict -Resplitter = Callable[[DatasetDict], DatasetDict] +Preprocessor = Callable[[DatasetDict], DatasetDict] diff --git a/datasets/flwr_datasets/resplitter/__init__.py b/datasets/flwr_datasets/resplitter/__init__.py index e0b2dc0dcc1c..7d8e79ff9662 100644 --- a/datasets/flwr_datasets/resplitter/__init__.py +++ b/datasets/flwr_datasets/resplitter/__init__.py @@ -12,13 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Resplitter package.""" +"""Preprocessor package.""" from .merge_resplitter import MergeResplitter -from .resplitter import Resplitter __all__ = [ "MergeResplitter", - "Resplitter", ] diff --git a/datasets/flwr_datasets/resplitter/merge_resplitter_test.py b/datasets/flwr_datasets/resplitter/merge_resplitter_test.py index ebbdfb4022b0..ae9123e5c80c 100644 --- a/datasets/flwr_datasets/resplitter/merge_resplitter_test.py +++ b/datasets/flwr_datasets/resplitter/merge_resplitter_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Resplitter tests.""" +"""Preprocessor tests.""" import unittest @@ -25,7 +25,7 @@ class TestResplitter(unittest.TestCase): - """Resplitter tests.""" + """Preprocessor tests.""" def setUp(self) -> None: """Set up the dataset with 3 splits for tests.""" diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index c6f6900a99cd..c4bdf572a223 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -20,7 +20,7 @@ from datasets import Dataset, DatasetDict, concatenate_datasets from flwr_datasets.partitioner import IidPartitioner, Partitioner -from flwr_datasets.resplitter import Resplitter +from flwr_datasets.preprocessor import Preprocessor from flwr_datasets.resplitter.merge_resplitter import MergeResplitter tested_datasets = [ @@ -76,12 +76,12 @@ def _instantiate_partitioners( def _instantiate_resplitter_if_needed( - resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] -) -> Optional[Resplitter]: - """Instantiate `MergeResplitter` if resplitter is merge_config.""" + resplitter: Optional[Union[Preprocessor, Dict[str, Tuple[str, ...]]]] +) -> Optional[Preprocessor]: + """Instantiate `MergeResplitter` if preprocessor is merge_config.""" if resplitter and isinstance(resplitter, Dict): resplitter = MergeResplitter(merge_config=resplitter) - return cast(Optional[Resplitter], resplitter) + return cast(Optional[Preprocessor], resplitter) def _check_if_dataset_tested(dataset: str) -> None: diff --git a/examples/xgboost-comprehensive/client.py b/examples/xgboost-comprehensive/client.py index 2d54c3fd63c7..08dd548a386b 100644 --- a/examples/xgboost-comprehensive/client.py +++ b/examples/xgboost-comprehensive/client.py @@ -32,7 +32,7 @@ fds = FederatedDataset( dataset="jxie/higgs", partitioners={"train": partitioner}, - resplitter=resplit, + preprocessor=resplit, ) # Load the partition for this `partition_id` diff --git a/examples/xgboost-comprehensive/server.py b/examples/xgboost-comprehensive/server.py index 939819641438..07dc4bed6db4 100644 --- a/examples/xgboost-comprehensive/server.py +++ b/examples/xgboost-comprehensive/server.py @@ -32,7 +32,7 @@ # Load centralised test set if centralised_eval: fds = FederatedDataset( - dataset="jxie/higgs", partitioners={"train": 20}, resplitter=resplit + dataset="jxie/higgs", partitioners={"train": 20}, preprocessor=resplit ) log(INFO, "Loading centralised test set...") test_set = fds.load_split("test") diff --git a/examples/xgboost-comprehensive/sim.py b/examples/xgboost-comprehensive/sim.py index c9481f1cdd5d..09ebbb81fcb4 100644 --- a/examples/xgboost-comprehensive/sim.py +++ b/examples/xgboost-comprehensive/sim.py @@ -80,7 +80,7 @@ def main(): fds = FederatedDataset( dataset="jxie/higgs", partitioners={"train": partitioner}, - resplitter=resplit, + preprocessor=resplit, ) # Load centralised test set