diff --git a/datasets/flwr_datasets/mock_utils_test.py b/datasets/flwr_datasets/mock_utils_test.py index bd49de8033de..7ee3bae890ff 100644 --- a/datasets/flwr_datasets/mock_utils_test.py +++ b/datasets/flwr_datasets/mock_utils_test.py @@ -19,7 +19,7 @@ import random import string from datetime import datetime, timedelta -from typing import Any, Dict, List, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import numpy as np from PIL import Image @@ -375,3 +375,62 @@ def _load_mocked_dataset( for params in zip(num_rows, split_names): dataset_dict[params[1]] = dataset_creation_fnc(params[0]) return datasets.DatasetDict(dataset_dict) + + +def _load_mocked_dataset_by_partial_download( + dataset_name: str, + split_name: str, + skip_take_list: List[Tuple[int, int]], + subset_name: Optional[str] = None, +) -> Dataset: + """Download a partial dataset. + + This functionality is not supported in the datasets library. This is an informal + way of achieving partial dataset download by using the `streaming=True` and creating + a dataset.Dataset from in-memory objects. + + Parameters + ---------- + dataset_name: str + Name of the dataset (passed to load_dataset). + split_name: str + Name of the split (passed to load_dataset) e.g. "train". + skip_take_list: List[Tuple[int, int]] + The streaming mode has a specific type of accessing the data, the first tuple + value is how many samples to skip, the second is how many samples to take. Due + to this mechanism, diverse samples can be taken (especially if the dataset is + sorted by the natual_id for NaturalIdPartitioner). + subset_name: Optional[str] + Name of the subset (passed to load_dataset) e.g. "v0.01" for speech_commands. + + Returns + ------- + dataset: Dataset + The dataset with the requested samples. + """ + dataset = datasets.load_dataset( + dataset_name, name=subset_name, split=split_name, streaming=True + ) + dataset_list = [] + # It's a list of dict such that each dict represent a single sample of the dataset + # The sample is exactly the same as if the full dataset was downloaded and indexed + for skip, take in skip_take_list: + # dataset.skip(n).take(m) in streaming mode is equivalent (in terms of return) + # to the fully downloaded dataset index: dataset[n+1: (n+1 + m)] + dataset_list.extend(list(dataset.skip(skip).take(take))) + return Dataset.from_list(dataset_list) + + +def _load_mocked_dataset_dict_by_partial_download( + dataset_name: str, + split_names: List[str], + skip_take_lists: List[List[Tuple[int, int]]], + subset_name: Optional[str] = None, +) -> DatasetDict: + """Like _load_mocked_dataset_by_partial_download but for many splits.""" + dataset_dict = {} + for split_name, skip_take_list in zip(split_names, skip_take_lists): + dataset_dict[split_name] = _load_mocked_dataset_by_partial_download( + dataset_name, split_name, skip_take_list, subset_name + ) + return DatasetDict(dataset_dict)