diff --git a/datasets/flwr_datasets/resplitter/__init__.py b/datasets/flwr_datasets/resplitter/__init__.py index e0b2dc0dcc1c..bf39786e0593 100644 --- a/datasets/flwr_datasets/resplitter/__init__.py +++ b/datasets/flwr_datasets/resplitter/__init__.py @@ -15,10 +15,12 @@ """Resplitter package.""" +from .divide_resplitter import DivideResplitter from .merge_resplitter import MergeResplitter from .resplitter import Resplitter __all__ = [ + "DivideResplitter", "MergeResplitter", "Resplitter", ] diff --git a/datasets/flwr_datasets/resplitter/divide_resplitter.py b/datasets/flwr_datasets/resplitter/divide_resplitter.py new file mode 100644 index 000000000000..4d833648babb --- /dev/null +++ b/datasets/flwr_datasets/resplitter/divide_resplitter.py @@ -0,0 +1,262 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""DivideResplitter class for Flower Datasets.""" +import collections +import warnings +from typing import Dict, List, Optional, Union + +import datasets +from datasets import DatasetDict + + +class DivideResplitter: + """Dive existing split(s) of the dataset and assign them custom names. + + Create new `DatasetDict` with new split names corresponding the percentages of data + and custom names. + + Parameters + ---------- + divide_config: Union[Dict[str, int], Dict[str, float], Dict[str, Dict[str, + int]], Dict[str, Dict[str, float]]] + If single level dictionary with keys - the new split names and the values of int + = number of samples, float - fraction of the split. The fractions do not have + to sum up to 1.0. The order of matter = the first key will get fraction_1 + starting from the beginning of the dataset. + If two level dictionary (dictionary of dictionaries) then the first keys are + the split names that will be divided into different splits. It's an alternative + to specifying `divide_split` if you need to divide many splits. + divide_split: Optional[str] + In case of single level dictionary specification of `divide_config`, specifies + the split name that will be divided. Might be left None in case of a single- + split dataset (it will be automatically inferred). Ignored in case of + multi-split configuration. + drop_remaining_splits: bool + In case of single level dictionary specification of `divide_config`, specifies + if the splits that are not divided are dropped. + + Raises + ------ + ValuesError if the specified name of a new split is already present in the dataset + and the `drop_remaining_splits` is False. + """ + + def __init__( + self, + divide_config: Union[ + Dict[str, float], + Dict[str, int], + Dict[str, Dict[str, float]], + Dict[str, Dict[str, int]], + ], + divide_split: Optional[str] = None, + drop_remaining_splits: bool = False, + ) -> None: + self._divide_config = divide_config + self._divide_split = divide_split + self._drop_remaining_splits = drop_remaining_splits + self._config_type = self._determine_config_type() + self._check_duplicate_splits_in_config() + self._warn_on_potential_misuse_of_divide_split() + + def __call__(self, dataset: DatasetDict) -> DatasetDict: + """Resplit the dataset according to the configuration.""" + if self._drop_remaining_splits is False: + dataset_splits = list(dataset.keys()) + self._check_duplicate_splits_in_config_and_original_dataset(dataset_splits) + return self.resplit(dataset) + + def resplit(self, dataset: DatasetDict) -> DatasetDict: + """Resplit the dataset according to the configuration.""" + resplit_dataset = {} + dataset_splits: List[str] = list(dataset.keys()) + config_type = self._determine_config_type() + # Change the "single-split" config to look like "multiple-split" config + if config_type == "single-split": + # First, if the `divide_split` is None determine the split + if self._divide_split is None: + if len(dataset_splits) != 1: + raise ValueError( + "When giving the config that is single level and working with " + "dataset with more than one split you need to specify the " + "`divide_split` but given None instead." + ) + else: + self._divide_split = dataset_splits[0] + # assert isinstance(self._divide_config, dict) + self._divide_config = {self._divide_split: self._divide_config} + + self._check_size_values(dataset) + # Continue with the resplitting process + # Move the non-split splits if they exist + if self._drop_remaining_splits is False: + if len(dataset_splits) >= 2: + split_splits = set(self._divide_config.keys()) + non_split_splits = list(set(dataset_splits) - split_splits) + for non_split_split in non_split_splits: + resplit_dataset[non_split_split] = dataset[non_split_split] + else: + # The remaining data is not kept (by simply not coping it=the reference) + pass + + # Split the splits + for split_from, new_splits_dict in self._divide_config.items(): + start_index = 0 + end_index = 0 + split_data = dataset[split_from] + assert isinstance(new_splits_dict, dict) + for new_split_name, size in new_splits_dict.items(): + if isinstance(size, float): + end_index += int(len(split_data) * size) + elif isinstance(size, int): + end_index += size + else: + raise ValueError( + "The type of size value for the divide config must " + "be int or float." + ) + if end_index > len(split_data): + raise ValueError( + "The size specified in the `divide_config` is greater than " + "the size of the dataset." + ) + if end_index == start_index: + raise ValueError( + f"The size specified in the `divide_config` results in the " + f"dataset of size 0. The problem occurred in {new_splits_dict}." + f"Please make sure to provide sizes that do not produce empty" + f"datasets." + ) + resplit_dataset[new_split_name] = split_data.select( + range(start_index, end_index) + ) + start_index = end_index + return datasets.DatasetDict(resplit_dataset) + + def _check_duplicate_splits_in_config(self) -> None: + """Check if the new split names are duplicated in `divide_config`.""" + if self._config_type == "single-split": + new_splits = list(self._divide_config.keys()) + elif self._config_type == "multiple-splits": + new_splits = [] + for new_splits_dict in self._divide_config.values(): + assert isinstance(new_splits_dict, dict) + new_values = list(new_splits_dict.values()) + assert isinstance(new_values, list) + new_splits.extend(new_values) + else: + raise ValueError("Incorrect type of config.") + + duplicates = [ + item for item, count in collections.Counter(new_splits).items() if count > 1 + ] + if duplicates: + raise ValueError( + "The specified values of the new splits in " + "`divide_config` are duplicated. Please specify" + "unique values for each new split." + ) + + def _check_duplicate_splits_in_config_and_original_dataset( + self, dataset_splits: List[str] + ) -> None: + """Check duplicates along the new split values and dataset splits. + + This check can happen only at the time this class is called (it does not have + access to the dataset prior to that). + """ + if self._config_type == "single-split": + new_splits = list(self._divide_config.keys()) + all_splits = dataset_splits + new_splits + all_splits.pop(all_splits.index(self._divide_split)) + elif self._config_type == "multiple-splits": + new_splits = [] + for new_splits_dict in self._divide_config.values(): + new_splits.extend(list(new_splits_dict.keys())) + all_splits = dataset_splits + new_splits + for used_split in self._divide_config.keys(): + all_splits.pop(all_splits.index(used_split)) + else: + raise ValueError("Incorrect type of config.") + + duplicates = [ + item for item, count in collections.Counter(all_splits).items() if count > 1 + ] + if duplicates: + raise ValueError( + "The specified values of the new splits in " + "`divide_config` are duplicated with the split names of " + "the datasets. " + "Please specify unique values for each new split." + ) + + def _determine_config_type(self) -> str: + """Determine configuration type of `divide_config` based on the dict structure. + + Two possible configuration are possible: 1) single-split single-level (works + together with `divide_split`), 2) nested/two-level that works with multiple + splits (`divide_split` is ignored). + + Returns + ------- + config_type: str + "single-split" or "multiple-splits" + """ + for value in self._divide_config.values(): + # Check if the value is a dictionary + if isinstance(value, dict): + return "multiple-splits" + # If no dictionary values are found, it is single-level + return "single-split" + + def _check_size_values(self, dataset: DatasetDict) -> None: + # It should be called after the `divide_config` is in the multiple-splits format + for split_from, new_split_dict in self._divide_config.items(): + if all(isinstance(x, float) for x in new_split_dict.values()): + if not all(0 < x <= 1 for x in new_split_dict.values()): + raise ValueError( + "All fractions in `divide_config` must be greater than 0 and " + "smaller or equal to 1." + ) + if sum(new_split_dict.values()) > 1.0: + raise ValueError( + "The sum of the fractions in `divide_config` must be smaller " + "than 1.0." + ) + + elif all(isinstance(x, int) for x in new_split_dict.values()): + dataset_len = len(dataset[split_from]) + len_from_divide_resplit = sum(new_split_dict.values()) + if len_from_divide_resplit > dataset_len: + raise ValueError( + f"The sum of the sample numbers in `divide_config` must be " + f"smaller than the split size. This is not the case for " + f"{split_from} split which is of length {dataset_len} and the " + f"sum in config is {len_from_divide_resplit}." + ) + else: + raise TypeError( + "The values in `divide_config` must be either ints or floats. " + "The mix of them or other types are not allowed." + ) + + def _warn_on_potential_misuse_of_divide_split(self) -> None: + if self._config_type == "multiple-splits" and self._divide_split is not None: + warnings.warn( + "The `divide_split` was specified but the multiple split " + "configuration was given. The `divide_split` will be " + "ignored.", + stacklevel=1, + ) diff --git a/datasets/flwr_datasets/resplitter/divide_resplitter_test.py b/datasets/flwr_datasets/resplitter/divide_resplitter_test.py new file mode 100644 index 000000000000..674d6e66086f --- /dev/null +++ b/datasets/flwr_datasets/resplitter/divide_resplitter_test.py @@ -0,0 +1,230 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""DivideResplitter tests.""" + +import unittest +from typing import Dict, Union + +from parameterized import parameterized_class + +from datasets import Dataset, DatasetDict +from flwr_datasets.resplitter import DivideResplitter + + +# todo a separate case when the name is inferred +@parameterized_class( + ("divide_config", "divide_split", "drop_remaining_splits", "split_name_to_size"), + [ + # Standard config that sums to one + ( + {"train_1": 0.25, "train_2": 0.75}, + "train", + False, + {"train_1": 10, "train_2": 30, "valid": 20, "test": 40}, + ), + # As the first use case but drop the remaining splits + ( + {"train_1": 0.25, "train_2": 0.75}, + "train", + True, + {"train_1": 10, "train_2": 30}, + ), + # Split does not sum to 1.0 + ( + {"a": 0.2, "b": 0.4}, + "valid", + False, + {"a": 4, "b": 8, "train": 40, "test": 40}, + ), + # Completely custom names + ( + {"test_a": 0.2, "asdfasdfsa": 0.4}, + "test", + False, + {"test_a": 8, "asdfasdfsa": 16, "valid": 20, "train": 40}, + ), + # Mirror copies of the first example but using multiple split + ( + {"train": {"train_1": 0.25, "train_2": 0.75}}, + None, + False, + {"train_1": 10, "train_2": 30, "valid": 20, "test": 40}, + ), + # + ], +) +class TestDivideResplitter(unittest.TestCase): + """DivideResplitter tests.""" + + divide_config: Union[ + Dict[str, float], + Dict[str, int], + Dict[str, Dict[str, float]], + Dict[str, Dict[str, int]], + ] + divide_split: str + drop_remaining_splits: bool + split_name_to_size: Dict[str, int] + + def setUp(self) -> None: + """Set up the dataset with 3 splits for tests.""" + self.dataset_dict = DatasetDict( + { + "train": Dataset.from_dict({"data": list(range(40))}), + "valid": Dataset.from_dict({"data": list(range(40, 60))}), + "test": Dataset.from_dict({"data": list(range(60, 100))}), + } + ) + + def test_resplitting_correct_new_split_names(self) -> None: + """Test if resplitting produces requested new splits.""" + resplitter = DivideResplitter( + self.divide_config, self.divide_split, self.drop_remaining_splits + ) + resplit_dataset = resplitter(self.dataset_dict) + new_keys = set(resplit_dataset.keys()) + self.assertEqual(set(self.split_name_to_size.keys()), new_keys) + + def test_resplitting_correct_new_split_sizes(self) -> None: + """Test if resplitting produces correct sizes of splits.""" + resplitter = DivideResplitter( + self.divide_config, self.divide_split, self.drop_remaining_splits + ) + resplit_dataset = resplitter(self.dataset_dict) + split_to_size = { + split_name: len(split) for split_name, split in resplit_dataset.items() + } + self.assertEqual(self.split_name_to_size, split_to_size) + + +# todo: test mixed dict types +class TestDivideResplitterIncorrectUseCases(unittest.TestCase): + """Resplitter tests.""" + + def setUp(self) -> None: + """Set up the dataset with 3 splits for tests.""" + self.dataset_dict = DatasetDict( + { + "train": Dataset.from_dict({"data": list(range(40))}), + "valid": Dataset.from_dict({"data": list(range(40, 60))}), + "test": Dataset.from_dict({"data": list(range(60, 100))}), + } + ) + + def test_doubling_names_in_config(self) -> None: + """Test if resplitting raises when the same name in config is detected.""" + divide_config = {"train": {"new_train": 0.5}, "valid": {"new_train": 0.5}} + divide_split = None + drop_remaining_splits = False + + with self.assertRaises(ValueError): + resplitter = DivideResplitter( + divide_config, divide_split, drop_remaining_splits + ) + _ = resplitter(self.dataset_dict) + + def test_duplicate_names_in_config_and_dataset_split_names_multisplit(self) -> None: + """Test if resplitting raises when the name collides with the old name.""" + divide_config = {"train": {"valid": 0.5}} + divide_split = None + drop_remaining_splits = False + resplitter = DivideResplitter( + divide_config, divide_split, drop_remaining_splits + ) + with self.assertRaises(ValueError): + _ = resplitter(self.dataset_dict) + + def test_duplicate_names_in_config_and_dataset_split_names_single_split( + self, + ) -> None: + """Test if resplitting raises when the name collides with the old name.""" + divide_config = {"valid": 0.5} + divide_split = "train" + drop_remaining_splits = False + resplitter = DivideResplitter( + divide_config, divide_split, drop_remaining_splits + ) + with self.assertRaises(ValueError): + _ = resplitter(self.dataset_dict) + + def test_fraction_sum_up_to_more_than_one_multisplit(self) -> None: + """Test if resplitting raises when fractions sum up to > 1.0 .""" + divide_config = {"train": {"train_1": 0.5, "train_2": 0.7}} + divide_split = None + drop_remaining_splits = False + resplitter = DivideResplitter( + divide_config, divide_split, drop_remaining_splits + ) + with self.assertRaises(ValueError): + _ = resplitter(self.dataset_dict) + + def test_fraction_sum_up_to_more_than_one_single_split(self) -> None: + """Test if resplitting raises when fractions sum up to > 1.0 .""" + divide_config = {"train_1": 0.5, "train_2": 0.7} + divide_split = "train" + drop_remaining_splits = False + resplitter = DivideResplitter( + divide_config, divide_split, drop_remaining_splits + ) + with self.assertRaises(ValueError): + _ = resplitter(self.dataset_dict) + + def test_sample_sizes_sum_up_to_more_than_dataset_size_single_split(self) -> None: + """Test if resplitting raises when samples size sum up to > len(datset) .""" + divide_config = {"train": {"train_1": 20, "train_2": 25}} + divide_split = None + drop_remaining_splits = False + resplitter = DivideResplitter( + divide_config, divide_split, drop_remaining_splits + ) + with self.assertRaises(ValueError): + _ = resplitter(self.dataset_dict) + + def test_sample_sizes_sum_up_to_more_than_dataset_size_multisplit(self) -> None: + """Test if resplitting raises when samples size sum up to > len(datset) .""" + divide_config = {"train_1": 20, "train_2": 25} + divide_split = "train" + drop_remaining_splits = False + resplitter = DivideResplitter( + divide_config, divide_split, drop_remaining_splits + ) + with self.assertRaises(ValueError): + _ = resplitter(self.dataset_dict) + + def test_too_small_size_values_create_empty_dataset_single_split(self) -> None: + """Test if resplitting raises when fraction creates empty dataset.""" + divide_config = {"train": {"train_1": 0.2, "train_2": 0.0001}} + divide_split = None + drop_remaining_splits = False + resplitter = DivideResplitter( + divide_config, divide_split, drop_remaining_splits + ) + with self.assertRaises(ValueError): + _ = resplitter(self.dataset_dict) + + def test_too_small_size_values_create_empty_dataset_multisplit(self) -> None: + """Test if resplitting raises when fraction creates empty dataset.""" + divide_config = {"train_1": 0.2, "train_2": 0.0001} + divide_split = "train" + drop_remaining_splits = False + resplitter = DivideResplitter( + divide_config, divide_split, drop_remaining_splits + ) + with self.assertRaises(ValueError): + _ = resplitter(self.dataset_dict) + + +if __name__ == "__main__": + unittest.main()