Skip to content

Commit

Permalink
Add distribution partitioner test
Browse files Browse the repository at this point in the history
  • Loading branch information
chongshenng committed Jul 15, 2024
1 parent 226b676 commit 2abeb21
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def _check_distribution_array_shape_if_needed(self) -> None:

def _check_distribution_array_sum_if_needed(self) -> None:
"""Test correctness of distribution array sum."""
if not self._partition_id_to_indices_determined:
if not self._partition_id_to_indices_determined and not self._rescale:
labels = self.dataset[self._partition_by]
distribution = sorted(Counter(labels).items())
distribution_vals = [v for _, v in distribution]
Expand Down
189 changes: 189 additions & 0 deletions datasets/flwr_datasets/partitioner/distribution_partitioner_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test cases for DistributionPartitioner."""

from collections import Counter
from typing import Dict, Tuple

import numpy as np
import pytest
from flwr.common.typing import NDArrayFloat

from datasets import Dataset
from flwr_datasets.partitioner.distribution_partitioner import DistributionPartitioner


def _dummy_dataset_setup(
num_samples: int, partition_by: str, num_unique_classes: int
) -> Dataset:
"""Create a dummy dataset for testing."""
data = {
partition_by: np.tile(
np.arange(num_unique_classes), num_samples // num_unique_classes + 1
)[:num_samples],
"features": np.random.randn(num_samples),
}
return Dataset.from_dict(data)


def _dummy_distribution_setup(
num_partitions: int,
num_unique_labels_per_partition: int,
num_unique_labels: int,
random_mode: bool = False,
) -> NDArrayFloat:
"""Create a dummy distribution for testing."""
num_columns = num_unique_labels_per_partition * num_partitions / num_unique_labels
if random_mode:
rng = np.random.default_rng(2024)
return rng.integers(1, 10, size=(num_unique_labels, int(num_columns)))
return np.tile(np.arange(num_columns) + 1.0, (num_unique_labels, 1))


# pylint: disable=R0913
def _get_partitioner(
num_partitions: int,
num_unique_labels_per_partition: int,
num_samples: int,
num_unique_labels: int,
preassigned_num_samples_per_label: int,
rescale_mode: bool = True,
) -> Tuple[DistributionPartitioner, Dict[int, Dataset]]:
"""Create DistributionPartitioner instance."""
dataset = _dummy_dataset_setup(
num_samples,
"labels",
num_unique_labels,
)
distribution = _dummy_distribution_setup(
num_partitions,
num_unique_labels_per_partition,
num_unique_labels,
)
partitioner = DistributionPartitioner(
distribution_array=distribution,
num_partitions=num_partitions,
num_unique_labels_per_partition=num_unique_labels_per_partition,
partition_by="labels",
preassigned_num_samples_per_label=preassigned_num_samples_per_label,
rescale=rescale_mode,
)
partitioner.dataset = dataset
partitions: Dict[int, Dataset] = {
pid: partitioner.load_partition(pid) for pid in range(num_partitions)
}

return partitioner, partitions


@pytest.mark.parametrize(
"num_partitions, num_unique_labels_per_partition, num_samples, "
"num_unique_labels, preassigned_num_samples_per_label",
[
(10, 2, 200, 10, 5),
(10, 2, 200, 10, 0),
(20, 1, 200, 10, 5),
],
)
# pylint: disable=R0913
class TestDistributionPartitioner:
"""Unit tests for DistributionPartitioner."""

def test_correct_num_classes_when_partitioned(
self,
num_partitions: int,
num_unique_labels_per_partition: int,
num_samples: int,
num_unique_labels: int,
preassigned_num_samples_per_label: int,
) -> None:
"""Test correct number of unique classes."""
_, partitions = _get_partitioner(
num_partitions=num_partitions,
num_unique_labels_per_partition=num_unique_labels_per_partition,
num_samples=num_samples,
num_unique_labels=num_unique_labels,
preassigned_num_samples_per_label=preassigned_num_samples_per_label,
)
unique_classes_per_partition = {
pid: np.unique(partition["labels"]) for pid, partition in partitions.items()
}

for unique_classes in unique_classes_per_partition.values():
assert num_unique_labels_per_partition == len(unique_classes)

def test_correct_num_times_classes_sampled_across_partitions(
self,
num_partitions: int,
num_unique_labels_per_partition: int,
num_samples: int,
num_unique_labels: int,
preassigned_num_samples_per_label: int,
) -> None:
"""Test correct number of times each unique class is drawn from distribution."""
partitioner, partitions = _get_partitioner(
num_partitions=num_partitions,
num_unique_labels_per_partition=num_unique_labels_per_partition,
num_samples=num_samples,
num_unique_labels=num_unique_labels,
preassigned_num_samples_per_label=preassigned_num_samples_per_label,
)

partitioned_distribution = {
label: [] for label in partitioner.dataset.unique("labels")
}

num_columns = (
num_unique_labels_per_partition * num_partitions / num_unique_labels
)
for _, partition in partitions.items():
for label in partition.unique("labels"):
value_counts = Counter(partition["labels"])
partitioned_distribution[label].append(value_counts[label])

for label in partitioner.dataset.unique("labels"):
assert num_columns == len(partitioned_distribution[label])

def test_exact_distribution_assignment(
self,
num_partitions: int,
num_unique_labels_per_partition: int,
num_samples: int,
num_unique_labels: int,
preassigned_num_samples_per_label: int,
) -> None:
"""Test that exact distribution is allocated to each class."""
partitioner, partitions = _get_partitioner(
num_partitions=num_partitions,
num_unique_labels_per_partition=num_unique_labels_per_partition,
num_samples=num_samples,
num_unique_labels=num_unique_labels,
preassigned_num_samples_per_label=preassigned_num_samples_per_label,
rescale_mode=False,
)
partitioned_distribution = {
label: [] for label in partitioner.dataset.unique("labels")
}

for _, partition in partitions.items():
for label in partition.unique("labels"):
value_counts = Counter(partition["labels"])
partitioned_distribution[label].append(value_counts[label])

for idx, label in enumerate(sorted(partitioner.dataset.unique("labels"))):
assert np.array_equal(
partitioner._distribution_array[idx], # pylint: disable=W0212
partitioned_distribution[label],
)

0 comments on commit 2abeb21

Please sign in to comment.