Skip to content

Commit

Permalink
Rename arg 'shuffle_seed back to rng_seed`
Browse files Browse the repository at this point in the history
  • Loading branch information
KarhouTam committed Sep 10, 2024
1 parent 5a25e5e commit ebc2f3a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 24 deletions.
34 changes: 27 additions & 7 deletions datasets/flwr_datasets/partitioner/image_semantic_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ class ImageSemanticPartitioner(Partitioner):
shuffle: bool
Whether to randomize the order of samples. Shuffling applied after the
samples assignment to partitions.
shuffle_seed: Optional[int]
Seed used for shuffling. Defaults to 42.
rng_seed: Optional[int]
Seed used for numpy random number generator,
which used throughout the process. Defaults to 42.
pca_seed: Optional[int]
Seed used for PCA dimensionality reduction. Defaults to 42.
gmm_seed: Optional[int]
Expand Down Expand Up @@ -137,7 +138,7 @@ def __init__( # pylint: disable=R0913
use_cuda: bool = False,
image_column_name: Optional[str] = None,
shuffle: bool = True,
shuffle_seed: Optional[int] = 42,
rng_seed: Optional[int] = 42,
pca_seed: Optional[int] = 42,
gmm_seed: Optional[int] = 42,
) -> None:
Expand All @@ -154,13 +155,13 @@ def __init__( # pylint: disable=R0913
self._use_cuda = use_cuda
self._image_column_name = image_column_name
self._shuffle = shuffle
self._shuffle_seed = shuffle_seed
self._rng_seed = rng_seed
self._pca_seed = pca_seed
self._gmm_seed = gmm_seed

self._check_variable_validation()

self._rng_numpy = np.random.default_rng(seed=self._shuffle_seed)
self._rng_numpy = np.random.default_rng(seed=self._rng_seed)

# The attributes below are determined during the first call to load_partition
self._unique_classes: Optional[Union[List[int], List[str]]] = None
Expand Down Expand Up @@ -497,9 +498,28 @@ def _check_variable_validation(self) -> None:
raise ValueError("The gmm max iter needs to be greater than zero.")
if self._pca_components <= 0:
raise ValueError("The pca components needs to be greater than zero.")
if not isinstance(self._shuffle_seed, int):
raise TypeError("The shuffle seed needs to be an integer.")
if not isinstance(self._rng_seed, int):
raise TypeError("The rng seed needs to be an integer.")
if not isinstance(self._pca_seed, int):
raise TypeError("The pca seed needs to be an integer.")
if not isinstance(self._gmm_seed, int):
raise TypeError("The gmm seed needs to be an integer.")

if __name__ == "__main__":
# ===================== Test with custom Dataset =====================
from datasets import Dataset

dataset = {
"image": [np.random.randn(28, 28) for _ in range(50)],
"label": [i % 3 for i in range(50)],
}
dataset = Dataset.from_dict(dataset)
partitioner = ImageSemanticPartitioner(
num_partitions=5, partition_by="label", pca_components=30
)
partitioner.dataset = dataset
partition = partitioner.load_partition(0)
partition_sizes = partition_sizes = [
len(partitioner.load_partition(partition_id)) for partition_id in range(5)
]
print(sorted(partition_sizes))
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# ==============================================================================
"""Test ImageSemanticPartitioner."""


# pylint: disable=W0212
import string
import unittest
Expand All @@ -33,14 +32,14 @@
def _dummy_setup(
data_shape: Tuple[int, ...] = (28, 28, 1),
num_partitions: int = 3,
num_rows: int = 10,
num_rows: int = 50,
partition_by: str = "label",
efficient_net_type: int = 0,
batch_size: int = 32,
pca_components: int = 6,
gmm_max_iter: int = 2,
gmm_init_params: str = "random",
shuffle_seed: Optional[int] = 42,
rng_seed: Optional[int] = 42,
pca_seed: Optional[int] = 42,
gmm_seed: Optional[int] = 42,
) -> Tuple[Dataset, ImageSemanticPartitioner]:
Expand All @@ -58,7 +57,7 @@ def _dummy_setup(
pca_components=pca_components,
gmm_max_iter=gmm_max_iter,
gmm_init_params=gmm_init_params,
shuffle_seed=shuffle_seed,
rng_seed=rng_seed,
pca_seed=pca_seed,
gmm_seed=gmm_seed,
)
Expand Down Expand Up @@ -88,7 +87,7 @@ def test_valid_initialization(
pca_components: int,
gmm_max_iter: int,
gmm_init_params: str,
shuffle_seed: Optional[int] = 42,
rng_seed: Optional[int] = 42,
pca_seed: Optional[int] = 42,
gmm_seed: Optional[int] = 42,
) -> None:
Expand All @@ -103,7 +102,7 @@ def test_valid_initialization(
pca_components=pca_components,
gmm_max_iter=gmm_max_iter,
gmm_init_params=gmm_init_params,
shuffle_seed=shuffle_seed,
rng_seed=rng_seed,
pca_seed=pca_seed,
gmm_seed=gmm_seed,
)
Expand All @@ -115,7 +114,7 @@ def test_valid_initialization(
partitioner._pca_components,
partitioner._gmm_max_iter,
partitioner._gmm_init_params,
partitioner._shuffle_seed,
partitioner._rng_seed,
partitioner._pca_seed,
partitioner._gmm_seed,
),
Expand All @@ -126,7 +125,7 @@ def test_valid_initialization(
pca_components,
gmm_max_iter,
gmm_init_params,
shuffle_seed,
rng_seed,
pca_seed,
gmm_seed,
),
Expand Down Expand Up @@ -167,21 +166,21 @@ def test_gaussian_mixture_model(

@parameterized.expand([(1, 2, 3)]) # type: ignore
def test_seeds(
self, shuffle_seed: int, pca_seed: Optional[int], gmm_seed: Optional[int]
self, rng_seed: int, pca_seed: Optional[int], gmm_seed: Optional[int]
) -> None:
"""Test if seeds are correct."""
_, partitioner = _dummy_setup(
shuffle_seed=shuffle_seed, pca_seed=pca_seed, gmm_seed=gmm_seed
rng_seed=rng_seed, pca_seed=pca_seed, gmm_seed=gmm_seed
)
self.assertEqual(
(partitioner._shuffle_seed, partitioner._pca_seed, partitioner._gmm_seed),
(shuffle_seed, pca_seed, gmm_seed),
(partitioner._rng_seed, partitioner._pca_seed, partitioner._gmm_seed),
(rng_seed, pca_seed, gmm_seed),
)

def test_determine_partition_id_to_indices(self) -> None:
"""Test the determine_nod_id_to_indices matches the flag after the call."""
_, partitioner = _dummy_setup()
partitioner._determine_partition_id_to_indices_if_needed()
partitioner.load_partition(0)
self.assertTrue(
partitioner._partition_id_to_indices_determined
and len(partitioner._partition_id_to_indices) == 3
Expand Down Expand Up @@ -238,9 +237,11 @@ def test_invalid_num_partitions(self, num_partitions: int) -> None:
_, partitioner = _dummy_setup(num_partitions=num_partitions, num_rows=10)
partitioner.load_partition(0)

@parameterized.expand([(0,), (-1,), (2.0,), (11,)]) # type: ignore
@parameterized.expand([(0,), (-1,), (2.0,), (110000000,)]) # type: ignore
def test_invalid_pca_components(self, pca_components: int) -> None:
"""Test if pca_components is not a positive integer."""
"""Test if pca_components is not a positive integer or
larger than the number of rows.
"""
with self.assertRaises((ValueError, TypeError)):
_, partitioner = _dummy_setup(pca_components=pca_components)
partitioner.load_partition(0)
Expand All @@ -267,12 +268,12 @@ def test_invalid_gaussian_mixture_config(

@parameterized.expand([("1", 2, 3), (1.2, 2, 3), (1, 2, 3.0)]) # type: ignore
def test_invalid_seeds(
self, shuffle_seed: int, pca_seed: Optional[int], gmm_seed: Optional[int]
self, rng_seed: int, pca_seed: Optional[int], gmm_seed: Optional[int]
) -> None:
"""Test if raises when the seeds are not integers."""
with self.assertRaises((TypeError, ValueError)):
_, partitioner = _dummy_setup(
shuffle_seed=shuffle_seed, pca_seed=pca_seed, gmm_seed=gmm_seed
rng_seed=rng_seed, pca_seed=pca_seed, gmm_seed=gmm_seed
)
partitioner.load_partition(0)

Expand Down

0 comments on commit ebc2f3a

Please sign in to comment.