Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(datasets) Add function to perform partial download of dataset for tests #3860

Merged
merged 5 commits into from
Jul 22, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion datasets/flwr_datasets/mock_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import random
import string
from datetime import datetime, timedelta
from typing import Any, Dict, List, Set, Tuple, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import numpy as np
from PIL import Image
Expand Down Expand Up @@ -375,3 +375,62 @@ def _load_mocked_dataset(
for params in zip(num_rows, split_names):
dataset_dict[params[1]] = dataset_creation_fnc(params[0])
return datasets.DatasetDict(dataset_dict)


def _load_mocked_dataset_by_partial_download(
dataset_name: str,
split_name: str,
skip_take_list: List[Tuple[int, int]],
subset_name: Optional[str] = None,
) -> Dataset:
"""Download a partial dataset.

This functionality is not supported in the datasets library. This is an informal
way of achieving this by using the `streaming=True` and creating a dataset.Dataset
from in-memory objects.

Parameters
----------
dataset_name: str
Name of the dataset (passed to load_dataset).
split_name: str
Name of the split (passed to load_dataset) e.g. "train".
skip_take_list: List[Tuple[int, int]]
The streaming mode has a specific type of accessing the data, the first tuple
value is how many samples to skip, the second is how many samples to take. Due
to this mechanism, diverse samples can be taken (especially if the dataset is
sorted by the natual_id for NaturalIdPartitioner).
subset_name: Optional[str]
Name of the subset (passed to load_dataset) e.g. "v0.01" for speech_commands.

Returns
-------
dataset: Dataset
The dataset with the requested samples.
"""
dataset = datasets.load_dataset(
dataset_name, name=subset_name, split=split_name, streaming=True
)
dataset_list = []
# It's a list of dict such that each dict represent a single sample of the dataset
# The sample is exactly the same as if the full dataset was downloaded and indexed
for skip, take in skip_take_list:
# dataset.skip(n).take(m) in streaming mode is equivalent (in terms of return)
# to the fully downloaded dataset index: dataset[n+1: (n+1 + m)]
dataset_list.extend(list(dataset.skip(skip).take(take)))
return Dataset.from_list(dataset_list)


def _load_mocked_dataset_dict_by_partial_download(
dataset_name: str,
split_names: List[str],
skip_take_lists: List[List[Tuple[int, int]]],
subset_name: Optional[str] = None,
) -> DatasetDict:
"""Like _load_mocked_dataset_by_partial_download but for many splits."""
dataset_dict = {}
for split_name, skip_take_list in zip(split_names, skip_take_lists):
dataset_dict[split_name] = _load_mocked_dataset_by_partial_download(
dataset_name, split_name, skip_take_list, subset_name
)
return DatasetDict(dataset_dict)