diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index 6dc685df4c40..87cb541a430e 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -27,7 +27,10 @@ import datasets from datasets import Dataset, DatasetDict, concatenate_datasets from flwr_datasets.federated_dataset import FederatedDataset -from flwr_datasets.mock_utils_test import _load_mocked_dataset +from flwr_datasets.mock_utils_test import ( + _load_mocked_dataset, + _load_mocked_dataset_dict_by_partial_download, +) from flwr_datasets.partitioner import IidPartitioner, NaturalIdPartitioner, Partitioner mocked_datasets = ["cifar100", "svhn", "sentiment140", "speech_commands"] @@ -411,7 +414,7 @@ def test_mixed_type_partitioners_creates_from_int(self) -> None: ("dataset_name", "test_split", "subset", "partition_by"), [ ("flwrlabs/femnist", "", "", "writer_id"), - ("flwrlabs/ucf101", "test", "", "video_id"), + ("flwrlabs/ucf101", "test", None, "video_id"), ], ) class NaturalIdPartitionerIntegrationTest(unittest.TestCase): @@ -429,15 +432,19 @@ def setUp(self) -> None: download is mocked. """ if self.dataset_name in mocked_natural_id_datasets: + mock_return_value = _load_mocked_dataset_dict_by_partial_download( + dataset_name=self.dataset_name, + split_names=["train"], + skip_take_lists=[[(0, 30), (1000, 30), (2000, 40)]], + subset_name=self.subset, + ) self.patcher = patch("datasets.load_dataset") self.mock_load_dataset = self.patcher.start() - self.mock_load_dataset.return_value = _load_mocked_dataset( - self.dataset_name, [20, 10], ["train", self.test_split], self.subset - ) + self.mock_load_dataset.return_value = mock_return_value def tearDown(self) -> None: """Clean up after the dataset mocking.""" - if self.dataset_name in mocked_datasets: + if self.dataset_name in mocked_natural_id_datasets: patch.stopall() def test_if_the_partitions_have_unique_values(self) -> None: @@ -453,6 +460,22 @@ def test_if_the_partitions_have_unique_values(self) -> None: unique_ids_in_partition = list(set(partition[self.partition_by])) self.assertEqual(len(unique_ids_in_partition), 1) + def tests_if_the_columns_are_unchanged(self) -> None: + """Test if the columns are unchanged after partitioning.""" + fds = FederatedDataset( + dataset=self.dataset_name, + partitioners={ + "train": NaturalIdPartitioner(partition_by=self.partition_by) + }, + ) + dataset = fds.load_split("train") + columns_in_dataset = set(dataset.column_names) + + for partition_id in range(fds.partitioners["train"].num_partitions): + partition = fds.load_partition(partition_id) + columns_in_partition = set(partition.column_names) + self.assertEqual(columns_in_partition, columns_in_dataset) + class IncorrectUsageFederatedDatasets(unittest.TestCase): """Test incorrect usages in FederatedDatasets."""