diff --git a/datasets/flwr_datasets/metrics/utils.py b/datasets/flwr_datasets/metrics/utils.py index 4e8d69fcb4bb..ef4ef449eac5 100644 --- a/datasets/flwr_datasets/metrics/utils.py +++ b/datasets/flwr_datasets/metrics/utils.py @@ -40,6 +40,8 @@ def compute_counts( label_counts: pd.Series The pd.Series with label as indices and counts as values. """ + if len(unique_labels) != len(set(unique_labels)): + raise ValueError("unique_labels must contain unique elements only.") labels_series = pd.Series(labels) label_counts = labels_series.value_counts() label_counts_with_zeros = pd.Series(index=unique_labels, data=0) diff --git a/datasets/flwr_datasets/metrics/utils_test.py b/datasets/flwr_datasets/metrics/utils_test.py index 5e4d9380bc6b..825cbb397508 100644 --- a/datasets/flwr_datasets/metrics/utils_test.py +++ b/datasets/flwr_datasets/metrics/utils_test.py @@ -69,6 +69,20 @@ def test_distribution_sum_to_one(self, labels, unique_labels) -> None: result = compute_distribution(labels, unique_labels) self.assertAlmostEqual(result.sum(), 1.0) + def test_compute_counts_non_unique_labels(self) -> None: + """Test if not having the unique labels raises ValueError.""" + labels = [1, 2, 3] + unique_labels = [1, 2, 2, 3] + with self.assertRaises(ValueError): + compute_counts(labels, unique_labels) + + def test_compute_distribution_non_unique_labels(self) -> None: + """Test if not having the unique labels raises ValueError.""" + labels = [1, 1, 2, 3] + unique_labels = [1, 1, 2, 3] + with self.assertRaises(ValueError): + compute_distribution(labels, unique_labels) + if __name__ == "__main__": unittest.main()