Skip to content

Commit

Permalink
Check for unique_label correctness
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed May 20, 2024
1 parent 1e8da7c commit 8120e16
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
2 changes: 2 additions & 0 deletions datasets/flwr_datasets/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def compute_counts(
label_counts: pd.Series
The pd.Series with label as indices and counts as values.
"""
if len(unique_labels) != len(set(unique_labels)):
raise ValueError("unique_labels must contain unique elements only.")
labels_series = pd.Series(labels)
label_counts = labels_series.value_counts()
label_counts_with_zeros = pd.Series(index=unique_labels, data=0)
Expand Down
14 changes: 14 additions & 0 deletions datasets/flwr_datasets/metrics/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ def test_distribution_sum_to_one(self, labels, unique_labels) -> None:
result = compute_distribution(labels, unique_labels)
self.assertAlmostEqual(result.sum(), 1.0)

def test_compute_counts_non_unique_labels(self) -> None:
"""Test if not having the unique labels raises ValueError."""
labels = [1, 2, 3]
unique_labels = [1, 2, 2, 3]
with self.assertRaises(ValueError):
compute_counts(labels, unique_labels)

def test_compute_distribution_non_unique_labels(self) -> None:
"""Test if not having the unique labels raises ValueError."""
labels = [1, 1, 2, 3]
unique_labels = [1, 1, 2, 3]
with self.assertRaises(ValueError):
compute_distribution(labels, unique_labels)


if __name__ == "__main__":
unittest.main()

0 comments on commit 8120e16

Please sign in to comment.