Skip to content

Commit

Permalink
Mock ucf101 by partial download
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Jul 22, 2024
1 parent 02e085f commit 2e955c9
Showing 1 changed file with 29 additions and 6 deletions.
35 changes: 29 additions & 6 deletions datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
import datasets
from datasets import Dataset, DatasetDict, concatenate_datasets
from flwr_datasets.federated_dataset import FederatedDataset
from flwr_datasets.mock_utils_test import _load_mocked_dataset
from flwr_datasets.mock_utils_test import (
_load_mocked_dataset,
_load_mocked_dataset_dict_by_partial_download,
)
from flwr_datasets.partitioner import IidPartitioner, NaturalIdPartitioner, Partitioner

mocked_datasets = ["cifar100", "svhn", "sentiment140", "speech_commands"]
Expand Down Expand Up @@ -411,7 +414,7 @@ def test_mixed_type_partitioners_creates_from_int(self) -> None:
("dataset_name", "test_split", "subset", "partition_by"),
[
("flwrlabs/femnist", "", "", "writer_id"),
("flwrlabs/ucf101", "test", "", "video_id"),
("flwrlabs/ucf101", "test", None, "video_id"),
],
)
class NaturalIdPartitionerIntegrationTest(unittest.TestCase):
Expand All @@ -429,15 +432,19 @@ def setUp(self) -> None:
download is mocked.
"""
if self.dataset_name in mocked_natural_id_datasets:
mock_return_value = _load_mocked_dataset_dict_by_partial_download(
dataset_name=self.dataset_name,
split_names=["train"],
skip_take_lists=[[(0, 30), (1000, 30), (2000, 40)]],
subset_name=self.subset,
)
self.patcher = patch("datasets.load_dataset")
self.mock_load_dataset = self.patcher.start()
self.mock_load_dataset.return_value = _load_mocked_dataset(
self.dataset_name, [20, 10], ["train", self.test_split], self.subset
)
self.mock_load_dataset.return_value = mock_return_value

def tearDown(self) -> None:
"""Clean up after the dataset mocking."""
if self.dataset_name in mocked_datasets:
if self.dataset_name in mocked_natural_id_datasets:
patch.stopall()

def test_if_the_partitions_have_unique_values(self) -> None:
Expand All @@ -453,6 +460,22 @@ def test_if_the_partitions_have_unique_values(self) -> None:
unique_ids_in_partition = list(set(partition[self.partition_by]))
self.assertEqual(len(unique_ids_in_partition), 1)

def tests_if_the_columns_are_unchanged(self) -> None:
"""Test if the columns are unchanged after partitioning."""
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners={
"train": NaturalIdPartitioner(partition_by=self.partition_by)
},
)
dataset = fds.load_split("train")
columns_in_dataset = set(dataset.column_names)

for partition_id in range(fds.partitioners["train"].num_partitions):
partition = fds.load_partition(partition_id)
columns_in_partition = set(partition.column_names)
self.assertEqual(columns_in_partition, columns_in_dataset)


class IncorrectUsageFederatedDatasets(unittest.TestCase):
"""Test incorrect usages in FederatedDatasets."""
Expand Down

0 comments on commit 2e955c9

Please sign in to comment.