Skip to content

Commit

Permalink
Use parameterized_class
Browse files Browse the repository at this point in the history
  • Loading branch information
chongshenng committed Jul 22, 2024
1 parent 496e96e commit 0a90076
Showing 1 changed file with 39 additions and 70 deletions.
109 changes: 39 additions & 70 deletions datasets/flwr_datasets/partitioner/distribution_partitioner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Any, Dict, List, Tuple, Union

import numpy as np
from parameterized import parameterized
from parameterized import parameterized_class

from datasets import Dataset
from flwr_datasets.common.typing import NDArrayFloat, NDArrayInt
Expand Down Expand Up @@ -89,73 +89,58 @@ def _get_partitioner(
return partitioner, partitions


@parameterized_class(
(
"num_partitions",
"num_unique_labels_per_partition",
"num_samples",
"num_unique_labels",
"preassigned_num_samples_per_label",
),
[
(10, 2, 200, 10, 5),
(10, 2, 200, 10, 0),
(20, 1, 200, 10, 5),
],
)
# pylint: disable=E1101
class TestDistributionPartitioner(unittest.TestCase):
"""Unit tests for DistributionPartitioner."""

@parameterized.expand( # type: ignore
[
# num_partitions, num_unique_labels_per_partition, num_samples,
# num_unique_labels, preassigned_num_samples_per_label
(10, 2, 200, 10, 5),
(10, 2, 200, 10, 0),
(20, 1, 200, 10, 5),
],
)
def test_correct_num_classes_when_partitioned(
self,
num_partitions: int,
num_unique_labels_per_partition: int,
num_samples: int,
num_unique_labels: int,
preassigned_num_samples_per_label: int,
) -> None:
def test_correct_num_classes_when_partitioned(self) -> None:
"""Test correct number of unique classes."""
_, partitions = _get_partitioner(
num_partitions=num_partitions,
num_unique_labels_per_partition=num_unique_labels_per_partition,
num_samples=num_samples,
num_unique_labels=num_unique_labels,
preassigned_num_samples_per_label=preassigned_num_samples_per_label,
num_partitions=self.num_partitions,
num_unique_labels_per_partition=self.num_unique_labels_per_partition,
num_samples=self.num_samples,
num_unique_labels=self.num_unique_labels,
preassigned_num_samples_per_label=self.preassigned_num_samples_per_label,
)
unique_classes_per_partition = {
pid: np.unique(partition["labels"]) for pid, partition in partitions.items()
}

for unique_classes in unique_classes_per_partition.values():
self.assertEqual(num_unique_labels_per_partition, len(unique_classes))

@parameterized.expand( # type: ignore
[
# num_partitions, num_unique_labels_per_partition, num_samples,
# num_unique_labels, preassigned_num_samples_per_label
(10, 2, 200, 10, 5),
(10, 2, 200, 10, 0),
(20, 1, 200, 10, 5),
],
)
def test_correct_num_times_classes_sampled_across_partitions(
self,
num_partitions: int,
num_unique_labels_per_partition: int,
num_samples: int,
num_unique_labels: int,
preassigned_num_samples_per_label: int,
) -> None:
self.assertEqual(self.num_unique_labels_per_partition, len(unique_classes))

def test_correct_num_times_classes_sampled_across_partitions(self) -> None:
"""Test correct number of times each unique class is drawn from distribution."""
partitioner, partitions = _get_partitioner(
num_partitions=num_partitions,
num_unique_labels_per_partition=num_unique_labels_per_partition,
num_samples=num_samples,
num_unique_labels=num_unique_labels,
preassigned_num_samples_per_label=preassigned_num_samples_per_label,
num_partitions=self.num_partitions,
num_unique_labels_per_partition=self.num_unique_labels_per_partition,
num_samples=self.num_samples,
num_unique_labels=self.num_unique_labels,
preassigned_num_samples_per_label=self.preassigned_num_samples_per_label,
)

partitioned_distribution: Dict[Any, List[Any]] = {
label: [] for label in partitioner.dataset.unique("labels")
}

num_columns = (
num_unique_labels_per_partition * num_partitions / num_unique_labels
self.num_unique_labels_per_partition
* self.num_partitions
/ self.num_unique_labels
)
for _, partition in partitions.items():
for label in partition.unique("labels"):
Expand All @@ -165,30 +150,14 @@ def test_correct_num_times_classes_sampled_across_partitions(
for label in partitioner.dataset.unique("labels"):
self.assertEqual(num_columns, len(partitioned_distribution[label]))

@parameterized.expand( # type: ignore
[
# num_partitions, num_unique_labels_per_partition, num_samples,
# num_unique_labels, preassigned_num_samples_per_label
(10, 2, 200, 10, 5),
(10, 2, 200, 10, 0),
(20, 1, 200, 10, 5),
],
)
def test_exact_distribution_assignment(
self,
num_partitions: int,
num_unique_labels_per_partition: int,
num_samples: int,
num_unique_labels: int,
preassigned_num_samples_per_label: int,
) -> None:
def test_exact_distribution_assignment(self) -> None:
"""Test that exact distribution is allocated to each class."""
partitioner, partitions = _get_partitioner(
num_partitions=num_partitions,
num_unique_labels_per_partition=num_unique_labels_per_partition,
num_samples=num_samples,
num_unique_labels=num_unique_labels,
preassigned_num_samples_per_label=preassigned_num_samples_per_label,
num_partitions=self.num_partitions,
num_unique_labels_per_partition=self.num_unique_labels_per_partition,
num_samples=self.num_samples,
num_unique_labels=self.num_unique_labels,
preassigned_num_samples_per_label=self.preassigned_num_samples_per_label,
rescale_mode=False,
)
partitioned_distribution: Dict[Any, List[Any]] = {
Expand Down

0 comments on commit 0a90076

Please sign in to comment.