Skip to content

Commit

Permalink
Add datasets tests for FDS (#2964)
Browse files Browse the repository at this point in the history
Co-authored-by: Javier <[email protected]>
  • Loading branch information
adam-narozniak and jafermarq authored Apr 25, 2024
1 parent 8379d97 commit f50f690
Show file tree
Hide file tree
Showing 3 changed files with 480 additions and 16 deletions.
114 changes: 98 additions & 16 deletions datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,70 @@
from typing import Dict, Union
from unittest.mock import Mock, patch

import numpy as np
import pytest
from parameterized import parameterized, parameterized_class

import datasets
from datasets import Dataset, DatasetDict, concatenate_datasets
from flwr_datasets.federated_dataset import FederatedDataset
from flwr_datasets.mock_utils_test import _load_mocked_dataset
from flwr_datasets.partitioner import IidPartitioner, Partitioner

mocked_datasets = ["cifar100", "svhn", "sentiment140", "speech_commands"]


@parameterized_class(
("dataset_name", "test_split", "subset"),
[
{"dataset_name": "mnist", "test_split": "test"},
{"dataset_name": "cifar10", "test_split": "test"},
{"dataset_name": "fashion_mnist", "test_split": "test"},
{"dataset_name": "sasha/dog-food", "test_split": "test"},
{"dataset_name": "zh-plus/tiny-imagenet", "test_split": "valid"},
]
# Downloaded
# #Image datasets
("mnist", "test", ""),
("cifar10", "test", ""),
("fashion_mnist", "test", ""),
("sasha/dog-food", "test", ""),
("zh-plus/tiny-imagenet", "valid", ""),
# Text
("scikit-learn/adult-census-income", None, ""),
# Mocked
# #Image
("cifar100", "test", ""),
# Note: there's also the extra split and full_numbers subset
("svhn", "test", "cropped_digits"),
# Text
("sentiment140", "test", ""), # aka twitter
# Audio
("speech_commands", "test", "v0.01"),
],
)
class RealDatasetsFederatedDatasetsTrainTest(unittest.TestCase):
"""Test Real Dataset (MNIST, CIFAR10) in FederatedDatasets."""
class BaseFederatedDatasetsTest(unittest.TestCase):
"""Test Real/Mocked Datasets used in FederatedDatasets.
The setUp method mocks the dataset download via datasets.load_dataset if it is in
the `mocked_datasets` list.
"""

dataset_name = ""
test_split = ""
subset = ""

def setUp(self) -> None:
"""Mock the dataset download prior to each method if needed.
If the `dataset_name` is in the `mocked_datasets` list, then the dataset
download is mocked.
"""
if self.dataset_name in mocked_datasets:
self.patcher = patch("datasets.load_dataset")
self.mock_load_dataset = self.patcher.start()
self.mock_load_dataset.return_value = _load_mocked_dataset(
self.dataset_name, [200, 100], ["train", self.test_split], self.subset
)

def tearDown(self) -> None:
"""Clean up after the dataset mocking."""
if self.dataset_name in mocked_datasets:
patch.stopall()

