Skip to content

Commit

Permalink
Rename resplitter parameter and type to preprocessor
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed May 20, 2024
1 parent bbad21f commit 74e70aa
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 27 deletions.
10 changes: 5 additions & 5 deletions datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import datasets
from datasets import Dataset, DatasetDict
from flwr_datasets.partitioner import Partitioner
from flwr_datasets.resplitter import Resplitter
from flwr_datasets.preprocessor import Preprocessor
from flwr_datasets.utils import (
_check_if_dataset_tested,
_instantiate_partitioners,
Expand All @@ -45,7 +45,7 @@ class FederatedDataset:
subset : str
Secondary information regarding the dataset, most often subset or version
(that is passed to the name in datasets.load_dataset).
resplitter : Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]]
preprocessor : Optional[Union[Preprocessor, Dict[str, Tuple[str, ...]]]]
`Callable` that transforms `DatasetDict` splits, or configuration dict for
`MergeResplitter`.
partitioners : Dict[str, Union[Partitioner, int]]
Expand Down Expand Up @@ -79,16 +79,16 @@ def __init__(
*,
dataset: str,
subset: Optional[str] = None,
resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] = None,
preprocessor: Optional[Union[Preprocessor, Dict[str, Tuple[str, ...]]]] = None,
partitioners: Dict[str, Union[Partitioner, int]],
shuffle: bool = True,
seed: Optional[int] = 42,
) -> None:
_check_if_dataset_tested(dataset)
self._dataset_name: str = dataset
self._subset: Optional[str] = subset
self._resplitter: Optional[Resplitter] = _instantiate_resplitter_if_needed(
resplitter
self._resplitter: Optional[Preprocessor] = _instantiate_resplitter_if_needed(
preprocessor
)
self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners(
partitioners
Expand Down
14 changes: 7 additions & 7 deletions datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,20 +170,20 @@ def test_resplit_dataset_into_one(self) -> None:
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners={"train": 100},
resplitter={"full": ("train", self.test_split)},
preprocessor={"full": ("train", self.test_split)},
)
full = fds.load_split("full")
self.assertEqual(dataset_length, len(full))

# pylint: disable=protected-access
def test_resplit_dataset_to_change_names(self) -> None:
"""Test resplitter to change the names of the partitions."""
"""Test preprocessor to change the names of the partitions."""
if self.test_split is None:
return
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners={"new_train": 100},
resplitter={
preprocessor={
"new_train": ("train",),
"new_" + self.test_split: (self.test_split,),
},
Expand All @@ -195,7 +195,7 @@ def test_resplit_dataset_to_change_names(self) -> None:
)

def test_resplit_dataset_by_callable(self) -> None:
"""Test resplitter to change the names of the partitions."""
"""Test preprocessor to change the names of the partitions."""
if self.test_split is None:
return

Expand All @@ -209,7 +209,7 @@ def resplit(dataset: DatasetDict) -> DatasetDict:
)

fds = FederatedDataset(
dataset=self.dataset_name, partitioners={"train": 100}, resplitter=resplit
dataset=self.dataset_name, partitioners={"train": 100}, preprocessor=resplit
)
full = fds.load_split("full")
dataset = datasets.load_dataset(self.dataset_name)
Expand Down Expand Up @@ -298,7 +298,7 @@ def resplit(dataset: DatasetDict) -> DatasetDict:
fds = FederatedDataset(
dataset="does-not-matter",
partitioners={"train": 10},
resplitter=resplit,
preprocessor=resplit,
shuffle=True,
)
train = fds.load_split("train")
Expand Down Expand Up @@ -411,7 +411,7 @@ def test_cannot_use_the_old_split_names(self) -> None:
fds = FederatedDataset(
dataset="mnist",
partitioners={"train": 100},
resplitter={"full": ("train", "test")},
preprocessor={"full": ("train", "test")},
)
with self.assertRaises(ValueError):
fds.load_partition(0, "train")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Resplitter."""
"""Preprocessor."""


from typing import Callable

from datasets import DatasetDict

Resplitter = Callable[[DatasetDict], DatasetDict]
Preprocessor = Callable[[DatasetDict], DatasetDict]
4 changes: 1 addition & 3 deletions datasets/flwr_datasets/resplitter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Resplitter package."""
"""Preprocessor package."""


from .merge_resplitter import MergeResplitter
from .resplitter import Resplitter

__all__ = [
"MergeResplitter",
"Resplitter",
]
4 changes: 2 additions & 2 deletions datasets/flwr_datasets/resplitter/merge_resplitter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Resplitter tests."""
"""Preprocessor tests."""


import unittest
Expand All @@ -25,7 +25,7 @@


class TestResplitter(unittest.TestCase):
"""Resplitter tests."""
"""Preprocessor tests."""

def setUp(self) -> None:
"""Set up the dataset with 3 splits for tests."""
Expand Down
10 changes: 5 additions & 5 deletions datasets/flwr_datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from datasets import Dataset, DatasetDict, concatenate_datasets
from flwr_datasets.partitioner import IidPartitioner, Partitioner
from flwr_datasets.resplitter import Resplitter
from flwr_datasets.preprocessor import Preprocessor
from flwr_datasets.resplitter.merge_resplitter import MergeResplitter

tested_datasets = [
Expand Down Expand Up @@ -76,12 +76,12 @@ def _instantiate_partitioners(


def _instantiate_resplitter_if_needed(
resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]]
) -> Optional[Resplitter]:
"""Instantiate `MergeResplitter` if resplitter is merge_config."""
resplitter: Optional[Union[Preprocessor, Dict[str, Tuple[str, ...]]]]
) -> Optional[Preprocessor]:
"""Instantiate `MergeResplitter` if preprocessor is merge_config."""
if resplitter and isinstance(resplitter, Dict):
resplitter = MergeResplitter(merge_config=resplitter)
return cast(Optional[Resplitter], resplitter)
return cast(Optional[Preprocessor], resplitter)


def _check_if_dataset_tested(dataset: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion examples/xgboost-comprehensive/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
fds = FederatedDataset(
dataset="jxie/higgs",
partitioners={"train": partitioner},
resplitter=resplit,
preprocessor=resplit,
)

# Load the partition for this `partition_id`
Expand Down
2 changes: 1 addition & 1 deletion examples/xgboost-comprehensive/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# Load centralised test set
if centralised_eval:
fds = FederatedDataset(
dataset="jxie/higgs", partitioners={"train": 20}, resplitter=resplit
dataset="jxie/higgs", partitioners={"train": 20}, preprocessor=resplit
)
log(INFO, "Loading centralised test set...")
test_set = fds.load_split("test")
Expand Down
2 changes: 1 addition & 1 deletion examples/xgboost-comprehensive/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def main():
fds = FederatedDataset(
dataset="jxie/higgs",
partitioners={"train": partitioner},
resplitter=resplit,
preprocessor=resplit,
)

# Load centralised test set
Expand Down

0 comments on commit 74e70aa

Please sign in to comment.