diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index b66874cb34e2..aa6075d9e61a 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -380,15 +380,18 @@ def test_mixed_type_partitioners_creates_from_int(self) -> None: ) -mocked_natural_id_datasets = [ +natural_id_datasets = [ "flwrlabs/femnist", ] +mocked_natural_id_datasets = ["flwrlabs/ucf101"] + @parameterized_class( ("dataset_name", "test_split", "subset", "partition_by"), [ ("flwrlabs/femnist", "", "", "writer_id"), + ("flwrlabs/ucf101", "test", "", "video_id"), ], ) class NaturalIdPartitionerIntegrationTest(unittest.TestCase): @@ -399,6 +402,24 @@ class NaturalIdPartitionerIntegrationTest(unittest.TestCase): subset = "" partition_by = "" + def setUp(self) -> None: + """Mock the dataset download prior to each method if needed. + + If the `dataset_name` is in the `mocked_datasets` list, then the dataset + download is mocked. + """ + if self.dataset_name in mocked_natural_id_datasets: + self.patcher = patch("datasets.load_dataset") + self.mock_load_dataset = self.patcher.start() + self.mock_load_dataset.return_value = _load_mocked_dataset( + self.dataset_name, [20, 10], ["train", self.test_split], self.subset + ) + + def tearDown(self) -> None: + """Clean up after the dataset mocking.""" + if self.dataset_name in mocked_datasets: + patch.stopall() + def test_if_the_partitions_have_unique_values(self) -> None: """Test if each partition has a single unique id value.""" fds = FederatedDataset( diff --git a/datasets/flwr_datasets/mock_utils_test.py b/datasets/flwr_datasets/mock_utils_test.py index bd49de8033de..f31fe17c2543 100644 --- a/datasets/flwr_datasets/mock_utils_test.py +++ b/datasets/flwr_datasets/mock_utils_test.py @@ -355,11 +355,31 @@ def _mock_speach_commands(num_rows: int) -> Dataset: return dataset +def _mock_ucf101(num_rows: int) -> Dataset: + imgs = _generate_random_image_column(num_rows, (320, 240, 3), "JPEG") + unique_video_id = ["0", "1", "2", "3", "4"] + unique_labels = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] + label = _generate_artificial_categories(num_rows, unique_labels) + video_id = _generate_artificial_categories(num_rows, unique_video_id) + features = Features( + { + "image": datasets.Image(decode=True), + "video_id": Value(dtype="string"), + "label": ClassLabel(names=unique_labels), + } + ) + dataset = datasets.Dataset.from_dict( + {"image": imgs, "video_id": video_id, "label": label}, features=features + ) + return dataset + + dataset_name_to_mock_function = { "cifar100": _mock_cifar100, "sentiment140": _mock_sentiment140, "svhn_cropped_digits": _mock_svhn_cropped_digits, "speech_commands_v0.01": _mock_speach_commands, + "flwrlabs/ucf101": _mock_ucf101, } diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index 746dc85478f1..282a33ed10a0 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -35,6 +35,7 @@ "sentiment140", "speech_commands", "flwrlabs/femnist", + "flwrlabs/ucf101", ]