@parameterized.expand( # type: ignore
[
Expand All @@ -61,14 +102,25 @@ def test_load_partition_size(self, _: str, train_num_partitions: int) -> None:
dataset_fds = FederatedDataset(
dataset=self.dataset_name, partitioners={"train": train_num_partitions}
)
dataset_partition0 = dataset_fds.load_partition(0, "train")
# Compute the actual partition sizes
partition_sizes = []
for node_id in range(train_num_partitions):
partition_sizes.append(len(dataset_fds.load_partition(node_id, "train")))

# Create the expected sizes of partitions
dataset = datasets.load_dataset(self.dataset_name)
self.assertEqual(
len(dataset_partition0), len(dataset["train"]) // train_num_partitions
)
full_train_length = len(dataset["train"])
expected_sizes = []
default_partition_size = full_train_length // train_num_partitions
mod = full_train_length % train_num_partitions
for i in range(train_num_partitions):
expected_sizes.append(default_partition_size + (1 if i < mod else 0))
self.assertEqual(partition_sizes, expected_sizes)

def test_load_split(self) -> None:
"""Test if the load_split works with the correct split name."""
if self.test_split is None:
return
dataset_fds = FederatedDataset(
dataset=self.dataset_name, partitioners={"train": 100}
)
Expand All @@ -78,6 +130,8 @@ def test_load_split(self) -> None:

def test_multiple_partitioners(self) -> None:
"""Test if the dataset works when multiple partitioners are specified."""
if self.test_split is None:
return
num_train_partitions = 100
num_test_partitions = 100
dataset_fds = FederatedDataset(
Expand All @@ -97,7 +151,7 @@ def test_multiple_partitioners(self) -> None:

def test_no_need_for_split_keyword_if_one_partitioner(self) -> None:
"""Test if partitions got with and without split args are the same."""
fds = FederatedDataset(dataset="mnist", partitioners={"train": 10})
fds = FederatedDataset(dataset=self.dataset_name, partitioners={"train": 10})
partition_loaded_with_no_split_arg = fds.load_partition(0)
partition_loaded_with_verbose_split_arg = fds.load_partition(0, "train")
self.assertTrue(
Expand All @@ -109,6 +163,8 @@ def test_no_need_for_split_keyword_if_one_partitioner(self) -> None:

def test_resplit_dataset_into_one(self) -> None:
"""Test resplit into a single dataset."""
if self.test_split is None:
return
dataset = datasets.load_dataset(self.dataset_name)
dataset_length = sum([len(ds) for ds in dataset.values()])
fds = FederatedDataset(
Expand All @@ -122,6 +178,8 @@ def test_resplit_dataset_into_one(self) -> None:
# pylint: disable=protected-access
def test_resplit_dataset_to_change_names(self) -> None:
"""Test resplitter to change the names of the partitions."""
if self.test_split is None:
return
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners={"new_train": 100},
Expand All @@ -138,6 +196,8 @@ def test_resplit_dataset_to_change_names(self) -> None:

def test_resplit_dataset_by_callable(self) -> None:
"""Test resplitter to change the names of the partitions."""
if self.test_split is None:
return

def resplit(dataset: DatasetDict) -> DatasetDict:
return DatasetDict(
Expand All @@ -157,8 +217,13 @@ def resplit(dataset: DatasetDict) -> DatasetDict:
self.assertEqual(len(full), dataset_length)


class ArtificialDatasetTest(unittest.TestCase):
"""Test using small artificial dataset, mocked load_dataset."""
class ShufflingResplittingOnArtificialDatasetTest(unittest.TestCase):
"""Test shuffling and resplitting using small artificial dataset.
The purpose of this class is to ensure the order of samples remains as expected.
The load_dataset method is mocked and the artificial dataset is returned.
"""

# pylint: disable=no-self-use
def _dummy_setup(self, train_rows: int = 10, test_rows: int = 5) -> DatasetDict:
Expand Down Expand Up @@ -360,9 +425,26 @@ def datasets_are_equal(ds1: Dataset, ds2: Dataset) -> bool:

# Iterate over each row and check for equality
for row1, row2 in zip(ds1, ds2):
if row1 != row2:
# Ensure all keys are the same in both rows
if set(row1.keys()) != set(row2.keys()):
return False

# Compare values for each key
for key in row1:
if key == "audio":
# Special handling for 'audio' key
if not all(
[
np.array_equal(row1[key]["array"], row2[key]["array"]),
row1[key]["path"] == row2[key]["path"],
row1[key]["sampling_rate"] == row2[key]["sampling_rate"],
]
):
return False
elif row1[key] != row2[key]:
# Direct comparison for other keys
return False

return True


Expand Down
Loading

0 comments on commit f50f690

Please sign in to comment.