Skip to content

Commit

Permalink
feat(datasets) Add function to perform partial download of dataset fo…
Browse files Browse the repository at this point in the history
…r tests (#3860)
  • Loading branch information
adam-narozniak authored Jul 22, 2024
1 parent df1084f commit b7ae03d
Showing 1 changed file with 60 additions and 1 deletion.
61 changes: 60 additions & 1 deletion datasets/flwr_datasets/mock_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import random
import string
from datetime import datetime, timedelta
from typing import Any, Dict, List, Set, Tuple, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import numpy as np
from PIL import Image
Expand Down Expand Up @@ -375,3 +375,62 @@ def _load_mocked_dataset(
for params in zip(num_rows, split_names):
dataset_dict[params[1]] = dataset_creation_fnc(params[0])
return datasets.DatasetDict(dataset_dict)


def _load_mocked_dataset_by_partial_download(
dataset_name: str,
split_name: str,
skip_take_list: List[Tuple[int, int]],
subset_name: Optional[str] = None,
) -> Dataset:
"""Download a partial dataset.
This functionality is not supported in the datasets library. This is an informal
way of achieving partial dataset download by using the `streaming=True` and creating
a dataset.Dataset from in-memory objects.
Parameters
----------
dataset_name: str
Name of the dataset (passed to load_dataset).
split_name: str
Name of the split (passed to load_dataset) e.g. "train".
skip_take_list: List[Tuple[int, int]]
The streaming mode has a specific type of accessing the data, the first tuple
value is how many samples to skip, the second is how many samples to take. Due
to this mechanism, diverse samples can be taken (especially if the dataset is
sorted by the natual_id for NaturalIdPartitioner).
subset_name: Optional[str]
Name of the subset (passed to load_dataset) e.g. "v0.01" for speech_commands.
Returns
-------
dataset: Dataset
The dataset with the requested samples.
"""
dataset = datasets.load_dataset(
dataset_name, name=subset_name, split=split_name, streaming=True
)
dataset_list = []
# It's a list of dict such that each dict represent a single sample of the dataset
# The sample is exactly the same as if the full dataset was downloaded and indexed
for skip, take in skip_take_list:
# dataset.skip(n).take(m) in streaming mode is equivalent (in terms of return)
# to the fully downloaded dataset index: dataset[n+1: (n+1 + m)]
dataset_list.extend(list(dataset.skip(skip).take(take)))
return Dataset.from_list(dataset_list)


def _load_mocked_dataset_dict_by_partial_download(
dataset_name: str,
split_names: List[str],
skip_take_lists: List[List[Tuple[int, int]]],
subset_name: Optional[str] = None,
) -> DatasetDict:
"""Like _load_mocked_dataset_by_partial_download but for many splits."""
dataset_dict = {}
for split_name, skip_take_list in zip(split_names, skip_take_lists):
dataset_dict[split_name] = _load_mocked_dataset_by_partial_download(
dataset_name, split_name, skip_take_list, subset_name
)
return DatasetDict(dataset_dict)

0 comments on commit b7ae03d

Please sign in to comment.