Skip to content

Commit

Permalink
Add tests for ucf101 dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Jul 18, 2024
1 parent 016348b commit 5292923
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 1 deletion.
23 changes: 22 additions & 1 deletion datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,15 +380,18 @@ def test_mixed_type_partitioners_creates_from_int(self) -> None:
)


mocked_natural_id_datasets = [
natural_id_datasets = [
"flwrlabs/femnist",
]

mocked_natural_id_datasets = ["flwrlabs/ucf101"]


@parameterized_class(
("dataset_name", "test_split", "subset", "partition_by"),
[
("flwrlabs/femnist", "", "", "writer_id"),
("flwrlabs/ucf101", "test", "", "video_id"),
],
)
class NaturalIdPartitionerIntegrationTest(unittest.TestCase):
Expand All @@ -399,6 +402,24 @@ class NaturalIdPartitionerIntegrationTest(unittest.TestCase):
subset = ""
partition_by = ""

def setUp(self) -> None:
"""Mock the dataset download prior to each method if needed.
If the `dataset_name` is in the `mocked_datasets` list, then the dataset
download is mocked.
"""
if self.dataset_name in mocked_natural_id_datasets:
self.patcher = patch("datasets.load_dataset")
self.mock_load_dataset = self.patcher.start()
self.mock_load_dataset.return_value = _load_mocked_dataset(
self.dataset_name, [20, 10], ["train", self.test_split], self.subset
)

def tearDown(self) -> None:
"""Clean up after the dataset mocking."""
if self.dataset_name in mocked_datasets:
patch.stopall()

def test_if_the_partitions_have_unique_values(self) -> None:
"""Test if each partition has a single unique id value."""
fds = FederatedDataset(
Expand Down
20 changes: 20 additions & 0 deletions datasets/flwr_datasets/mock_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,11 +355,31 @@ def _mock_speach_commands(num_rows: int) -> Dataset:
return dataset


def _mock_ucf101(num_rows: int) -> Dataset:
imgs = _generate_random_image_column(num_rows, (320, 240, 3), "JPEG")
unique_video_id = ["0", "1", "2", "3", "4"]
unique_labels = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
label = _generate_artificial_categories(num_rows, unique_labels)
video_id = _generate_artificial_categories(num_rows, unique_video_id)
features = Features(
{
"image": datasets.Image(decode=True),
"video_id": Value(dtype="string"),
"label": ClassLabel(names=unique_labels),
}
)
dataset = datasets.Dataset.from_dict(
{"image": imgs, "video_id": video_id, "label": label}, features=features
)
return dataset


dataset_name_to_mock_function = {
"cifar100": _mock_cifar100,
"sentiment140": _mock_sentiment140,
"svhn_cropped_digits": _mock_svhn_cropped_digits,
"speech_commands_v0.01": _mock_speach_commands,
"flwrlabs/ucf101": _mock_ucf101,
}


Expand Down
1 change: 1 addition & 0 deletions datasets/flwr_datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"sentiment140",
"speech_commands",
"flwrlabs/femnist",
"flwrlabs/ucf101",
]


Expand Down

0 comments on commit 5292923

Please sign in to comment.