Skip to content

Commit

Permalink
Fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Feb 15, 2024
1 parent b4f5dc5 commit d5c22e5
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 38 deletions.
94 changes: 59 additions & 35 deletions datasets/flwr_datasets/resplitter/divide_resplitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""DivideResplitter class for Flower Datasets."""
import collections
import warnings
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, cast

import datasets
from datasets import DatasetDict
Expand Down Expand Up @@ -64,10 +64,23 @@ def __init__(
divide_split: Optional[str] = None,
drop_remaining_splits: bool = False,
) -> None:
self._divide_config = divide_config
self._single_split_config: Union[Dict[str, float], Dict[str, int]]
self._multiple_splits_config: Union[
Dict[str, Dict[str, float]], Dict[str, Dict[str, int]]
]

self._config_type = _determine_config_type(divide_config)
if self._config_type == "single-split":
self._single_split_config = cast(
Union[Dict[str, float], Dict[str, int]], divide_config
)
else:
self._multiple_splits_config = cast(
Union[Dict[str, Dict[str, float]], Dict[str, Dict[str, int]]],
divide_config,
)
self._divide_split = divide_split
self._drop_remaining_splits = drop_remaining_splits
self._config_type = self._determine_config_type()
self._check_duplicate_splits_in_config()
self._warn_on_potential_misuse_of_divide_split()

Expand All @@ -82,9 +95,8 @@ def resplit(self, dataset: DatasetDict) -> DatasetDict:
"""Resplit the dataset according to the configuration."""
resplit_dataset = {}
dataset_splits: List[str] = list(dataset.keys())
config_type = self._determine_config_type()
# Change the "single-split" config to look like "multiple-split" config
if config_type == "single-split":
if self._config_type == "single-split":
# First, if the `divide_split` is None determine the split
if self._divide_split is None:
if len(dataset_splits) != 1:
Expand All @@ -95,15 +107,17 @@ def resplit(self, dataset: DatasetDict) -> DatasetDict:
)
else:
self._divide_split = dataset_splits[0]
# assert isinstance(self._divide_config, dict)
self._divide_config = {self._divide_split: self._divide_config}
self._multiple_splits_config = cast(
Union[Dict[str, Dict[str, float]], Dict[str, Dict[str, int]]],
{self._divide_split: self._single_split_config},
)

self._check_size_values(dataset)
# Continue with the resplitting process
# Move the non-split splits if they exist
if self._drop_remaining_splits is False:
if len(dataset_splits) >= 2:
split_splits = set(self._divide_config.keys())
split_splits = set(self._multiple_splits_config.keys())
non_split_splits = list(set(dataset_splits) - split_splits)
for non_split_split in non_split_splits:
resplit_dataset[non_split_split] = dataset[non_split_split]
Expand All @@ -112,7 +126,7 @@ def resplit(self, dataset: DatasetDict) -> DatasetDict:
pass

# Split the splits
for split_from, new_splits_dict in self._divide_config.items():
for split_from, new_splits_dict in self._multiple_splits_config.items():
start_index = 0
end_index = 0
split_data = dataset[split_from]
Expand Down Expand Up @@ -148,12 +162,12 @@ def resplit(self, dataset: DatasetDict) -> DatasetDict:
def _check_duplicate_splits_in_config(self) -> None:
"""Check if the new split names are duplicated in `divide_config`."""
if self._config_type == "single-split":
new_splits = list(self._divide_config.keys())
new_splits = list(self._single_split_config.keys())
elif self._config_type == "multiple-splits":
new_splits = []
for new_splits_dict in self._divide_config.values():
for new_splits_dict in self._multiple_splits_config.values():
assert isinstance(new_splits_dict, dict)
new_values = list(new_splits_dict.values())
new_values = list(new_splits_dict.keys())
assert isinstance(new_values, list)
new_splits.extend(new_values)
else:
Expand All @@ -178,15 +192,16 @@ def _check_duplicate_splits_in_config_and_original_dataset(
access to the dataset prior to that).
"""
if self._config_type == "single-split":
new_splits = list(self._divide_config.keys())
new_splits = list(self._single_split_config.keys())
all_splits = dataset_splits + new_splits
assert self._divide_split is not None
all_splits.pop(all_splits.index(self._divide_split))
elif self._config_type == "multiple-splits":
new_splits = []
for new_splits_dict in self._divide_config.values():
for new_splits_dict in self._multiple_splits_config.values():
new_splits.extend(list(new_splits_dict.keys()))
all_splits = dataset_splits + new_splits
for used_split in self._divide_config.keys():
for used_split in self._multiple_splits_config.keys():
all_splits.pop(all_splits.index(used_split))
else:
raise ValueError("Incorrect type of config.")
Expand All @@ -202,28 +217,10 @@ def _check_duplicate_splits_in_config_and_original_dataset(
"Please specify unique values for each new split."
)

