Skip to content

Commit

Permalink
Add tests for compute_counts and frequencies fnc
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Jun 5, 2024
1 parent 80131b0 commit 9eaefdc
Showing 1 changed file with 123 additions and 9 deletions.
132 changes: 123 additions & 9 deletions datasets/flwr_datasets/metrics/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,126 @@
import unittest

import pandas as pd
from parameterized import parameterized
from parameterized import parameterized, parameterized_class

from flwr_datasets.metrics.utils import compute_counts, compute_frequency
import datasets
from datasets import ClassLabel
from flwr_datasets.metrics.utils import (
_compute_counts,
_compute_frequencies,
compute_counts,
compute_frequencies,
)
from flwr_datasets.partitioner import IidPartitioner


class TestMetricsUtils(unittest.TestCase):
@parameterized_class(
("dataset", "result"),
[
(
datasets.Dataset.from_dict({"feature": list(range(10)), "label": [0] * 10}),
pd.DataFrame([[5], [5]], index=pd.Index([0, 1], name="Partition ID")),
),
(
datasets.Dataset.from_dict(
{"feature": list(range(10)), "label": [0] * 5 + [1] * 5}
),
pd.DataFrame([[5, 0], [0, 5]], index=pd.Index([0, 1], name="Partition ID")),
),
(
datasets.Dataset.from_dict(
{"feature": list(range(10)), "label": [0, 0, 0, 1, 1] + [1, 1, 1, 1, 2]}
),
pd.DataFrame(
[[3, 2, 0], [0, 4, 1]], index=pd.Index([0, 1], name="Partition ID")
),
),
],
)
class TestPublicMetricsUtils(unittest.TestCase):
"""Test metrics utils."""

dataset: datasets.Dataset
result: pd.DataFrame

def test_compute_counts(self) -> None:
"""Test if the counts are computed correctly."""
iid_partitioner = IidPartitioner(num_partitions=2)
iid_partitioner.dataset = self.dataset
count = compute_counts(iid_partitioner, column_name="label")
pd.testing.assert_frame_equal(count, self.result)

def test_compute_frequencies(self) -> None:
"""Test if the frequencies are computed correctly."""
print(self.dataset.features)
iid_partitioner = IidPartitioner(num_partitions=2)
iid_partitioner.dataset = self.dataset
frequencies = compute_frequencies(iid_partitioner, column_name="label")
result = self.result.div(self.result.sum(axis=1), axis=0)
pd.testing.assert_frame_equal(frequencies, result)

def test_compute_counts_with_verbose_label(self) -> None:
"""Test if the counts are computed correctly."""
iid_partitioner = IidPartitioner(num_partitions=2)
dataset = self.dataset
new_col_names = [
str(col_id) for col_id in range(len(self.dataset.unique("label")))
]
dataset = dataset.cast_column(
"label",
ClassLabel(
num_classes=len(self.dataset.unique("label")), names=new_col_names
),
)
iid_partitioner.dataset = dataset
result = self.result.copy()
result.columns = new_col_names
count = compute_counts(iid_partitioner, column_name="label", verbose_names=True)
pd.testing.assert_frame_equal(count, result)

def test_compute_frequencies_with_verbose_label(self) -> None:
"""Test if the frequencies are computed correctly."""
iid_partitioner = IidPartitioner(num_partitions=2)
dataset = self.dataset
new_col_names = [
str(col_id) for col_id in range(len(self.dataset.unique("label")))
]
dataset = dataset.cast_column(
"label",
ClassLabel(
num_classes=len(self.dataset.unique("label")), names=new_col_names
),
)
iid_partitioner.dataset = dataset
result = self.result.copy()
result.columns = new_col_names
result = result.div(result.sum(axis=1), axis=0)
frequencies = compute_frequencies(
iid_partitioner, column_name="label", verbose_names=True
)
print(frequencies)
pd.testing.assert_frame_equal(frequencies, result)

def test_compute_count_with_smaller_max_partitions(self) -> None:
"""Test is compute_count works when the max_partitions<total partitions."""
iid_partitioner = IidPartitioner(num_partitions=2)
iid_partitioner.dataset = self.dataset
count = compute_counts(
iid_partitioner, column_name="label", max_num_partitions=1
)
pd.testing.assert_frame_equal(count, self.result.iloc[:1])

def test_compute_count_with_bigger_max_partitions(self) -> None:
"""Test is compute_count works when the max_partitions>total partitions."""
iid_partitioner = IidPartitioner(num_partitions=2)
iid_partitioner.dataset = self.dataset
count = compute_counts(
iid_partitioner, column_name="label", max_num_partitions=3
)
pd.testing.assert_frame_equal(count, self.result)


class TestPrivateMetricsUtils(unittest.TestCase):
"""Test metrics utils."""

@parameterized.expand( # type: ignore
Expand All @@ -34,9 +148,9 @@ class TestMetricsUtils(unittest.TestCase):
([1, 1, 2], [1, 2, 3, 4], pd.Series([2, 1, 0, 0], index=[1, 2, 3, 4])),
]
)
def test_compute_counts(self, labels, unique_labels, expected) -> None:
def test__compute_counts(self, labels, unique_labels, expected) -> None:
"""Test if the counts are computed correctly."""
result = compute_counts(labels, unique_labels)
result = _compute_counts(labels, unique_labels)
pd.testing.assert_series_equal(result, expected)

@parameterized.expand( # type: ignore
Expand All @@ -56,7 +170,7 @@ def test_compute_counts(self, labels, unique_labels, expected) -> None:
)
def test_compute_distribution(self, labels, unique_labels, expected) -> None:
"""Test if the distributions are computed correctly."""
result = compute_frequency(labels, unique_labels)
result = _compute_frequencies(labels, unique_labels)
pd.testing.assert_series_equal(result, expected, atol=0.001)

@parameterized.expand( # type: ignore
Expand All @@ -67,22 +181,22 @@ def test_compute_distribution(self, labels, unique_labels, expected) -> None:
)
def test_distribution_sum_to_one(self, labels, unique_labels) -> None:
"""Test if distributions sum up to one."""
result = compute_frequency(labels, unique_labels)
result = _compute_frequencies(labels, unique_labels)
self.assertAlmostEqual(result.sum(), 1.0)

def test_compute_counts_non_unique_labels(self) -> None:
"""Test if not having the unique labels raises ValueError."""
labels = [1, 2, 3]
unique_labels = [1, 2, 2, 3]
with self.assertRaises(ValueError):
compute_counts(labels, unique_labels)
_compute_counts(labels, unique_labels)

def test_compute_distribution_non_unique_labels(self) -> None:
"""Test if not having the unique labels raises ValueError."""
labels = [1, 1, 2, 3]
unique_labels = [1, 1, 2, 3]
with self.assertRaises(ValueError):
compute_frequency(labels, unique_labels)
_compute_frequencies(labels, unique_labels)


if __name__ == "__main__":
Expand Down

0 comments on commit 9eaefdc

Please sign in to comment.