Skip to content

Commit

Permalink
Add function to perform partial download of dataset for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Jul 20, 2024
1 parent 3ead850 commit 915a6f3
Showing 1 changed file with 45 additions and 1 deletion.
46 changes: 45 additions & 1 deletion datasets/flwr_datasets/mock_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import random
import string
from datetime import datetime, timedelta
from typing import Any, Dict, List, Set, Tuple, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import numpy as np
from PIL import Image
Expand Down Expand Up @@ -375,3 +375,47 @@ def _load_mocked_dataset(
for params in zip(num_rows, split_names):
dataset_dict[params[1]] = dataset_creation_fnc(params[0])
return datasets.DatasetDict(dataset_dict)


def _download_partial_dataset(
dataset_name: str,
split_name: str,
skip_take_list: List[Tuple[int, int]],
subset_name: Optional[str] = None,
) -> Dataset:
"""Download a partial dataset.
This functionality is not supported in the datasets library. This is an informal
way of achieving this by using the `streaming=True` and creating a dataset.Dataset
from in-memory objects.
Parameters
----------
dataset_name: str
Name of the dataset (passed to load_dataset).
split_name: str
Name of the split (passed to load_dataset) e.g. "train".
skip_take_list: List[Tuple[int, int]]
The streaming mode has a specific type of accessing the data, the first tuple
value is how many samples to skip, the second is how many samples to take. Due
to this mechanism, diverse samples can be taken (especially if the dataset is
sorted by the natual_id for NaturalIdPartitioner).
subset_name: Optional[str]
Name of the subset (passed to load_dataset) e.g. "v0.01" for speech_commands.
Returns
-------
dataset: Dataset
The dataset with the requested samples.
"""
dataset = datasets.load_dataset(
dataset_name, name=subset_name, split=split_name, streaming=True
)
dataset_list = []
# It's a list of dict such that each dict represent a single sample of the dataset
# The sample is exactly the same as if the full dataset was downloaded and indexed
for skip, take in skip_take_list:
# dataset.skip(n).take(m) in streaming mode is equivalent (in terms of return)
# to the fully downloaded dataset index: dataset[n+1: (n+1 + m)]
dataset_list.extend(list(dataset.skip(skip).take(take)))
return Dataset.from_list(dataset_list)

0 comments on commit 915a6f3

Please sign in to comment.