Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: hannah <[email protected]>
  • Loading branch information
timhoffm and story645 committed Jan 24, 2025
1 parent bf2894a commit 162ce79
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 46 deletions.
9 changes: 5 additions & 4 deletions doc/users/next_whats_new/grouped_bar.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@ Example:

.. plot::
:include-source: true
:alt: Diagram of a grouped bar chart of 3 datasets with 2 categories.

import matplotlib.pyplot as plt

categories = ['A', 'B']
datasets = {
'dataset 0': [1.0, 3.0],
'dataset 1': [1.4, 3.4],
'dataset 2': [1.8, 3.8],
'dataset 0': [1, 11],
'dataset 1': [3, 13],
'dataset 2': [5, 15],
}

fig, ax = plt.subplots(figsize=(4, 2.2))
fig, ax = plt.subplots()
ax.grouped_bar(datasets, tick_labels=categories)
ax.legend()
79 changes: 37 additions & 42 deletions lib/matplotlib/axes/_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3073,22 +3073,19 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
"""
Make a grouped bar plot.
.. note::
.. versionadded:: 3.11
This function is new in v3.11, and the API is still provisional.
We may still fine-tune some aspects based on user-feedback.
This is a convenience function to plot bars for multiple datasets.
In particular, it simplifies positioning of the bars compared to individual
`~.Axes.bar` plots.
Bar plots present categorical data as a sequence of bars, one bar per category.
We call one set of such values a *dataset* and it's bars all share the same
We call one set of such values a *dataset* and its bars all share the same
color. Grouped bar plots show multiple such datasets, where the values per
category are grouped together. The category names are drawn as tick labels
below the bar groups. Each dataset has a distinct bar color, and can optionally
get a label that is used for the legend.
Here is an example call structure and the corresponding plot:
Example:
.. code-block:: python
Expand Down Expand Up @@ -3121,25 +3118,20 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
- dict of array-like: A mapping from names to datasets. Each dataset
(dict value) must have the same number of elements.
This is similar to passing a list of array-like, with the addition that
each dataset gets a name.
Example call:
.. code-block:: python
grouped_bar({'ds0': dataset_0, 'ds1': dataset_1, 'ds2': dataset_2]})
data_dict = {'ds0': dataset_0, 'ds1': dataset_1, 'ds2': dataset_2}
grouped_bar(data_dict)
The names are used as *labels*, i.e. the following two calls are
equivalent:
The names are used as *labels*, i.e. this is equivalent to
.. code-block:: python
data_dict = {'ds0': dataset_0, 'ds1': dataset_1, 'ds2': dataset_2]}
grouped_bar(data_dict)
grouped_bar(data_dict.values(), labels=data_dict.keys())
When using a dict-like input, you must not pass *labels* explicitly.
When using a dict input, you must not pass *labels* explicitly.
- a 2D array: The rows are the categories, the columns are the different
datasets.
Expand All @@ -3154,30 +3146,31 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
.. code-block:: python
group_labels = ["group_A", "group_B"]
categories = ["A", "B"]
dataset_labels = ["dataset_0", "dataset_1", "dataset_2"]
array = np.random.random((2, 3))
Note that this is consistent with pandas. These two calls produce
the same bar plot structure:
.. code-block:: python
grouped_bar(array, tick_labels=categories, labels=dataset_labels)
df = pd.DataFrame(array, index=categories, columns=dataset_labels)
df.plot.bar()
- a `pandas.DataFrame`.
The index is used for the categories, the columns are used for the
datasets.
.. code-block:: python
df = pd.DataFrame(
np.random.random((2, 3))
index=["group_A", "group_B"],
np.random.random((2, 3)),
index=["A", "B"],
columns=["dataset_0", "dataset_1", "dataset_2"]
)
grouped_bar(df)
i.e. this is equivalent to
.. code-block::
grouped_bar(df.to_numpy(), tick_labels=df.index, labels=df.columns)
Note that ``grouped_bar(df)`` produces a structurally equivalent plot like
``df.plot.bar()``.
Expand All @@ -3187,22 +3180,21 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
tick_labels : list of str, optional
The category labels, which are placed on ticks at the center *positions*
of the bar groups.
If not set, the axis ticks (positions and labels) are left unchanged.
of the bar groups. If not set, the axis ticks (positions and labels) are
left unchanged.
labels : list of str, optional
The labels of the datasets, i.e. the bars within one group.
These will show up in the legend.
group_spacing : float, default: 1.5
The space between two bar groups in units of bar width.
The space between two bar groups as multiples of bar width.
The default value of 1.5 thus means that there's a gap of
1.5 bar widths between bar groups.
bar_spacing : float, default: 0
The space between bars in units of bar width.
The space between bars as multiples of bar width.
orientation : {"vertical", "horizontal"}, default: "vertical"
The direction of the bars.
Expand Down Expand Up @@ -3249,7 +3241,7 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
aspects. ``bar(x, y)`` is a lower-level API and places bars with height *y*
at explicit positions *x*. It also allows to specify individual bar widths
and colors. This kind of detailed control and flexibility is difficult to
manage and often not needed when plotting multiple datasets as grouped bar
manage and often not needed when plotting multiple datasets as a grouped bar
plot. Therefore, ``grouped_bar`` focusses on the abstraction of bar plots
as visualization of categorical data.
Expand Down Expand Up @@ -3309,8 +3301,18 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
heights = heights.T

num_datasets = len(heights)
dataset_0 = next(iter(heights))
num_groups = len(dataset_0)
num_groups = len(next(iter(heights))) # inferred from first dataset

# validate that all datasets have the same length, i.e. num_groups
# - can be skipped if heights is an array
if not hasattr(heights, 'shape'):
for i, dataset in enumerate(heights):
if len(dataset) != num_groups:
raise ValueError(
"'heights' contains datasets with different number of "
f"elements. dataset 0 has {num_groups} elements but "
f"dataset {i} has {len(dataset)} elements."
)

if positions is None:
group_centers = np.arange(num_groups)
Expand All @@ -3325,13 +3327,6 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
else:
group_distance = 1

for i, dataset in enumerate(heights):
if len(dataset) != num_groups:
raise ValueError(
f"'x' indicates {num_groups} groups, but dataset {i} "
f"has {len(dataset)} groups"
)

_api.check_in_list(["vertical", "horizontal"], orientation=orientation)

if colors is None:
Expand Down

0 comments on commit 162ce79

Please sign in to comment.