Skip to content

Commit

Permalink
feat(datasets) Add tests for ucf101 dataset (#3842)
Browse files Browse the repository at this point in the history
Co-authored-by: jafermarq <[email protected]>
  • Loading branch information
adam-narozniak and jafermarq authored Jul 22, 2024
1 parent b7ae03d commit 9f685be
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
46 changes: 45 additions & 1 deletion datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
import datasets
from datasets import Dataset, DatasetDict, concatenate_datasets
from flwr_datasets.federated_dataset import FederatedDataset
from flwr_datasets.mock_utils_test import _load_mocked_dataset
from flwr_datasets.mock_utils_test import (
_load_mocked_dataset,
_load_mocked_dataset_dict_by_partial_download,
)
from flwr_datasets.partitioner import IidPartitioner, NaturalIdPartitioner, Partitioner

mocked_datasets = ["cifar100", "svhn", "sentiment140", "speech_commands"]
Expand Down Expand Up @@ -403,11 +406,14 @@ def test_mixed_type_partitioners_creates_from_int(self) -> None:
"flwrlabs/femnist",
]

mocked_natural_id_datasets = ["flwrlabs/ucf101"]


@parameterized_class(
("dataset_name", "test_split", "subset", "partition_by"),
[
("flwrlabs/femnist", "", "", "writer_id"),
("flwrlabs/ucf101", "test", None, "video_id"),
],
)
class NaturalIdPartitionerIntegrationTest(unittest.TestCase):
Expand All @@ -418,6 +424,28 @@ class NaturalIdPartitionerIntegrationTest(unittest.TestCase):
subset = ""
partition_by = ""

def setUp(self) -> None:
"""Mock the dataset download prior to each method if needed.
If the `dataset_name` is in the `mocked_datasets` list, then the dataset
download is mocked.
"""
if self.dataset_name in mocked_natural_id_datasets:
mock_return_value = _load_mocked_dataset_dict_by_partial_download(
dataset_name=self.dataset_name,
split_names=["train"],
skip_take_lists=[[(0, 30), (1000, 30), (2000, 40)]],
subset_name=self.subset,
)
self.patcher = patch("datasets.load_dataset")
self.mock_load_dataset = self.patcher.start()
self.mock_load_dataset.return_value = mock_return_value

def tearDown(self) -> None:
"""Clean up after the dataset mocking."""
if self.dataset_name in mocked_natural_id_datasets:
patch.stopall()

def test_if_the_partitions_have_unique_values(self) -> None:
"""Test if each partition has a single unique id value."""
fds = FederatedDataset(
Expand All @@ -431,6 +459,22 @@ def test_if_the_partitions_have_unique_values(self) -> None:
unique_ids_in_partition = list(set(partition[self.partition_by]))
self.assertEqual(len(unique_ids_in_partition), 1)

def tests_if_the_columns_are_unchanged(self) -> None:
"""Test if the columns are unchanged after partitioning."""
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners={
"train": NaturalIdPartitioner(partition_by=self.partition_by)
},
)
dataset = fds.load_split("train")
columns_in_dataset = set(dataset.column_names)

for partition_id in range(fds.partitioners["train"].num_partitions):
partition = fds.load_partition(partition_id)
columns_in_partition = set(partition.column_names)
self.assertEqual(columns_in_partition, columns_in_dataset)


class IncorrectUsageFederatedDatasets(unittest.TestCase):
"""Test incorrect usages in FederatedDatasets."""
Expand Down
1 change: 1 addition & 0 deletions datasets/flwr_datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"speech_commands",
"LIUM/tedlium", # Feature wise it's just like speech_commands
"flwrlabs/femnist",
"flwrlabs/ucf101",
"jlh/uci-mushrooms",
"Mike0307/MNIST-M",
"flwrlabs/usps",
Expand Down

0 comments on commit 9f685be

Please sign in to comment.