From 2abeb212a27a9a8dafdafbb6bac0399d949c722e Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Mon, 15 Jul 2024 12:51:36 +0100 Subject: [PATCH] Add distribution partitioner test --- .../partitioner/distribution_partitioner.py | 2 +- .../distribution_partitioner_test.py | 189 ++++++++++++++++++ 2 files changed, 190 insertions(+), 1 deletion(-) create mode 100644 datasets/flwr_datasets/partitioner/distribution_partitioner_test.py diff --git a/datasets/flwr_datasets/partitioner/distribution_partitioner.py b/datasets/flwr_datasets/partitioner/distribution_partitioner.py index 7027c0afa125..7245923ecd75 100644 --- a/datasets/flwr_datasets/partitioner/distribution_partitioner.py +++ b/datasets/flwr_datasets/partitioner/distribution_partitioner.py @@ -305,7 +305,7 @@ def _check_distribution_array_shape_if_needed(self) -> None: def _check_distribution_array_sum_if_needed(self) -> None: """Test correctness of distribution array sum.""" - if not self._partition_id_to_indices_determined: + if not self._partition_id_to_indices_determined and not self._rescale: labels = self.dataset[self._partition_by] distribution = sorted(Counter(labels).items()) distribution_vals = [v for _, v in distribution] diff --git a/datasets/flwr_datasets/partitioner/distribution_partitioner_test.py b/datasets/flwr_datasets/partitioner/distribution_partitioner_test.py new file mode 100644 index 000000000000..e3d222ebcfa8 --- /dev/null +++ b/datasets/flwr_datasets/partitioner/distribution_partitioner_test.py @@ -0,0 +1,189 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test cases for DistributionPartitioner.""" + +from collections import Counter +from typing import Dict, Tuple + +import numpy as np +import pytest +from flwr.common.typing import NDArrayFloat + +from datasets import Dataset +from flwr_datasets.partitioner.distribution_partitioner import DistributionPartitioner + + +def _dummy_dataset_setup( + num_samples: int, partition_by: str, num_unique_classes: int +) -> Dataset: + """Create a dummy dataset for testing.""" + data = { + partition_by: np.tile( + np.arange(num_unique_classes), num_samples // num_unique_classes + 1 + )[:num_samples], + "features": np.random.randn(num_samples), + } + return Dataset.from_dict(data) + + +def _dummy_distribution_setup( + num_partitions: int, + num_unique_labels_per_partition: int, + num_unique_labels: int, + random_mode: bool = False, +) -> NDArrayFloat: + """Create a dummy distribution for testing.""" + num_columns = num_unique_labels_per_partition * num_partitions / num_unique_labels + if random_mode: + rng = np.random.default_rng(2024) + return rng.integers(1, 10, size=(num_unique_labels, int(num_columns))) + return np.tile(np.arange(num_columns) + 1.0, (num_unique_labels, 1)) + + +# pylint: disable=R0913 +def _get_partitioner( + num_partitions: int, + num_unique_labels_per_partition: int, + num_samples: int, + num_unique_labels: int, + preassigned_num_samples_per_label: int, + rescale_mode: bool = True, +) -> Tuple[DistributionPartitioner, Dict[int, Dataset]]: + """Create DistributionPartitioner instance.""" + dataset = _dummy_dataset_setup( + num_samples, + "labels", + num_unique_labels, + ) + distribution = _dummy_distribution_setup( + num_partitions, + num_unique_labels_per_partition, + num_unique_labels, + ) + partitioner = DistributionPartitioner( + distribution_array=distribution, + num_partitions=num_partitions, + num_unique_labels_per_partition=num_unique_labels_per_partition, + partition_by="labels", + preassigned_num_samples_per_label=preassigned_num_samples_per_label, + rescale=rescale_mode, + ) + partitioner.dataset = dataset + partitions: Dict[int, Dataset] = { + pid: partitioner.load_partition(pid) for pid in range(num_partitions) + } + + return partitioner, partitions + + +@pytest.mark.parametrize( + "num_partitions, num_unique_labels_per_partition, num_samples, " + "num_unique_labels, preassigned_num_samples_per_label", + [ + (10, 2, 200, 10, 5), + (10, 2, 200, 10, 0), + (20, 1, 200, 10, 5), + ], +) +# pylint: disable=R0913 +class TestDistributionPartitioner: + """Unit tests for DistributionPartitioner.""" + + def test_correct_num_classes_when_partitioned( + self, + num_partitions: int, + num_unique_labels_per_partition: int, + num_samples: int, + num_unique_labels: int, + preassigned_num_samples_per_label: int, + ) -> None: + """Test correct number of unique classes.""" + _, partitions = _get_partitioner( + num_partitions=num_partitions, + num_unique_labels_per_partition=num_unique_labels_per_partition, + num_samples=num_samples, + num_unique_labels=num_unique_labels, + preassigned_num_samples_per_label=preassigned_num_samples_per_label, + ) + unique_classes_per_partition = { + pid: np.unique(partition["labels"]) for pid, partition in partitions.items() + } + + for unique_classes in unique_classes_per_partition.values(): + assert num_unique_labels_per_partition == len(unique_classes) + + def test_correct_num_times_classes_sampled_across_partitions( + self, + num_partitions: int, + num_unique_labels_per_partition: int, + num_samples: int, + num_unique_labels: int, + preassigned_num_samples_per_label: int, + ) -> None: + """Test correct number of times each unique class is drawn from distribution.""" + partitioner, partitions = _get_partitioner( + num_partitions=num_partitions, + num_unique_labels_per_partition=num_unique_labels_per_partition, + num_samples=num_samples, + num_unique_labels=num_unique_labels, + preassigned_num_samples_per_label=preassigned_num_samples_per_label, + ) + + partitioned_distribution = { + label: [] for label in partitioner.dataset.unique("labels") + } + + num_columns = ( + num_unique_labels_per_partition * num_partitions / num_unique_labels + ) + for _, partition in partitions.items(): + for label in partition.unique("labels"): + value_counts = Counter(partition["labels"]) + partitioned_distribution[label].append(value_counts[label]) + + for label in partitioner.dataset.unique("labels"): + assert num_columns == len(partitioned_distribution[label]) + + def test_exact_distribution_assignment( + self, + num_partitions: int, + num_unique_labels_per_partition: int, + num_samples: int, + num_unique_labels: int, + preassigned_num_samples_per_label: int, + ) -> None: + """Test that exact distribution is allocated to each class.""" + partitioner, partitions = _get_partitioner( + num_partitions=num_partitions, + num_unique_labels_per_partition=num_unique_labels_per_partition, + num_samples=num_samples, + num_unique_labels=num_unique_labels, + preassigned_num_samples_per_label=preassigned_num_samples_per_label, + rescale_mode=False, + ) + partitioned_distribution = { + label: [] for label in partitioner.dataset.unique("labels") + } + + for _, partition in partitions.items(): + for label in partition.unique("labels"): + value_counts = Counter(partition["labels"]) + partitioned_distribution[label].append(value_counts[label]) + + for idx, label in enumerate(sorted(partitioner.dataset.unique("labels"))): + assert np.array_equal( + partitioner._distribution_array[idx], # pylint: disable=W0212 + partitioned_distribution[label], + )