Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: hannah <[email protected]>
  • Loading branch information
timhoffm and story645 committed Jan 24, 2025
1 parent bf2894a commit 4cc0e31
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 40 deletions.
9 changes: 5 additions & 4 deletions doc/users/next_whats_new/grouped_bar.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@ Example:

.. plot::
:include-source: true
:alt: Diagram of a grouped bar chart of 3 datasets with 2 categories.

import matplotlib.pyplot as plt

categories = ['A', 'B']
datasets = {
'dataset 0': [1.0, 3.0],
'dataset 1': [1.4, 3.4],
'dataset 2': [1.8, 3.8],
'dataset 0': [1, 11],
'dataset 1': [3, 13],
'dataset 2': [5, 15],
}

fig, ax = plt.subplots(figsize=(4, 2.2))
fig, ax = plt.subplots()
ax.grouped_bar(datasets, tick_labels=categories)
ax.legend()
67 changes: 31 additions & 36 deletions lib/matplotlib/axes/_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3077,18 +3077,14 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
This function is new in v3.11, and the API is still provisional.
We may still fine-tune some aspects based on user-feedback.
This is a convenience function to plot bars for multiple datasets.
In particular, it simplifies positioning of the bars compared to individual
`~.Axes.bar` plots.
Bar plots present categorical data as a sequence of bars, one bar per category.
We call one set of such values a *dataset* and it's bars all share the same
We call one set of such values a *dataset* and its bars all share the same
color. Grouped bar plots show multiple such datasets, where the values per
category are grouped together. The category names are drawn as tick labels
below the bar groups. Each dataset has a distinct bar color, and can optionally
get a label that is used for the legend.
Here is an example call structure and the corresponding plot:
Example:
.. code-block:: python
Expand All @@ -3098,6 +3094,10 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
.. plot:: _embedded_plots/grouped_bar.py
``grouped_bar()`` is a high-level plotting function for grouped bar charts.
Use `~.Axes.bar` instead if you need finer grained control on individual bar
positions or widths.
Parameters
----------
heights : list of array-like or dict of array-like or 2D array \
Expand All @@ -3121,25 +3121,20 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
- dict of array-like: A mapping from names to datasets. Each dataset
(dict value) must have the same number of elements.
This is similar to passing a list of array-like, with the addition that
each dataset gets a name.
Example call:
.. code-block:: python
grouped_bar({'ds0': dataset_0, 'ds1': dataset_1, 'ds2': dataset_2]})
The names are used as *labels*, i.e. the following two calls are
equivalent:
The names are used as *labels*, which is equivalent to passing in a dict
of array-like:
.. code-block:: python
data_dict = {'ds0': dataset_0, 'ds1': dataset_1, 'ds2': dataset_2]}
grouped_bar(data_dict)
grouped_bar(data_dict.values(), labels=data_dict.keys())
When using a dict-like input, you must not pass *labels* explicitly.
When using a dict input, you must not pass *labels* explicitly.
- a 2D array: The rows are the categories, the columns are the different
datasets.
Expand All @@ -3154,21 +3149,16 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
.. code-block:: python
group_labels = ["group_A", "group_B"]
categories = ["A", "B"]
dataset_labels = ["dataset_0", "dataset_1", "dataset_2"]
array = np.random.random((2, 3))
Note that this is consistent with pandas. These two calls produce
the same bar plot structure:
.. code-block:: python
grouped_bar(array, tick_labels=categories, labels=dataset_labels)
df = pd.DataFrame(array, index=categories, columns=dataset_labels)
df.plot.bar()
- a `pandas.DataFrame`.
The index is used for the categories, the columns are used for the
datasets.
.. code-block:: python
df = pd.DataFrame(
Expand All @@ -3178,6 +3168,9 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
)
grouped_bar(df)
# is equivalent to
grouped_bar(df.to_numpy(), tick_labels=df.index, labels=df.columns)
Note that ``grouped_bar(df)`` produces a structurally equivalent plot like
``df.plot.bar()``.
Expand All @@ -3187,9 +3180,8 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
tick_labels : list of str, optional
The category labels, which are placed on ticks at the center *positions*
of the bar groups.
If not set, the axis ticks (positions and labels) are left unchanged.
of the bar groups. If not set, the axis ticks (positions and labels) are
left unchanged.
labels : list of str, optional
The labels of the datasets, i.e. the bars within one group.
Expand Down Expand Up @@ -3249,7 +3241,7 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
aspects. ``bar(x, y)`` is a lower-level API and places bars with height *y*
at explicit positions *x*. It also allows to specify individual bar widths
and colors. This kind of detailed control and flexibility is difficult to
manage and often not needed when plotting multiple datasets as grouped bar
manage and often not needed when plotting multiple datasets as a grouped bar
plot. Therefore, ``grouped_bar`` focusses on the abstraction of bar plots
as visualization of categorical data.
Expand Down Expand Up @@ -3309,8 +3301,18 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
heights = heights.T

num_datasets = len(heights)
dataset_0 = next(iter(heights))
num_groups = len(dataset_0)
num_groups = len(next(iter(heights))) # inferred from first dataset

# validate that all datasets have the same length, i.e. num_groups
# - can be skipped if heights is an array
if not hasattr(heights, 'shape'):
for i, dataset in enumerate(heights):
if len(dataset) != num_groups:
raise ValueError(
"'heights' contains datasets with different number of "
f"elements. dataset 0 has {num_groups} elements but "
f"dataset {i} has {len(dataset)} elements."
)

if positions is None:
group_centers = np.arange(num_groups)
Expand All @@ -3325,13 +3327,6 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
else:
group_distance = 1

for i, dataset in enumerate(heights):
if len(dataset) != num_groups:
raise ValueError(
f"'x' indicates {num_groups} groups, but dataset {i} "
f"has {len(dataset)} groups"
)

_api.check_in_list(["vertical", "horizontal"], orientation=orientation)

if colors is None:
Expand Down

0 comments on commit 4cc0e31

Please sign in to comment.