def _determine_config_type(self) -> str:
"""Determine configuration type of `divide_config` based on the dict structure.
Two possible configuration are possible: 1) single-split single-level (works
together with `divide_split`), 2) nested/two-level that works with multiple
splits (`divide_split` is ignored).
Returns
-------
config_type: str
"single-split" or "multiple-splits"
"""
for value in self._divide_config.values():
# Check if the value is a dictionary
if isinstance(value, dict):
return "multiple-splits"
# If no dictionary values are found, it is single-level
return "single-split"

def _check_size_values(self, dataset: DatasetDict) -> None:
# It should be called after the `divide_config` is in the multiple-splits format
for split_from, new_split_dict in self._divide_config.items():
assert self._multiple_splits_config is not None
for split_from, new_split_dict in self._multiple_splits_config.items():
if all(isinstance(x, float) for x in new_split_dict.values()):
if not all(0 < x <= 1 for x in new_split_dict.values()):
raise ValueError(
Expand Down Expand Up @@ -260,3 +257,30 @@ def _warn_on_potential_misuse_of_divide_split(self) -> None:
"ignored.",
stacklevel=1,
)


def _determine_config_type(
config: Union[
Dict[str, float],
Dict[str, int],
Dict[str, Dict[str, float]],
Dict[str, Dict[str, int]],
],
) -> str:
"""Determine configuration type of `divide_config` based on the dict structure.
Two possible configuration are possible: 1) single-split single-level (works
together with `divide_split`), 2) nested/two-level that works with multiple
splits (`divide_split` is ignored).
Returns
-------
config_type: str
"single-split" or "multiple-splits"
"""
for value in config.values():
# Check if the value is a dictionary
if isinstance(value, dict):
return "multiple-splits"
# If no dictionary values are found, it is single-level
return "single-split"
4 changes: 1 addition & 3 deletions datasets/flwr_datasets/resplitter/divide_resplitter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from flwr_datasets.resplitter import DivideResplitter


# todo a separate case when the name is inferred
@parameterized_class(
("divide_config", "divide_split", "drop_remaining_splits", "split_name_to_size"),
[
Expand Down Expand Up @@ -109,7 +108,6 @@ def test_resplitting_correct_new_split_sizes(self) -> None:
self.assertEqual(self.split_name_to_size, split_to_size)


# todo: test mixed dict types
class TestDivideResplitterIncorrectUseCases(unittest.TestCase):
"""Resplitter tests."""

Expand All @@ -125,7 +123,7 @@ def setUp(self) -> None:

def test_doubling_names_in_config(self) -> None:
"""Test if resplitting raises when the same name in config is detected."""
divide_config = {"train": {"new_train": 0.5}, "valid": {"new_train": 0.5}}
divide_config = {"train": {"new_train": 0.5}, "valid": {"new_train": 0.3}}
divide_split = None
drop_remaining_splits = False

Expand Down

0 comments on commit d5c22e5

Please sign in to comment.