-
Notifications
You must be signed in to change notification settings - Fork 942
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(datasets) Add label distribution visualization (#3451)
Co-authored-by: jafermarq <[email protected]>
- Loading branch information
1 parent
72244a8
commit 097b803
Showing
12 changed files
with
1,039 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Copyright 2024 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Metrics package.""" | ||
|
||
|
||
from flwr_datasets.metrics.utils import compute_counts, compute_frequency | ||
|
||
__all__ = [ | ||
"compute_counts", | ||
"compute_frequency", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright 2024 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Utils for metrics computation.""" | ||
|
||
|
||
from typing import List, Union | ||
|
||
import pandas as pd | ||
|
||
|
||
def compute_counts( | ||
labels: Union[List[int], List[str]], unique_labels: Union[List[int], List[str]] | ||
) -> pd.Series: | ||
"""Compute the count of labels when taking into account all possible labels. | ||
Also known as absolute frequency. | ||
Parameters | ||
---------- | ||
labels: Union[List[int], List[str]] | ||
The labels from the datasets. | ||
unique_labels: Union[List[int], List[str]] | ||
The reference all unique label. Needed to avoid missing any label, instead | ||
having the value equal to zero for them. | ||
Returns | ||
------- | ||
label_counts: pd.Series | ||
The pd.Series with label as indices and counts as values. | ||
""" | ||
if len(unique_labels) != len(set(unique_labels)): | ||
raise ValueError("unique_labels must contain unique elements only.") | ||
labels_series = pd.Series(labels) | ||
label_counts = labels_series.value_counts() | ||
label_counts_with_zeros = pd.Series(index=unique_labels, data=0) | ||
label_counts_with_zeros = label_counts_with_zeros.add( | ||
label_counts, fill_value=0 | ||
).astype(int) | ||
return label_counts_with_zeros | ||
|
||
|
||
def compute_frequency( | ||
labels: Union[List[int], List[str]], unique_labels: Union[List[int], List[str]] | ||
) -> pd.Series: | ||
"""Compute the distribution of labels when taking into account all possible labels. | ||
Also known as relative frequency. | ||
Parameters | ||
---------- | ||
labels: Union[List[int], List[str]] | ||
The labels from the datasets. | ||
unique_labels: Union[List[int], List[str]] | ||
The reference all unique label. Needed to avoid missing any label, instead | ||
having the value equal to zero for them. | ||
Returns | ||
------- | ||
The pd.Series with label as indices and probabilities as values. | ||
""" | ||
counts = compute_counts(labels, unique_labels) | ||
if len(labels) == 0: | ||
counts = counts.astype(float) | ||
return counts | ||
counts = counts.divide(len(labels)) | ||
return counts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright 2024 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for metrics utils.""" | ||
# pylint: disable=no-self-use | ||
|
||
|
||
import unittest | ||
|
||
import pandas as pd | ||
from parameterized import parameterized | ||
|
||
from flwr_datasets.metrics.utils import compute_counts, compute_frequency | ||
|
||
|
||
class TestMetricsUtils(unittest.TestCase): | ||
"""Test metrics utils.""" | ||
|
||
@parameterized.expand( # type: ignore | ||
[ | ||
([1, 2, 2, 3], [1, 2, 3, 4], pd.Series([1, 2, 1, 0], index=[1, 2, 3, 4])), | ||
([], [1, 2, 3], pd.Series([0, 0, 0], index=[1, 2, 3])), | ||
([1, 1, 2], [1, 2, 3, 4], pd.Series([2, 1, 0, 0], index=[1, 2, 3, 4])), | ||
] | ||
) | ||
def test_compute_counts(self, labels, unique_labels, expected) -> None: | ||
"""Test if the counts are computed correctly.""" | ||
result = compute_counts(labels, unique_labels) | ||
pd.testing.assert_series_equal(result, expected) | ||
|
||
@parameterized.expand( # type: ignore | ||
[ | ||
( | ||
[1, 1, 2, 2, 2, 3], | ||
[1, 2, 3, 4], | ||
pd.Series([0.3333, 0.5, 0.1667, 0.0], index=[1, 2, 3, 4]), | ||
), | ||
([], [1, 2, 3], pd.Series([0.0, 0.0, 0.0], index=[1, 2, 3])), | ||
( | ||
["a", "b", "b", "c"], | ||
["a", "b", "c", "d"], | ||
pd.Series([0.25, 0.50, 0.25, 0.0], index=["a", "b", "c", "d"]), | ||
), | ||
] | ||
) | ||
def test_compute_distribution(self, labels, unique_labels, expected) -> None: | ||
"""Test if the distributions are computed correctly.""" | ||
result = compute_frequency(labels, unique_labels) | ||
pd.testing.assert_series_equal(result, expected, atol=0.001) | ||
|
||
@parameterized.expand( # type: ignore | ||
[ | ||
(["a", "b", "b", "c"], ["a", "b", "c"]), | ||
([1, 2, 2, 3, 3, 3, 4], [1, 2, 3, 4]), | ||
] | ||
) | ||
def test_distribution_sum_to_one(self, labels, unique_labels) -> None: | ||
"""Test if distributions sum up to one.""" | ||
result = compute_frequency(labels, unique_labels) | ||
self.assertAlmostEqual(result.sum(), 1.0) | ||
|
||
def test_compute_counts_non_unique_labels(self) -> None: | ||
"""Test if not having the unique labels raises ValueError.""" | ||
labels = [1, 2, 3] | ||
unique_labels = [1, 2, 2, 3] | ||
with self.assertRaises(ValueError): | ||
compute_counts(labels, unique_labels) | ||
|
||
def test_compute_distribution_non_unique_labels(self) -> None: | ||
"""Test if not having the unique labels raises ValueError.""" | ||
labels = [1, 1, 2, 3] | ||
unique_labels = [1, 1, 2, 3] | ||
with self.assertRaises(ValueError): | ||
compute_frequency(labels, unique_labels) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Copyright 2024 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Visualization package.""" | ||
|
||
|
||
from .comparison_label_distribution import plot_comparison_label_distribution | ||
from .label_distribution import plot_label_distributions | ||
|
||
__all__ = [ | ||
"plot_label_distributions", | ||
"plot_comparison_label_distribution", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# Copyright 2024 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Label distribution bar plotting.""" | ||
|
||
|
||
from typing import Any, Dict, Optional, Tuple, Union | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from matplotlib import colors as mcolors | ||
from matplotlib import pyplot as plt | ||
from matplotlib.axes import Axes | ||
|
||
|
||
# pylint: disable=too-many-arguments,too-many-locals,too-many-branches | ||
def _plot_bar( | ||
dataframe: pd.DataFrame, | ||
axis: Optional[Axes], | ||
figsize: Optional[Tuple[float, float]], | ||
title: str, | ||
colormap: Optional[Union[str, mcolors.Colormap]], | ||
partition_id_axis: str, | ||
size_unit: str, | ||
legend: bool, | ||
legend_title: Optional[str], | ||
plot_kwargs: Optional[Dict[str, Any]], | ||
legend_kwargs: Optional[Dict[str, Any]], | ||
) -> Axes: | ||
|
||
if axis is None: | ||
if figsize is None: | ||
figsize = _initialize_figsize( | ||
partition_id_axis=partition_id_axis, num_partitions=dataframe.shape[0] | ||
) | ||
_, axis = plt.subplots(figsize=figsize) | ||
|
||
# Handle plot_kwargs | ||
if plot_kwargs is None: | ||
plot_kwargs = {} | ||
|
||
kind = "bar" if partition_id_axis == "x" else "barh" | ||
if "kind" not in plot_kwargs: | ||
plot_kwargs["kind"] = kind | ||
|
||
# Handle non-optional parameters | ||
plot_kwargs["title"] = title | ||
|
||
# Handle optional parameters | ||
if colormap is not None: | ||
plot_kwargs["colormap"] = colormap | ||
elif "colormap" not in plot_kwargs: | ||
plot_kwargs["colormap"] = "RdYlGn" | ||
|
||
if "xlabel" not in plot_kwargs and "ylabel" not in plot_kwargs: | ||
xlabel, ylabel = _initialize_xy_labels( | ||
size_unit=size_unit, partition_id_axis=partition_id_axis | ||
) | ||
plot_kwargs["xlabel"] = xlabel | ||
plot_kwargs["ylabel"] = ylabel | ||
|
||
# Make the x ticks readable (they appear 90 degrees rotated by default) | ||
if "rot" not in plot_kwargs: | ||
plot_kwargs["rot"] = 0 | ||
|
||
# Handle hard-coded parameters | ||
# Legend is handled separately (via axes.legend call not in the plot()) | ||
if "legend" not in plot_kwargs: | ||
plot_kwargs["legend"] = False | ||
|
||
# Make the bar plot stacked | ||
if "stacked" not in plot_kwargs: | ||
plot_kwargs["stacked"] = True | ||
|
||
axis = dataframe.plot( | ||
ax=axis, | ||
**plot_kwargs, | ||
) | ||
|
||
if legend: | ||
if legend_kwargs is None: | ||
legend_kwargs = {} | ||
|
||
if legend_title is not None: | ||
legend_kwargs["title"] = legend_title | ||
elif "title" not in legend_kwargs: | ||
legend_kwargs["title"] = "Labels" | ||
|
||
if "loc" not in legend_kwargs: | ||
legend_kwargs["loc"] = "outside center right" | ||
|
||
if "bbox_to_anchor" not in legend_kwargs: | ||
max_len_label_str = max([len(str(column)) for column in dataframe.columns]) | ||
shift = min(0.05 + max_len_label_str / 100, 0.15) | ||
legend_kwargs["bbox_to_anchor"] = (1.0 + shift, 0.5) | ||
|
||
handles, legend_labels = axis.get_legend_handles_labels() | ||
_ = axis.figure.legend( | ||
handles=handles[::-1], labels=legend_labels[::-1], **legend_kwargs | ||
) | ||
|
||
# Heuristic to make the partition id on xticks non-overlapping | ||
if partition_id_axis == "x": | ||
xticklabels = axis.get_xticklabels() | ||
if len(xticklabels) > 20: | ||
# Make every other xtick label not visible | ||
for i, label in enumerate(xticklabels): | ||
if i % 2 == 1: | ||
label.set_visible(False) | ||
return axis | ||
|
||
|
||
def _initialize_figsize( | ||
partition_id_axis: str, | ||
num_partitions: int, | ||
) -> Tuple[float, float]: | ||
figsize = (0.0, 0.0) | ||
if partition_id_axis == "x": | ||
figsize = (6.4, 4.8) | ||
elif partition_id_axis == "y": | ||
figsize = (6.4, np.sqrt(num_partitions)) | ||
return figsize | ||
|
||
|
||
def _initialize_xy_labels(size_unit: str, partition_id_axis: str) -> Tuple[str, str]: | ||
xlabel = "Partition ID" | ||
ylabel = "Count" if size_unit == "absolute" else "Percent %" | ||
|
||
if partition_id_axis == "y": | ||
xlabel, ylabel = ylabel, xlabel | ||
|
||
return xlabel, ylabel |
Oops, something went wrong.