From df4c38924e21f8e8a70aabbaec69b464112ec6a4 Mon Sep 17 00:00:00 2001 From: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> Date: Sun, 19 Jan 2025 00:21:55 +0100 Subject: [PATCH] Support pandas.DataFrame as input --- lib/matplotlib/axes/_axes.py | 24 ++++++++++++++++++++++-- lib/matplotlib/axes/_axes.pyi | 2 +- lib/matplotlib/pyplot.py | 3 ++- 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/lib/matplotlib/axes/_axes.py b/lib/matplotlib/axes/_axes.py index 50ea66656e43..e8643f995503 100644 --- a/lib/matplotlib/axes/_axes.py +++ b/lib/matplotlib/axes/_axes.py @@ -3148,6 +3148,20 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing df = pd.DataFrame(array, index=group_labels, columns=dataset_labels) df.plot.bar() + - a `pandas.DataFrame`. + + .. code-block:: + + df = pd.DataFrame( + np.random.random((2, 3)) + index=["group_A", "group_B"], + columns=["dataset_0", "dataset_1", "dataset_2"] + ) + grouped_bar(df) + + Note that ``grouped_bar(df)`` produced a structurally equivalent plot like + ``df.plot.bar()`. + positions : array-like, optional The center positions of the bar groups. The values have to be equidistant. If not given, a sequence of integer positions 0, 1, 2, ... is used. @@ -3198,13 +3212,19 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing See also `.Artist.remove()`. """ - if hasattr(heights, 'keys'): + if cbook._is_pandas_dataframe(heights): + if labels is None: + labels = heights.columns.tolist() + if tick_labels is None: + tick_labels = heights.index.tolist() + heights = heights.to_numpy().T + elif hasattr(heights, 'keys'): # dict if labels is not None: raise ValueError( "'labels' cannot be used if 'heights' are a mapping") labels = heights.keys() heights = list(heights.values()) - elif hasattr(heights, 'shape'): + elif hasattr(heights, 'shape'): # numpy array heights = heights.T num_datasets = len(heights) diff --git a/lib/matplotlib/axes/_axes.pyi b/lib/matplotlib/axes/_axes.pyi index ee9fcc7335f9..75ae4a821ec6 100644 --- a/lib/matplotlib/axes/_axes.pyi +++ b/lib/matplotlib/axes/_axes.pyi @@ -287,7 +287,7 @@ class Axes(_AxesBase): ) -> PolyCollection: ... def grouped_bar( self, - heights : Sequence[ArrayLike] | dict[str, ArrayLike] | np.ndarray, + heights : Sequence[ArrayLike] | dict[str, ArrayLike] | np.ndarray | pd.DataFrame, *, positions : ArrayLike | None = ..., tick_labels : Sequence[str] | None = ..., diff --git a/lib/matplotlib/pyplot.py b/lib/matplotlib/pyplot.py index 06a52fc1cb08..3c1bfdf953b0 100644 --- a/lib/matplotlib/pyplot.py +++ b/lib/matplotlib/pyplot.py @@ -94,6 +94,7 @@ import PIL.Image from numpy.typing import ArrayLike + import pandas as pd import matplotlib.axes import matplotlib.artist @@ -3391,7 +3392,7 @@ def grid( # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.grouped_bar) def grouped_bar( - heights: Sequence[ArrayLike] | dict[str, ArrayLike] | np.ndarray, + heights: Sequence[ArrayLike] | dict[str, ArrayLike] | np.ndarray | pd.DataFrame, *, positions: ArrayLike | None = None, group_spacing: float | None = 1.5,