Skip to content

Commit

Permalink
Add tests for grouped_bar()
Browse files Browse the repository at this point in the history
  • Loading branch information
timhoffm committed Jan 23, 2025
1 parent 121c21d commit 8963636
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 0 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
85 changes: 85 additions & 0 deletions lib/matplotlib/tests/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2144,6 +2144,91 @@ def test_bar_datetime_start():
assert isinstance(ax.xaxis.get_major_formatter(), mdates.AutoDateFormatter)


@image_comparison(["grouped_bar.png"], style="mpl20")
def test_grouped_bar():
data = {
'data1': [1, 2, 3],
'data2': [1.2, 2.2, 3.2],
'data3': [1.4, 2.4, 3.4],
}

fig, ax = plt.subplots()
ax.grouped_bar(data, tick_labels=['A', 'B', 'C'],
group_spacing=0.5, bar_spacing=0.1,
colors=['#1f77b4', '#58a1cf', '#abd0e6'])
ax.set_yticks([])


@check_figures_equal(extensions=["png"])
def test_grouped_bar_list_of_datasets(fig_test, fig_ref):
categories = ['A', 'B']
data1 = [1, 1.2]
data2 = [2, 2.4]
data3 = [3, 3.6]

ax = fig_test.subplots()
ax.grouped_bar([data1, data2, data3], tick_labels=categories,
labels=["data1", "data2", "data3"])
ax.legend()

ax = fig_ref.subplots()
label_pos = np.array([0, 1])
bar_width = 1 / (3 + 1.5) # 3 bars + 1.5 group_spacing
data_shift = -1 * bar_width + np.array([0, bar_width, 2 * bar_width])
ax.bar(label_pos + data_shift[0], data1, width=bar_width, label="data1")
ax.bar(label_pos + data_shift[1], data2, width=bar_width, label="data2")
ax.bar(label_pos + data_shift[2], data3, width=bar_width, label="data3")
ax.set_xticks(label_pos, categories)
ax.legend()


@check_figures_equal(extensions=["png"])
def test_grouped_bar_dict_of_datasets(fig_test, fig_ref):
categories = ['A', 'B']
data_dict = dict(data1=[1, 1.2], data2=[2, 2.4], data3=[3, 3.6])

ax = fig_test.subplots()
ax.grouped_bar(data_dict, tick_labels=categories)
ax.legend()

ax = fig_ref.subplots()
ax.grouped_bar(data_dict.values(), tick_labels=categories, labels=data_dict.keys())
ax.legend()


@check_figures_equal(extensions=["png"])
def test_grouped_bar_array(fig_test, fig_ref):
categories = ['A', 'B']
array = np.array([[1, 2, 3], [1.2, 2.4, 3.6]])
labels = ['data1', 'data2', 'data3']

ax = fig_test.subplots()
ax.grouped_bar(array, tick_labels=categories, labels=labels)
ax.legend()

ax = fig_ref.subplots()
list_of_datasets = [column for column in array.T]
ax.grouped_bar(list_of_datasets, tick_labels=categories, labels=labels)
ax.legend()


@check_figures_equal(extensions=["png"])
def test_grouped_bar_dataframe(fig_test, fig_ref, pd):
categories = ['A', 'B']
labels = ['data1', 'data2', 'data3']
df = pd.DataFrame([[1, 2, 3], [1.2, 2.4, 3.6]],
index=categories, columns=labels)

ax = fig_test.subplots()
ax.grouped_bar(df)
ax.legend()

ax = fig_ref.subplots()
list_of_datasets = [df[col].to_numpy() for col in df.columns]
ax.grouped_bar(list_of_datasets, tick_labels=categories, labels=labels)
ax.legend()


def test_boxplot_dates_pandas(pd):
# smoke test for boxplot and dates in pandas
data = np.random.rand(5, 2)
Expand Down

0 comments on commit 8963636

Please sign in to comment.