Skip to content

Commit

Permalink
Add divide resplitter
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Feb 15, 2024
1 parent e0dbe0d commit b4f5dc5
Show file tree
Hide file tree
Showing 3 changed files with 494 additions and 0 deletions.
2 changes: 2 additions & 0 deletions datasets/flwr_datasets/resplitter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
"""Resplitter package."""


from .divide_resplitter import DivideResplitter
from .merge_resplitter import MergeResplitter
from .resplitter import Resplitter

__all__ = [
"DivideResplitter",
"MergeResplitter",
"Resplitter",
]
262 changes: 262 additions & 0 deletions datasets/flwr_datasets/resplitter/divide_resplitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DivideResplitter class for Flower Datasets."""
import collections
import warnings
from typing import Dict, List, Optional, Union

import datasets
from datasets import DatasetDict


class DivideResplitter:
"""Dive existing split(s) of the dataset and assign them custom names.
Create new `DatasetDict` with new split names corresponding the percentages of data
and custom names.
Parameters
----------
divide_config: Union[Dict[str, int], Dict[str, float], Dict[str, Dict[str,
int]], Dict[str, Dict[str, float]]]
If single level dictionary with keys - the new split names and the values of int
= number of samples, float - fraction of the split. The fractions do not have
to sum up to 1.0. The order of matter = the first key will get fraction_1
starting from the beginning of the dataset.
If two level dictionary (dictionary of dictionaries) then the first keys are
the split names that will be divided into different splits. It's an alternative
to specifying `divide_split` if you need to divide many splits.
divide_split: Optional[str]
In case of single level dictionary specification of `divide_config`, specifies
the split name that will be divided. Might be left None in case of a single-
split dataset (it will be automatically inferred). Ignored in case of
multi-split configuration.
drop_remaining_splits: bool
In case of single level dictionary specification of `divide_config`, specifies
if the splits that are not divided are dropped.
Raises
------
ValuesError if the specified name of a new split is already present in the dataset
and the `drop_remaining_splits` is False.
"""

def __init__(
self,
divide_config: Union[
Dict[str, float],
Dict[str, int],
Dict[str, Dict[str, float]],
Dict[str, Dict[str, int]],
],
divide_split: Optional[str] = None,
drop_remaining_splits: bool = False,
) -> None:
self._divide_config = divide_config
self._divide_split = divide_split
self._drop_remaining_splits = drop_remaining_splits
self._config_type = self._determine_config_type()
self._check_duplicate_splits_in_config()
self._warn_on_potential_misuse_of_divide_split()

def __call__(self, dataset: DatasetDict) -> DatasetDict:
"""Resplit the dataset according to the configuration."""
if self._drop_remaining_splits is False:
dataset_splits = list(dataset.keys())
self._check_duplicate_splits_in_config_and_original_dataset(dataset_splits)
return self.resplit(dataset)

def resplit(self, dataset: DatasetDict) -> DatasetDict:
"""Resplit the dataset according to the configuration."""
resplit_dataset = {}
dataset_splits: List[str] = list(dataset.keys())
config_type = self._determine_config_type()
# Change the "single-split" config to look like "multiple-split" config
if config_type == "single-split":
# First, if the `divide_split` is None determine the split
if self._divide_split is None:
if len(dataset_splits) != 1:
raise ValueError(
"When giving the config that is single level and working with "
"dataset with more than one split you need to specify the "
"`divide_split` but given None instead."
)
else:
self._divide_split = dataset_splits[0]
# assert isinstance(self._divide_config, dict)
self._divide_config = {self._divide_split: self._divide_config}

self._check_size_values(dataset)
# Continue with the resplitting process
# Move the non-split splits if they exist
if self._drop_remaining_splits is False:
if len(dataset_splits) >= 2:
split_splits = set(self._divide_config.keys())
non_split_splits = list(set(dataset_splits) - split_splits)
for non_split_split in non_split_splits:
resplit_dataset[non_split_split] = dataset[non_split_split]
else:
# The remaining data is not kept (by simply not coping it=the reference)
pass

# Split the splits
for split_from, new_splits_dict in self._divide_config.items():
start_index = 0
end_index = 0
split_data = dataset[split_from]
assert isinstance(new_splits_dict, dict)
for new_split_name, size in new_splits_dict.items():
if isinstance(size, float):
end_index += int(len(split_data) * size)
elif isinstance(size, int):
end_index += size
else:
raise ValueError(
"The type of size value for the divide config must "
"be int or float."
)
if end_index > len(split_data):
raise ValueError(
"The size specified in the `divide_config` is greater than "
"the size of the dataset."
)
if end_index == start_index:
raise ValueError(
f"The size specified in the `divide_config` results in the "
f"dataset of size 0. The problem occurred in {new_splits_dict}."
f"Please make sure to provide sizes that do not produce empty"
f"datasets."
)
resplit_dataset[new_split_name] = split_data.select(
range(start_index, end_index)
)
start_index = end_index
return datasets.DatasetDict(resplit_dataset)

def _check_duplicate_splits_in_config(self) -> None:
"""Check if the new split names are duplicated in `divide_config`."""
if self._config_type == "single-split":
new_splits = list(self._divide_config.keys())
elif self._config_type == "multiple-splits":
new_splits = []
for new_splits_dict in self._divide_config.values():
assert isinstance(new_splits_dict, dict)
new_values = list(new_splits_dict.values())
assert isinstance(new_values, list)
new_splits.extend(new_values)
else:
raise ValueError("Incorrect type of config.")

duplicates = [
item for item, count in collections.Counter(new_splits).items() if count > 1
]
if duplicates:
raise ValueError(
"The specified values of the new splits in "
"`divide_config` are duplicated. Please specify"
"unique values for each new split."
)

def _check_duplicate_splits_in_config_and_original_dataset(
self, dataset_splits: List[str]
) -> None:
"""Check duplicates along the new split values and dataset splits.
This check can happen only at the time this class is called (it does not have
access to the dataset prior to that).
"""
if self._config_type == "single-split":
new_splits = list(self._divide_config.keys())
all_splits = dataset_splits + new_splits
all_splits.pop(all_splits.index(self._divide_split))
elif self._config_type == "multiple-splits":
new_splits = []
for new_splits_dict in self._divide_config.values():
new_splits.extend(list(new_splits_dict.keys()))
all_splits = dataset_splits + new_splits
for used_split in self._divide_config.keys():
all_splits.pop(all_splits.index(used_split))
else:
raise ValueError("Incorrect type of config.")

duplicates = [
item for item, count in collections.Counter(all_splits).items() if count > 1
]
if duplicates:
raise ValueError(
"The specified values of the new splits in "
"`divide_config` are duplicated with the split names of "
"the datasets. "
"Please specify unique values for each new split."
)

def _determine_config_type(self) -> str:
"""Determine configuration type of `divide_config` based on the dict structure.
Two possible configuration are possible: 1) single-split single-level (works
together with `divide_split`), 2) nested/two-level that works with multiple
splits (`divide_split` is ignored).
Returns
-------
config_type: str
"single-split" or "multiple-splits"
"""
for value in self._divide_config.values():
# Check if the value is a dictionary
if isinstance(value, dict):
return "multiple-splits"
# If no dictionary values are found, it is single-level
return "single-split"

def _check_size_values(self, dataset: DatasetDict) -> None:
# It should be called after the `divide_config` is in the multiple-splits format
for split_from, new_split_dict in self._divide_config.items():
if all(isinstance(x, float) for x in new_split_dict.values()):
if not all(0 < x <= 1 for x in new_split_dict.values()):
raise ValueError(
"All fractions in `divide_config` must be greater than 0 and "
"smaller or equal to 1."
)
if sum(new_split_dict.values()) > 1.0:
raise ValueError(
"The sum of the fractions in `divide_config` must be smaller "
"than 1.0."
)

elif all(isinstance(x, int) for x in new_split_dict.values()):
dataset_len = len(dataset[split_from])
len_from_divide_resplit = sum(new_split_dict.values())
if len_from_divide_resplit > dataset_len:
raise ValueError(
f"The sum of the sample numbers in `divide_config` must be "
f"smaller than the split size. This is not the case for "
f"{split_from} split which is of length {dataset_len} and the "
f"sum in config is {len_from_divide_resplit}."
)
else:
raise TypeError(
"The values in `divide_config` must be either ints or floats. "
"The mix of them or other types are not allowed."
)

def _warn_on_potential_misuse_of_divide_split(self) -> None:
if self._config_type == "multiple-splits" and self._divide_split is not None:
warnings.warn(
"The `divide_split` was specified but the multiple split "
"configuration was given. The `divide_split` will be "
"ignored.",
stacklevel=1,
)
Loading

0 comments on commit b4f5dc5

Please sign in to comment.