-
Notifications
You must be signed in to change notification settings - Fork 941
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e0dbe0d
commit b4f5dc5
Showing
3 changed files
with
494 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,262 @@ | ||
# Copyright 2024 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""DivideResplitter class for Flower Datasets.""" | ||
import collections | ||
import warnings | ||
from typing import Dict, List, Optional, Union | ||
|
||
import datasets | ||
from datasets import DatasetDict | ||
|
||
|
||
class DivideResplitter: | ||
"""Dive existing split(s) of the dataset and assign them custom names. | ||
Create new `DatasetDict` with new split names corresponding the percentages of data | ||
and custom names. | ||
Parameters | ||
---------- | ||
divide_config: Union[Dict[str, int], Dict[str, float], Dict[str, Dict[str, | ||
int]], Dict[str, Dict[str, float]]] | ||
If single level dictionary with keys - the new split names and the values of int | ||
= number of samples, float - fraction of the split. The fractions do not have | ||
to sum up to 1.0. The order of matter = the first key will get fraction_1 | ||
starting from the beginning of the dataset. | ||
If two level dictionary (dictionary of dictionaries) then the first keys are | ||
the split names that will be divided into different splits. It's an alternative | ||
to specifying `divide_split` if you need to divide many splits. | ||
divide_split: Optional[str] | ||
In case of single level dictionary specification of `divide_config`, specifies | ||
the split name that will be divided. Might be left None in case of a single- | ||
split dataset (it will be automatically inferred). Ignored in case of | ||
multi-split configuration. | ||
drop_remaining_splits: bool | ||
In case of single level dictionary specification of `divide_config`, specifies | ||
if the splits that are not divided are dropped. | ||
Raises | ||
------ | ||
ValuesError if the specified name of a new split is already present in the dataset | ||
and the `drop_remaining_splits` is False. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
divide_config: Union[ | ||
Dict[str, float], | ||
Dict[str, int], | ||
Dict[str, Dict[str, float]], | ||
Dict[str, Dict[str, int]], | ||
], | ||
divide_split: Optional[str] = None, | ||
drop_remaining_splits: bool = False, | ||
) -> None: | ||
self._divide_config = divide_config | ||
self._divide_split = divide_split | ||
self._drop_remaining_splits = drop_remaining_splits | ||
self._config_type = self._determine_config_type() | ||
self._check_duplicate_splits_in_config() | ||
self._warn_on_potential_misuse_of_divide_split() | ||
|
||
def __call__(self, dataset: DatasetDict) -> DatasetDict: | ||
"""Resplit the dataset according to the configuration.""" | ||
if self._drop_remaining_splits is False: | ||
dataset_splits = list(dataset.keys()) | ||
self._check_duplicate_splits_in_config_and_original_dataset(dataset_splits) | ||
return self.resplit(dataset) | ||
|
||
def resplit(self, dataset: DatasetDict) -> DatasetDict: | ||
"""Resplit the dataset according to the configuration.""" | ||
resplit_dataset = {} | ||
dataset_splits: List[str] = list(dataset.keys()) | ||
config_type = self._determine_config_type() | ||
# Change the "single-split" config to look like "multiple-split" config | ||
if config_type == "single-split": | ||
# First, if the `divide_split` is None determine the split | ||
if self._divide_split is None: | ||
if len(dataset_splits) != 1: | ||
raise ValueError( | ||
"When giving the config that is single level and working with " | ||
"dataset with more than one split you need to specify the " | ||
"`divide_split` but given None instead." | ||
) | ||
else: | ||
self._divide_split = dataset_splits[0] | ||
# assert isinstance(self._divide_config, dict) | ||
self._divide_config = {self._divide_split: self._divide_config} | ||
|
||
self._check_size_values(dataset) | ||
# Continue with the resplitting process | ||
# Move the non-split splits if they exist | ||
if self._drop_remaining_splits is False: | ||
if len(dataset_splits) >= 2: | ||
split_splits = set(self._divide_config.keys()) | ||
non_split_splits = list(set(dataset_splits) - split_splits) | ||
for non_split_split in non_split_splits: | ||
resplit_dataset[non_split_split] = dataset[non_split_split] | ||
else: | ||
# The remaining data is not kept (by simply not coping it=the reference) | ||
pass | ||
|
||
# Split the splits | ||
for split_from, new_splits_dict in self._divide_config.items(): | ||
start_index = 0 | ||
end_index = 0 | ||
split_data = dataset[split_from] | ||
assert isinstance(new_splits_dict, dict) | ||
for new_split_name, size in new_splits_dict.items(): | ||
if isinstance(size, float): | ||
end_index += int(len(split_data) * size) | ||
elif isinstance(size, int): | ||
end_index += size | ||
else: | ||
raise ValueError( | ||
"The type of size value for the divide config must " | ||
"be int or float." | ||
) | ||
if end_index > len(split_data): | ||
raise ValueError( | ||
"The size specified in the `divide_config` is greater than " | ||
"the size of the dataset." | ||
) | ||
if end_index == start_index: | ||
raise ValueError( | ||
f"The size specified in the `divide_config` results in the " | ||
f"dataset of size 0. The problem occurred in {new_splits_dict}." | ||
f"Please make sure to provide sizes that do not produce empty" | ||
f"datasets." | ||
) | ||
resplit_dataset[new_split_name] = split_data.select( | ||
range(start_index, end_index) | ||
) | ||
start_index = end_index | ||
return datasets.DatasetDict(resplit_dataset) | ||
|
||
def _check_duplicate_splits_in_config(self) -> None: | ||
"""Check if the new split names are duplicated in `divide_config`.""" | ||
if self._config_type == "single-split": | ||
new_splits = list(self._divide_config.keys()) | ||
elif self._config_type == "multiple-splits": | ||
new_splits = [] | ||
for new_splits_dict in self._divide_config.values(): | ||
assert isinstance(new_splits_dict, dict) | ||
new_values = list(new_splits_dict.values()) | ||
assert isinstance(new_values, list) | ||
new_splits.extend(new_values) | ||
else: | ||
raise ValueError("Incorrect type of config.") | ||
|
||
duplicates = [ | ||
item for item, count in collections.Counter(new_splits).items() if count > 1 | ||
] | ||
if duplicates: | ||
raise ValueError( | ||
"The specified values of the new splits in " | ||
"`divide_config` are duplicated. Please specify" | ||
"unique values for each new split." | ||
) | ||
|
||
def _check_duplicate_splits_in_config_and_original_dataset( | ||
self, dataset_splits: List[str] | ||
) -> None: | ||
"""Check duplicates along the new split values and dataset splits. | ||
This check can happen only at the time this class is called (it does not have | ||
access to the dataset prior to that). | ||
""" | ||
if self._config_type == "single-split": | ||
new_splits = list(self._divide_config.keys()) | ||
all_splits = dataset_splits + new_splits | ||
all_splits.pop(all_splits.index(self._divide_split)) | ||
elif self._config_type == "multiple-splits": | ||
new_splits = [] | ||
for new_splits_dict in self._divide_config.values(): | ||
new_splits.extend(list(new_splits_dict.keys())) | ||
all_splits = dataset_splits + new_splits | ||
for used_split in self._divide_config.keys(): | ||
all_splits.pop(all_splits.index(used_split)) | ||
else: | ||
raise ValueError("Incorrect type of config.") | ||
|
||
duplicates = [ | ||
item for item, count in collections.Counter(all_splits).items() if count > 1 | ||
] | ||
if duplicates: | ||
raise ValueError( | ||
"The specified values of the new splits in " | ||
"`divide_config` are duplicated with the split names of " | ||
"the datasets. " | ||
"Please specify unique values for each new split." | ||
) | ||
|
||
def _determine_config_type(self) -> str: | ||
"""Determine configuration type of `divide_config` based on the dict structure. | ||
Two possible configuration are possible: 1) single-split single-level (works | ||
together with `divide_split`), 2) nested/two-level that works with multiple | ||
splits (`divide_split` is ignored). | ||
Returns | ||
------- | ||
config_type: str | ||
"single-split" or "multiple-splits" | ||
""" | ||
for value in self._divide_config.values(): | ||
# Check if the value is a dictionary | ||
if isinstance(value, dict): | ||
return "multiple-splits" | ||
# If no dictionary values are found, it is single-level | ||
return "single-split" | ||
|
||
def _check_size_values(self, dataset: DatasetDict) -> None: | ||
# It should be called after the `divide_config` is in the multiple-splits format | ||
for split_from, new_split_dict in self._divide_config.items(): | ||
if all(isinstance(x, float) for x in new_split_dict.values()): | ||
if not all(0 < x <= 1 for x in new_split_dict.values()): | ||
raise ValueError( | ||
"All fractions in `divide_config` must be greater than 0 and " | ||
"smaller or equal to 1." | ||
) | ||
if sum(new_split_dict.values()) > 1.0: | ||
raise ValueError( | ||
"The sum of the fractions in `divide_config` must be smaller " | ||
"than 1.0." | ||
) | ||
|
||
elif all(isinstance(x, int) for x in new_split_dict.values()): | ||
dataset_len = len(dataset[split_from]) | ||
len_from_divide_resplit = sum(new_split_dict.values()) | ||
if len_from_divide_resplit > dataset_len: | ||
raise ValueError( | ||
f"The sum of the sample numbers in `divide_config` must be " | ||
f"smaller than the split size. This is not the case for " | ||
f"{split_from} split which is of length {dataset_len} and the " | ||
f"sum in config is {len_from_divide_resplit}." | ||
) | ||
else: | ||
raise TypeError( | ||
"The values in `divide_config` must be either ints or floats. " | ||
"The mix of them or other types are not allowed." | ||
) | ||
|
||
def _warn_on_potential_misuse_of_divide_split(self) -> None: | ||
if self._config_type == "multiple-splits" and self._divide_split is not None: | ||
warnings.warn( | ||
"The `divide_split` was specified but the multiple split " | ||
"configuration was given. The `divide_split` will be " | ||
"ignored.", | ||
stacklevel=1, | ||
) |
Oops, something went wrong.