From d5c22e5336cccc075c42bd5862e3aaeb723e72ea Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Thu, 15 Feb 2024 11:02:10 +0100 Subject: [PATCH] Fix mypy errors --- .../resplitter/divide_resplitter.py | 94 ++++++++++++------- .../resplitter/divide_resplitter_test.py | 4 +- 2 files changed, 60 insertions(+), 38 deletions(-) diff --git a/datasets/flwr_datasets/resplitter/divide_resplitter.py b/datasets/flwr_datasets/resplitter/divide_resplitter.py index 4d833648babb..d0d9d8353f83 100644 --- a/datasets/flwr_datasets/resplitter/divide_resplitter.py +++ b/datasets/flwr_datasets/resplitter/divide_resplitter.py @@ -15,7 +15,7 @@ """DivideResplitter class for Flower Datasets.""" import collections import warnings -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, cast import datasets from datasets import DatasetDict @@ -64,10 +64,23 @@ def __init__( divide_split: Optional[str] = None, drop_remaining_splits: bool = False, ) -> None: - self._divide_config = divide_config + self._single_split_config: Union[Dict[str, float], Dict[str, int]] + self._multiple_splits_config: Union[ + Dict[str, Dict[str, float]], Dict[str, Dict[str, int]] + ] + + self._config_type = _determine_config_type(divide_config) + if self._config_type == "single-split": + self._single_split_config = cast( + Union[Dict[str, float], Dict[str, int]], divide_config + ) + else: + self._multiple_splits_config = cast( + Union[Dict[str, Dict[str, float]], Dict[str, Dict[str, int]]], + divide_config, + ) self._divide_split = divide_split self._drop_remaining_splits = drop_remaining_splits - self._config_type = self._determine_config_type() self._check_duplicate_splits_in_config() self._warn_on_potential_misuse_of_divide_split() @@ -82,9 +95,8 @@ def resplit(self, dataset: DatasetDict) -> DatasetDict: """Resplit the dataset according to the configuration.""" resplit_dataset = {} dataset_splits: List[str] = list(dataset.keys()) - config_type = self._determine_config_type() # Change the "single-split" config to look like "multiple-split" config - if config_type == "single-split": + if self._config_type == "single-split": # First, if the `divide_split` is None determine the split if self._divide_split is None: if len(dataset_splits) != 1: @@ -95,15 +107,17 @@ def resplit(self, dataset: DatasetDict) -> DatasetDict: ) else: self._divide_split = dataset_splits[0] - # assert isinstance(self._divide_config, dict) - self._divide_config = {self._divide_split: self._divide_config} + self._multiple_splits_config = cast( + Union[Dict[str, Dict[str, float]], Dict[str, Dict[str, int]]], + {self._divide_split: self._single_split_config}, + ) self._check_size_values(dataset) # Continue with the resplitting process # Move the non-split splits if they exist if self._drop_remaining_splits is False: if len(dataset_splits) >= 2: - split_splits = set(self._divide_config.keys()) + split_splits = set(self._multiple_splits_config.keys()) non_split_splits = list(set(dataset_splits) - split_splits) for non_split_split in non_split_splits: resplit_dataset[non_split_split] = dataset[non_split_split] @@ -112,7 +126,7 @@ def resplit(self, dataset: DatasetDict) -> DatasetDict: pass # Split the splits - for split_from, new_splits_dict in self._divide_config.items(): + for split_from, new_splits_dict in self._multiple_splits_config.items(): start_index = 0 end_index = 0 split_data = dataset[split_from] @@ -148,12 +162,12 @@ def resplit(self, dataset: DatasetDict) -> DatasetDict: def _check_duplicate_splits_in_config(self) -> None: """Check if the new split names are duplicated in `divide_config`.""" if self._config_type == "single-split": - new_splits = list(self._divide_config.keys()) + new_splits = list(self._single_split_config.keys()) elif self._config_type == "multiple-splits": new_splits = [] - for new_splits_dict in self._divide_config.values(): + for new_splits_dict in self._multiple_splits_config.values(): assert isinstance(new_splits_dict, dict) - new_values = list(new_splits_dict.values()) + new_values = list(new_splits_dict.keys()) assert isinstance(new_values, list) new_splits.extend(new_values) else: @@ -178,15 +192,16 @@ def _check_duplicate_splits_in_config_and_original_dataset( access to the dataset prior to that). """ if self._config_type == "single-split": - new_splits = list(self._divide_config.keys()) + new_splits = list(self._single_split_config.keys()) all_splits = dataset_splits + new_splits + assert self._divide_split is not None all_splits.pop(all_splits.index(self._divide_split)) elif self._config_type == "multiple-splits": new_splits = [] - for new_splits_dict in self._divide_config.values(): + for new_splits_dict in self._multiple_splits_config.values(): new_splits.extend(list(new_splits_dict.keys())) all_splits = dataset_splits + new_splits - for used_split in self._divide_config.keys(): + for used_split in self._multiple_splits_config.keys(): all_splits.pop(all_splits.index(used_split)) else: raise ValueError("Incorrect type of config.") @@ -202,28 +217,10 @@ def _check_duplicate_splits_in_config_and_original_dataset( "Please specify unique values for each new split." ) - def _determine_config_type(self) -> str: - """Determine configuration type of `divide_config` based on the dict structure. - - Two possible configuration are possible: 1) single-split single-level (works - together with `divide_split`), 2) nested/two-level that works with multiple - splits (`divide_split` is ignored). - - Returns - ------- - config_type: str - "single-split" or "multiple-splits" - """ - for value in self._divide_config.values(): - # Check if the value is a dictionary - if isinstance(value, dict): - return "multiple-splits" - # If no dictionary values are found, it is single-level - return "single-split" - def _check_size_values(self, dataset: DatasetDict) -> None: # It should be called after the `divide_config` is in the multiple-splits format - for split_from, new_split_dict in self._divide_config.items(): + assert self._multiple_splits_config is not None + for split_from, new_split_dict in self._multiple_splits_config.items(): if all(isinstance(x, float) for x in new_split_dict.values()): if not all(0 < x <= 1 for x in new_split_dict.values()): raise ValueError( @@ -260,3 +257,30 @@ def _warn_on_potential_misuse_of_divide_split(self) -> None: "ignored.", stacklevel=1, ) + + +def _determine_config_type( + config: Union[ + Dict[str, float], + Dict[str, int], + Dict[str, Dict[str, float]], + Dict[str, Dict[str, int]], + ], +) -> str: + """Determine configuration type of `divide_config` based on the dict structure. + + Two possible configuration are possible: 1) single-split single-level (works + together with `divide_split`), 2) nested/two-level that works with multiple + splits (`divide_split` is ignored). + + Returns + ------- + config_type: str + "single-split" or "multiple-splits" + """ + for value in config.values(): + # Check if the value is a dictionary + if isinstance(value, dict): + return "multiple-splits" + # If no dictionary values are found, it is single-level + return "single-split" diff --git a/datasets/flwr_datasets/resplitter/divide_resplitter_test.py b/datasets/flwr_datasets/resplitter/divide_resplitter_test.py index 674d6e66086f..7c8da36d28a8 100644 --- a/datasets/flwr_datasets/resplitter/divide_resplitter_test.py +++ b/datasets/flwr_datasets/resplitter/divide_resplitter_test.py @@ -23,7 +23,6 @@ from flwr_datasets.resplitter import DivideResplitter -# todo a separate case when the name is inferred @parameterized_class( ("divide_config", "divide_split", "drop_remaining_splits", "split_name_to_size"), [ @@ -109,7 +108,6 @@ def test_resplitting_correct_new_split_sizes(self) -> None: self.assertEqual(self.split_name_to_size, split_to_size) -# todo: test mixed dict types class TestDivideResplitterIncorrectUseCases(unittest.TestCase): """Resplitter tests.""" @@ -125,7 +123,7 @@ def setUp(self) -> None: def test_doubling_names_in_config(self) -> None: """Test if resplitting raises when the same name in config is detected.""" - divide_config = {"train": {"new_train": 0.5}, "valid": {"new_train": 0.5}} + divide_config = {"train": {"new_train": 0.5}, "valid": {"new_train": 0.3}} divide_split = None drop_remaining_splits = False