diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index 7ca2b44570ca..5d5179122e3b 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -20,29 +20,70 @@ from typing import Dict, Union from unittest.mock import Mock, patch +import numpy as np import pytest from parameterized import parameterized, parameterized_class import datasets from datasets import Dataset, DatasetDict, concatenate_datasets from flwr_datasets.federated_dataset import FederatedDataset +from flwr_datasets.mock_utils_test import _load_mocked_dataset from flwr_datasets.partitioner import IidPartitioner, Partitioner +mocked_datasets = ["cifar100", "svhn", "sentiment140", "speech_commands"] + @parameterized_class( + ("dataset_name", "test_split", "subset"), [ - {"dataset_name": "mnist", "test_split": "test"}, - {"dataset_name": "cifar10", "test_split": "test"}, - {"dataset_name": "fashion_mnist", "test_split": "test"}, - {"dataset_name": "sasha/dog-food", "test_split": "test"}, - {"dataset_name": "zh-plus/tiny-imagenet", "test_split": "valid"}, - ] + # Downloaded + # #Image datasets + ("mnist", "test", ""), + ("cifar10", "test", ""), + ("fashion_mnist", "test", ""), + ("sasha/dog-food", "test", ""), + ("zh-plus/tiny-imagenet", "valid", ""), + # Text + ("scikit-learn/adult-census-income", None, ""), + # Mocked + # #Image + ("cifar100", "test", ""), + # Note: there's also the extra split and full_numbers subset + ("svhn", "test", "cropped_digits"), + # Text + ("sentiment140", "test", ""), # aka twitter + # Audio + ("speech_commands", "test", "v0.01"), + ], ) -class RealDatasetsFederatedDatasetsTrainTest(unittest.TestCase): - """Test Real Dataset (MNIST, CIFAR10) in FederatedDatasets.""" +class BaseFederatedDatasetsTest(unittest.TestCase): + """Test Real/Mocked Datasets used in FederatedDatasets. + + The setUp method mocks the dataset download via datasets.load_dataset if it is in + the `mocked_datasets` list. + """ dataset_name = "" test_split = "" + subset = "" + + def setUp(self) -> None: + """Mock the dataset download prior to each method if needed. + + If the `dataset_name` is in the `mocked_datasets` list, then the dataset + download is mocked. + """ + if self.dataset_name in mocked_datasets: + self.patcher = patch("datasets.load_dataset") + self.mock_load_dataset = self.patcher.start() + self.mock_load_dataset.return_value = _load_mocked_dataset( + self.dataset_name, [200, 100], ["train", self.test_split], self.subset + ) + + def tearDown(self) -> None: + """Clean up after the dataset mocking.""" + if self.dataset_name in mocked_datasets: + patch.stopall() @parameterized.expand( # type: ignore [ @@ -61,14 +102,25 @@ def test_load_partition_size(self, _: str, train_num_partitions: int) -> None: dataset_fds = FederatedDataset( dataset=self.dataset_name, partitioners={"train": train_num_partitions} ) - dataset_partition0 = dataset_fds.load_partition(0, "train") + # Compute the actual partition sizes + partition_sizes = [] + for node_id in range(train_num_partitions): + partition_sizes.append(len(dataset_fds.load_partition(node_id, "train"))) + + # Create the expected sizes of partitions dataset = datasets.load_dataset(self.dataset_name) - self.assertEqual( - len(dataset_partition0), len(dataset["train"]) // train_num_partitions - ) + full_train_length = len(dataset["train"]) + expected_sizes = [] + default_partition_size = full_train_length // train_num_partitions + mod = full_train_length % train_num_partitions + for i in range(train_num_partitions): + expected_sizes.append(default_partition_size + (1 if i < mod else 0)) + self.assertEqual(partition_sizes, expected_sizes) def test_load_split(self) -> None: """Test if the load_split works with the correct split name.""" + if self.test_split is None: + return dataset_fds = FederatedDataset( dataset=self.dataset_name, partitioners={"train": 100} ) @@ -78,6 +130,8 @@ def test_load_split(self) -> None: def test_multiple_partitioners(self) -> None: """Test if the dataset works when multiple partitioners are specified.""" + if self.test_split is None: + return num_train_partitions = 100 num_test_partitions = 100 dataset_fds = FederatedDataset( @@ -97,7 +151,7 @@ def test_multiple_partitioners(self) -> None: def test_no_need_for_split_keyword_if_one_partitioner(self) -> None: """Test if partitions got with and without split args are the same.""" - fds = FederatedDataset(dataset="mnist", partitioners={"train": 10}) + fds = FederatedDataset(dataset=self.dataset_name, partitioners={"train": 10}) partition_loaded_with_no_split_arg = fds.load_partition(0) partition_loaded_with_verbose_split_arg = fds.load_partition(0, "train") self.assertTrue( @@ -109,6 +163,8 @@ def test_no_need_for_split_keyword_if_one_partitioner(self) -> None: def test_resplit_dataset_into_one(self) -> None: """Test resplit into a single dataset.""" + if self.test_split is None: + return dataset = datasets.load_dataset(self.dataset_name) dataset_length = sum([len(ds) for ds in dataset.values()]) fds = FederatedDataset( @@ -122,6 +178,8 @@ def test_resplit_dataset_into_one(self) -> None: # pylint: disable=protected-access def test_resplit_dataset_to_change_names(self) -> None: """Test resplitter to change the names of the partitions.""" + if self.test_split is None: + return fds = FederatedDataset( dataset=self.dataset_name, partitioners={"new_train": 100}, @@ -138,6 +196,8 @@ def test_resplit_dataset_to_change_names(self) -> None: def test_resplit_dataset_by_callable(self) -> None: """Test resplitter to change the names of the partitions.""" + if self.test_split is None: + return def resplit(dataset: DatasetDict) -> DatasetDict: return DatasetDict( @@ -157,8 +217,13 @@ def resplit(dataset: DatasetDict) -> DatasetDict: self.assertEqual(len(full), dataset_length) -class ArtificialDatasetTest(unittest.TestCase): - """Test using small artificial dataset, mocked load_dataset.""" +class ShufflingResplittingOnArtificialDatasetTest(unittest.TestCase): + """Test shuffling and resplitting using small artificial dataset. + + The purpose of this class is to ensure the order of samples remains as expected. + + The load_dataset method is mocked and the artificial dataset is returned. + """ # pylint: disable=no-self-use def _dummy_setup(self, train_rows: int = 10, test_rows: int = 5) -> DatasetDict: @@ -360,9 +425,26 @@ def datasets_are_equal(ds1: Dataset, ds2: Dataset) -> bool: # Iterate over each row and check for equality for row1, row2 in zip(ds1, ds2): - if row1 != row2: + # Ensure all keys are the same in both rows + if set(row1.keys()) != set(row2.keys()): return False + # Compare values for each key + for key in row1: + if key == "audio": + # Special handling for 'audio' key + if not all( + [ + np.array_equal(row1[key]["array"], row2[key]["array"]), + row1[key]["path"] == row2[key]["path"], + row1[key]["sampling_rate"] == row2[key]["sampling_rate"], + ] + ): + return False + elif row1[key] != row2[key]: + # Direct comparison for other keys + return False + return True diff --git a/datasets/flwr_datasets/mock_utils_test.py b/datasets/flwr_datasets/mock_utils_test.py new file mode 100644 index 000000000000..78aff1f1cdd7 --- /dev/null +++ b/datasets/flwr_datasets/mock_utils_test.py @@ -0,0 +1,377 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utils for mocking datasets.""" + + +import io +import random +import string +from datetime import datetime, timedelta +from typing import Any, Dict, List, Set, Tuple, Union + +import numpy as np +from PIL import Image + +import datasets +from datasets import ClassLabel, Dataset, DatasetDict, Features, Value + + +def _generate_artificial_strings( + num_rows: int, num_unique: int, string_length: int, seed: int = 42 +) -> List[str]: + """Create list of strings for categories or labels mocking. + + Note to keep the seed the same if you reuse this function for in creation of the + dataset for multiple splits. + + Parameters + ---------- + num_rows: int + Number of rows = number of elements in the list. + num_unique: int + Number of unique strings that will be initially created. + string_length: int + Length of each string. + seed: int + Seed to the random package. + + Returns + ------- + string_column : List[str] + List of generated strings. + """ + random.seed(seed) + unique_strings: Set[str] = set() + while len(unique_strings) < num_unique: + random_str = "".join( + random.choices(string.ascii_letters + string.digits, k=string_length) + ) + unique_strings.add(random_str) + + unique_strings_list = list(unique_strings) + artificial_column = unique_strings_list.copy() + remaining_to_allocate = num_rows - num_unique + for _ in range(remaining_to_allocate): + artificial_column.append(random.choice(unique_strings_list)) + return artificial_column + + +def _generate_artificial_categories(num_rows: int, choices: List[Any]) -> List[str]: + """Create list of strings from given `choices` list.""" + artificial_column = choices.copy() + remaining_to_allocate = num_rows - len(choices) + for _ in range(remaining_to_allocate): + artificial_column.append(random.choice(choices)) + return artificial_column + + +def _generate_random_word(length: int) -> str: + """Generate a random word of the given length.""" + return "".join(random.choices(string.ascii_letters, k=length)) + + +def _generate_random_text_column(num_rows: int, length: int) -> List[str]: + """Generate a list of random text of specified length.""" + text_col = [] + for _ in range(num_rows): + text_col.append(_generate_random_word(length)) + return text_col + + +def _generate_random_sentence( + min_word_length: int, + max_word_length: int, + min_sentence_length: int, + max_sentence_length: int, +) -> str: + """Generate a random sentence with words of random lengths.""" + sentence_length = random.randint(min_sentence_length, max_sentence_length) + sentence: List[str] = [] + while len(" ".join(sentence)) < sentence_length: + word_length = random.randint(min_word_length, max_word_length) + word = _generate_random_word(word_length) + sentence.append(word) + return " ".join(sentence) + + +def _generate_random_sentences( + num_rows: int, + min_word_length: int, + max_word_length: int, + min_sentence_length: int, + max_sentence_length: int, +) -> List[str]: + """Generate a list of random sentences.""" + text_col = [ + _generate_random_sentence( + min_word_length, max_word_length, min_sentence_length, max_sentence_length + ) + for _ in range(num_rows) + ] + return text_col + + +def _make_num_rows_none(column: List[Any], num_none: int) -> List[Any]: + """Assign none num_none times to the given list.""" + column_copy = column.copy() + none_positions = random.sample(range(len(column_copy)), num_none) + for pos in none_positions: + column_copy[pos] = None + return column_copy + + +def _generate_random_date( + start_date: datetime, + end_date: datetime, + date_format: str = "%a %b %d %H:%M:%S %Y", + as_string: bool = True, +) -> Union[str, datetime]: + """Generate a random date between start_date and end_date.""" + time_between_dates = end_date - start_date + random_seconds = random.randint(0, int(time_between_dates.total_seconds())) + random_date = start_date + timedelta(seconds=random_seconds) + + if as_string: + return random_date.strftime(date_format) + return random_date + + +def _generate_random_date_column( + num_rows: int, + start_date: datetime, + end_date: datetime, + date_format: str = "%a %b %d %H:%M:%S %Y", + as_string: bool = True, +) -> List[Union[str, datetime]]: + """Generate a list of random dates.""" + return [ + _generate_random_date(start_date, end_date, date_format, as_string) + for _ in range(num_rows) + ] + + +def _generate_random_int_column(num_rows: int, min_int: int, max_int: int) -> List[int]: + """Generate a list of ints.""" + return [random.randint(min_int, max_int) for _ in range(num_rows)] + + +def _generate_random_bool_column(num_rows: int) -> List[bool]: + """Generate a list of bools.""" + return [random.choice([True, False]) for _ in range(num_rows)] + + +def _generate_random_image_column( + num_rows: int, + image_size: Union[Tuple[int, int], Tuple[int, int, int]], + simulate_type: str, +) -> List[Any]: + """Simulate the images with the format that is found in HF Hub. + + Directly using `Image.fromarray` does not work because it creates `PIL.Image.Image`. + """ + # Generate numpy images + np_images = [] + for _ in range(num_rows): + np_images.append(np.random.randint(0, 255, size=image_size, dtype=np.uint8)) + # Change the format to the PIL.PngImagePlugin.PngImageFile + # or the PIL.JpegImagePlugin.JpegImageFile format + pil_imgs = [] + for np_image in np_images: + # Convert the NumPy array to a PIL image + pil_img_beg = Image.fromarray(np_image) # type: ignore + + # Save the image to an in-memory bytes buffer + in_memory_file = io.BytesIO() + pil_img_beg.save(in_memory_file, format=simulate_type) + in_memory_file.seek(0) + + # Reload the image as a PngImageFile + pil_image_end = Image.open(in_memory_file) + pil_imgs.append(pil_image_end) + return pil_imgs + + +def generate_random_audio_column( + num_rows: int, + sampling_rate: int, + length_in_samples: int, +) -> List[Dict[str, Any]]: + """Simulate the audio column. + + Audio column in the datset is comprised from an array or floats, sample_rate and a + path. + """ + # Generate numpy images + audios = [] + for _ in range(num_rows): + audio_array = np.random.uniform(low=-1.0, high=1.0, size=length_in_samples) + audios.append( + {"path": None, "array": audio_array, "sampling_rate": sampling_rate} + ) + return audios + + +def _mock_sentiment140(num_rows: int) -> Dataset: + users = _generate_artificial_strings( + num_rows=num_rows, num_unique=30, string_length=5 + ) + sentiment = _generate_artificial_categories(num_rows=num_rows, choices=[0, 4]) + query = ["NO_QUERY"] * num_rows + + # Sentences + min_word_length = 3 + max_word_length = 8 + min_sentence_length = 20 + max_sentence_length = 60 + + text = _generate_random_sentences( + num_rows, + min_word_length, + max_word_length, + min_sentence_length, + max_sentence_length, + ) + + start_date = datetime(2009, 1, 1) + end_date = datetime(2010, 12, 31) + date_format = "%a %b %d %H:%M:%S %Y" + + # Generate a list of random dates as strings + date = _generate_random_date_column( + num_rows, start_date, end_date, date_format, as_string=True + ) + + features = Features( + { + "text": Value(dtype="string"), + "date": Value(dtype="string"), + "user": Value(dtype="string"), + "sentiment": Value(dtype="int32"), + "query": Value(dtype="string"), + } + ) + dataset = datasets.Dataset.from_dict( + { + "user": users, + "sentiment": sentiment, + "query": query, + "text": text, + "date": date, + }, + features=features, + ) + return dataset + + +def _mock_cifar100(num_rows: int) -> Dataset: + imgs = _generate_random_image_column(num_rows, (32, 32, 3), "PNG") + unique_fine_labels = _generate_artificial_strings( + num_rows=100, num_unique=100, string_length=10, seed=42 + ) + fine_label = _generate_artificial_categories(num_rows, unique_fine_labels) + unique_coarse_labels = _generate_artificial_strings( + num_rows=20, num_unique=20, string_length=10, seed=42 + ) + + coarse_label = _generate_artificial_categories(num_rows, unique_coarse_labels) + features = Features( + { + "img": datasets.Image(decode=True), + "fine_label": ClassLabel(names=unique_fine_labels), + "coarse_label": ClassLabel(names=unique_coarse_labels), + } + ) + dataset = datasets.Dataset.from_dict( + {"img": imgs, "coarse_label": coarse_label, "fine_label": fine_label}, + features=features, + ) + return dataset + + +def _mock_svhn_cropped_digits(num_rows: int) -> Dataset: + imgs = _generate_random_image_column(num_rows, (32, 32, 3), "PNG") + unique_labels = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] + label = _generate_artificial_categories(num_rows, unique_labels) + features = Features( + { + "image": datasets.Image(decode=True), + "label": ClassLabel(names=unique_labels), + } + ) + dataset = datasets.Dataset.from_dict( + {"image": imgs, "label": label}, features=features + ) + return dataset + + +def _mock_speach_commands(num_rows: int) -> Dataset: + sampling_rate = 16_000 + length_in_samples = 16_000 + imgs = generate_random_audio_column( + num_rows=num_rows, + sampling_rate=sampling_rate, + length_in_samples=length_in_samples, + ) + unique_labels = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] + label = _generate_artificial_categories(num_rows, unique_labels) + is_unknown = _generate_random_bool_column(num_rows) + utterance_id = _generate_random_int_column(num_rows, 0, 10) + unique_ids = _generate_random_text_column(num_rows // 10, 5) + speaker_id = _generate_artificial_categories(num_rows, unique_ids) + speaker_id = _make_num_rows_none(speaker_id, 10) + features = Features( + { + "audio": datasets.Audio( + sampling_rate=sampling_rate, mono=True, decode=True + ), + "is_unknown": Value(dtype="bool"), + "speaker_id": Value(dtype="string"), + "utterance_id": Value(dtype="int8"), + "label": ClassLabel(names=unique_labels), + } + ) + dataset = datasets.Dataset.from_dict( + { + "audio": imgs, + "is_unknown": is_unknown, + "speaker_id": speaker_id, + "utterance_id": utterance_id, + "label": label, + }, + features=features, + ) + return dataset + + +dataset_name_to_mock_function = { + "cifar100": _mock_cifar100, + "sentiment140": _mock_sentiment140, + "svhn_cropped_digits": _mock_svhn_cropped_digits, + "speech_commands_v0.01": _mock_speach_commands, +} + + +def _load_mocked_dataset( + dataset_name: str, + num_rows: List[int], + split_names: List[str], + subset: str = "", +) -> DatasetDict: + dataset_dict = {} + name = dataset_name if subset == "" else dataset_name + "_" + subset + dataset_creation_fnc = dataset_name_to_mock_function[name] + for params in zip(num_rows, split_names): + dataset_dict[params[1]] = dataset_creation_fnc(params[0]) + return datasets.DatasetDict(dataset_dict) diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index 346d897ccdd6..c6f6900a99cd 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -29,6 +29,11 @@ "fashion_mnist", "sasha/dog-food", "zh-plus/tiny-imagenet", + "scikit-learn/adult-census-income", + "cifar100", + "svhn", + "sentiment140", + "speech_commands", ]