Skip to content

Commit

Permalink
Support pandas.DataFrame as input
Browse files Browse the repository at this point in the history
  • Loading branch information
timhoffm committed Jan 18, 2025
1 parent 2e0660a commit df4c389
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 4 deletions.
24 changes: 22 additions & 2 deletions lib/matplotlib/axes/_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3148,6 +3148,20 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
df = pd.DataFrame(array, index=group_labels, columns=dataset_labels)
df.plot.bar()
- a `pandas.DataFrame`.
.. code-block::
df = pd.DataFrame(
np.random.random((2, 3))
index=["group_A", "group_B"],
columns=["dataset_0", "dataset_1", "dataset_2"]
)
grouped_bar(df)
Note that ``grouped_bar(df)`` produced a structurally equivalent plot like
``df.plot.bar()`.
positions : array-like, optional
The center positions of the bar groups. The values have to be equidistant.
If not given, a sequence of integer positions 0, 1, 2, ... is used.
Expand Down Expand Up @@ -3198,13 +3212,19 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
See also `.Artist.remove()`.
"""
if hasattr(heights, 'keys'):
if cbook._is_pandas_dataframe(heights):
if labels is None:
labels = heights.columns.tolist()
if tick_labels is None:
tick_labels = heights.index.tolist()
heights = heights.to_numpy().T
elif hasattr(heights, 'keys'): # dict
if labels is not None:
raise ValueError(
"'labels' cannot be used if 'heights' are a mapping")
labels = heights.keys()
heights = list(heights.values())
elif hasattr(heights, 'shape'):
elif hasattr(heights, 'shape'): # numpy array
heights = heights.T

num_datasets = len(heights)
Expand Down
2 changes: 1 addition & 1 deletion lib/matplotlib/axes/_axes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ class Axes(_AxesBase):
) -> PolyCollection: ...
def grouped_bar(
self,
heights : Sequence[ArrayLike] | dict[str, ArrayLike] | np.ndarray,
heights : Sequence[ArrayLike] | dict[str, ArrayLike] | np.ndarray | pd.DataFrame,
*,
positions : ArrayLike | None = ...,
tick_labels : Sequence[str] | None = ...,
Expand Down
3 changes: 2 additions & 1 deletion lib/matplotlib/pyplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@

import PIL.Image
from numpy.typing import ArrayLike
import pandas as pd

import matplotlib.axes
import matplotlib.artist
Expand Down Expand Up @@ -3391,7 +3392,7 @@ def grid(
# Autogenerated by boilerplate.py. Do not edit as changes will be lost.
@_copy_docstring_and_deprecators(Axes.grouped_bar)
def grouped_bar(
heights: Sequence[ArrayLike] | dict[str, ArrayLike] | np.ndarray,
heights: Sequence[ArrayLike] | dict[str, ArrayLike] | np.ndarray | pd.DataFrame,
*,
positions: ArrayLike | None = None,
group_spacing: float | None = 1.5,
Expand Down

0 comments on commit df4c389

Please sign in to comment.