Skip to content

Commit

Permalink
Fix formatting errors
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Feb 16, 2024
1 parent 5765ea2 commit 834708d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 21 deletions.
6 changes: 3 additions & 3 deletions datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
from unittest.mock import Mock, patch

import pytest
from mock_utils import _load_mocked_dataset
from parameterized import parameterized, parameterized_class

import datasets
from datasets import Dataset, DatasetDict, concatenate_datasets
from flwr_datasets.federated_dataset import FederatedDataset
from flwr_datasets.mock_utils import _load_mocked_dataset
from flwr_datasets.partitioner import IidPartitioner, Partitioner

mocked_datasets = ["cifar100", "svhn", "sentiment140", "speech_commands"]
Expand Down Expand Up @@ -66,7 +66,7 @@ class BaseFederatedDatasetsTest(unittest.TestCase):
test_split = ""
subset = ""

def setUp(self):
def setUp(self) -> None:
"""Mock the dataset download prior to each method if needed.
If the `dataset_name` is in the `mocked_datasets` list, then the dataset
Expand All @@ -79,7 +79,7 @@ def setUp(self):
self.dataset_name, [200, 100], ["train", self.test_split], self.subset
)

def tearDown(self):
def tearDown(self) -> None:
"""Clean up after the dataset mocking."""
if self.dataset_name in mocked_datasets:
patch.stopall()
Expand Down
36 changes: 18 additions & 18 deletions datasets/flwr_datasets/mock_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@
import random
import string
from datetime import datetime, timedelta
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Set, Tuple, Union

import numpy as np
from PIL import Image

import datasets
from datasets import ClassLabel, DatasetDict, Features, Value
from datasets import ClassLabel, Dataset, DatasetDict, Features, Value


def _generate_artificial_strings(
num_rows: int, num_unique: int, string_length: int
) -> List[str]:
"""Create list of strings."""
unique_strings = set()
unique_strings: Set[str] = set()
while len(unique_strings) < num_unique:
random_str = "".join(
random.choices(string.ascii_letters + string.digits, k=string_length)
Expand Down Expand Up @@ -74,7 +74,7 @@ def _generate_random_sentence(
) -> str:
# Generate a random sentence with words of random lengths
sentence_length = random.randint(min_sentence_length, max_sentence_length)
sentence = []
sentence: List[str] = []
while len(" ".join(sentence)) < sentence_length:
word_length = random.randint(min_word_length, max_word_length)
word = _generate_random_word(word_length)
Expand All @@ -99,7 +99,7 @@ def _generate_random_sentences(
return text_col


def _make_num_rows_none(column, num_none):
def _make_num_rows_none(column: List[Any], num_none: int) -> List[Any]:
none_positions = random.sample(range(len(column)), num_none)
for pos in none_positions:
column[pos] = None
Expand All @@ -117,11 +117,9 @@ def _generate_random_date(
random_seconds = random.randint(0, int(time_between_dates.total_seconds()))
random_date = start_date + timedelta(seconds=random_seconds)

# Return the date in the specified format
if as_string:
return random_date.strftime(date_format)
else:
return random_date
return random_date


def _generate_random_date_column(
Expand All @@ -138,19 +136,19 @@ def _generate_random_date_column(
]


def _generate_random_int_column(num_rows: int, min_int: int, max_int: int):
def _generate_random_int_column(num_rows: int, min_int: int, max_int: int) -> List[int]:
return [random.randint(min_int, max_int) for _ in range(num_rows)]


def _generate_random_bool_column(num_rows: int):
def _generate_random_bool_column(num_rows: int) -> List[bool]:
return [random.choice([True, False]) for _ in range(num_rows)]


def _generate_random_image_column(
num_rows: int,
image_size: Union[Tuple[int, int], Tuple[int, int, int]],
simulate_type: str,
):
) -> List[Any]:
"""Simulate the images with the format that is found in HF Hub.
Directly using `Image.fromarray` does not work because it creates `PIL.Image.Image`.
Expand Down Expand Up @@ -181,7 +179,7 @@ def generate_random_audio_column(
num_rows: int,
sampling_rate: int,
length_in_samples: int,
):
) -> List[Dict[str, Any]]:
"""Simulate the audio column.
Audio column in the datset is comprised from an array or floats, sample_rate and a
Expand All @@ -197,7 +195,7 @@ def generate_random_audio_column(
return audios


def _mock_sentiment140(num_rows: int):
def _mock_sentiment140(num_rows: int) -> Dataset:
users = _generate_artificial_strings(
num_rows=num_rows, num_unique=30, string_length=5
)
Expand Down Expand Up @@ -249,7 +247,7 @@ def _mock_sentiment140(num_rows: int):
return dataset


def _mock_cifar100(num_rows: int):
def _mock_cifar100(num_rows: int) -> Dataset:
imgs = _generate_random_image_column(num_rows, (32, 32, 3), "PNG")
unique_fine_labels = [
"apple",
Expand Down Expand Up @@ -392,7 +390,7 @@ def _mock_cifar100(num_rows: int):
return dataset


def _mock_svhn_cropped_digits(num_rows: int):
def _mock_svhn_cropped_digits(num_rows: int) -> Dataset:
imgs = _generate_random_image_column(num_rows, (32, 32, 3), "PNG")
unique_labels = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
label = _generate_artificial_categories(num_rows, unique_labels)
Expand All @@ -408,7 +406,7 @@ def _mock_svhn_cropped_digits(num_rows: int):
return dataset


def _mock_speach_commands(num_rows: int):
def _mock_speach_commands(num_rows: int) -> Dataset:
sampling_rate = 16_000
length_in_samples = 16_000
imgs = generate_random_audio_column(
Expand Down Expand Up @@ -447,7 +445,9 @@ def _mock_speach_commands(num_rows: int):
return dataset


def _mock_dict_dataset(num_rows: List[int], split_names: List[str], function: Callable):
def _mock_dict_dataset(
num_rows: List[int], split_names: List[str], function: Callable[[int], Dataset]
) -> DatasetDict:
dataset_dict = {}
for params in zip(num_rows, split_names):
dataset_dict[params[1]] = function(params[0])
Expand All @@ -465,7 +465,7 @@ def _load_mocked_dataset(
dataset_name: str,
num_rows: List[int],
split_names: List[str],
subset: Optional[str] = "",
subset: str = "",
) -> DatasetDict:
dataset_dict = {}
name = dataset_name if subset == "" else dataset_name + "_" + subset
Expand Down

0 comments on commit 834708d

Please sign in to comment.