diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index 38f032a4ccfe..278cf4c9f5ad 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -27,7 +27,10 @@ import datasets from datasets import Dataset, DatasetDict, concatenate_datasets from flwr_datasets.federated_dataset import FederatedDataset -from flwr_datasets.mock_utils_test import _load_mocked_dataset +from flwr_datasets.mock_utils_test import ( + _load_mocked_dataset, + _load_mocked_dataset_dict_by_partial_download, +) from flwr_datasets.partitioner import IidPartitioner, NaturalIdPartitioner, Partitioner mocked_datasets = ["cifar100", "svhn", "sentiment140", "speech_commands"] @@ -403,11 +406,14 @@ def test_mixed_type_partitioners_creates_from_int(self) -> None: "flwrlabs/femnist", ] +mocked_natural_id_datasets = ["flwrlabs/ucf101"] + @parameterized_class( ("dataset_name", "test_split", "subset", "partition_by"), [ ("flwrlabs/femnist", "", "", "writer_id"), + ("flwrlabs/ucf101", "test", None, "video_id"), ], ) class NaturalIdPartitionerIntegrationTest(unittest.TestCase): @@ -418,6 +424,28 @@ class NaturalIdPartitionerIntegrationTest(unittest.TestCase): subset = "" partition_by = "" + def setUp(self) -> None: + """Mock the dataset download prior to each method if needed. + + If the `dataset_name` is in the `mocked_datasets` list, then the dataset + download is mocked. + """ + if self.dataset_name in mocked_natural_id_datasets: + mock_return_value = _load_mocked_dataset_dict_by_partial_download( + dataset_name=self.dataset_name, + split_names=["train"], + skip_take_lists=[[(0, 30), (1000, 30), (2000, 40)]], + subset_name=self.subset, + ) + self.patcher = patch("datasets.load_dataset") + self.mock_load_dataset = self.patcher.start() + self.mock_load_dataset.return_value = mock_return_value + + def tearDown(self) -> None: + """Clean up after the dataset mocking.""" + if self.dataset_name in mocked_natural_id_datasets: + patch.stopall() + def test_if_the_partitions_have_unique_values(self) -> None: """Test if each partition has a single unique id value.""" fds = FederatedDataset( @@ -431,6 +459,22 @@ def test_if_the_partitions_have_unique_values(self) -> None: unique_ids_in_partition = list(set(partition[self.partition_by])) self.assertEqual(len(unique_ids_in_partition), 1) + def tests_if_the_columns_are_unchanged(self) -> None: + """Test if the columns are unchanged after partitioning.""" + fds = FederatedDataset( + dataset=self.dataset_name, + partitioners={ + "train": NaturalIdPartitioner(partition_by=self.partition_by) + }, + ) + dataset = fds.load_split("train") + columns_in_dataset = set(dataset.column_names) + + for partition_id in range(fds.partitioners["train"].num_partitions): + partition = fds.load_partition(partition_id) + columns_in_partition = set(partition.column_names) + self.assertEqual(columns_in_partition, columns_in_dataset) + class IncorrectUsageFederatedDatasets(unittest.TestCase): """Test incorrect usages in FederatedDatasets.""" diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index 578270c3735a..955ea041c2a6 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -36,6 +36,7 @@ "speech_commands", "LIUM/tedlium", # Feature wise it's just like speech_commands "flwrlabs/femnist", + "flwrlabs/ucf101", "jlh/uci-mushrooms", "Mike0307/MNIST-M", "flwrlabs/usps",