Skip to content

Commit df4c389

Browse files
committed
Support pandas.DataFrame as input
1 parent 2e0660a commit df4c389

File tree

3 files changed

+25
-4
lines changed

3 files changed

+25
-4
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3148,6 +3148,20 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
31483148
df = pd.DataFrame(array, index=group_labels, columns=dataset_labels)
31493149
df.plot.bar()
31503150
3151+
- a `pandas.DataFrame`.
3152+
3153+
.. code-block::
3154+
3155+
df = pd.DataFrame(
3156+
np.random.random((2, 3))
3157+
index=["group_A", "group_B"],
3158+
columns=["dataset_0", "dataset_1", "dataset_2"]
3159+
)
3160+
grouped_bar(df)
3161+
3162+
Note that ``grouped_bar(df)`` produced a structurally equivalent plot like
3163+
``df.plot.bar()`.
3164+
31513165
positions : array-like, optional
31523166
The center positions of the bar groups. The values have to be equidistant.
31533167
If not given, a sequence of integer positions 0, 1, 2, ... is used.
@@ -3198,13 +3212,19 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
31983212
See also `.Artist.remove()`.
31993213
32003214
"""
3201-
if hasattr(heights, 'keys'):
3215+
if cbook._is_pandas_dataframe(heights):
3216+
if labels is None:
3217+
labels = heights.columns.tolist()
3218+
if tick_labels is None:
3219+
tick_labels = heights.index.tolist()
3220+
heights = heights.to_numpy().T
3221+
elif hasattr(heights, 'keys'): # dict
32023222
if labels is not None:
32033223
raise ValueError(
32043224
"'labels' cannot be used if 'heights' are a mapping")
32053225
labels = heights.keys()
32063226
heights = list(heights.values())
3207-
elif hasattr(heights, 'shape'):
3227+
elif hasattr(heights, 'shape'): # numpy array
32083228
heights = heights.T
32093229

32103230
num_datasets = len(heights)

lib/matplotlib/axes/_axes.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ class Axes(_AxesBase):
287287
) -> PolyCollection: ...
288288
def grouped_bar(
289289
self,
290-
heights : Sequence[ArrayLike] | dict[str, ArrayLike] | np.ndarray,
290+
heights : Sequence[ArrayLike] | dict[str, ArrayLike] | np.ndarray | pd.DataFrame,
291291
*,
292292
positions : ArrayLike | None = ...,
293293
tick_labels : Sequence[str] | None = ...,

lib/matplotlib/pyplot.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494

9595
import PIL.Image
9696
from numpy.typing import ArrayLike
97+
import pandas as pd
9798

9899
import matplotlib.axes
99100
import matplotlib.artist
@@ -3391,7 +3392,7 @@ def grid(
33913392
# Autogenerated by boilerplate.py. Do not edit as changes will be lost.
33923393
@_copy_docstring_and_deprecators(Axes.grouped_bar)
33933394
def grouped_bar(
3394-
heights: Sequence[ArrayLike] | dict[str, ArrayLike] | np.ndarray,
3395+
heights: Sequence[ArrayLike] | dict[str, ArrayLike] | np.ndarray | pd.DataFrame,
33953396
*,
33963397
positions: ArrayLike | None = None,
33973398
group_spacing: float | None = 1.5,

0 commit comments

Comments
 (0)