Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Directly check if two figures returned by a function are equal #95

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 72 additions & 44 deletions pytest_mpl/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,34 @@ def get_marker(item, marker_name):
return item.keywords.get(marker_name)


def _raise_on_image_difference(expected, actual, tol):
"""
Based on matplotlib.testing.decorators._raise_on_image_difference

Compare image size ourselves since the Matplotlib
exception is a bit cryptic in this case and doesn't show
the filenames
"""
from matplotlib.image import imread
from matplotlib.testing.compare import compare_images

expected_shape = imread(expected).shape[:2]
actual_shape = imread(actual).shape[:2]
if expected_shape != actual_shape:
error = SHAPE_MISMATCH_ERROR.format(expected_path=expected,
expected_shape=expected_shape,
actual_path=actual,
actual_shape=actual_shape)
pytest.fail(error, pytrace=False)

msg = compare_images(expected, actual, tol=tol)

if msg is None:
shutil.rmtree(os.path.dirname(expected))
else:
pytest.fail(msg, pytrace=False)


class ImageComparison(object):

def __init__(self, config, baseline_dir=None, generate_dir=None, results_dir=None):
Expand All @@ -195,9 +223,7 @@ def pytest_runtest_setup(self, item):
return

import matplotlib
from matplotlib.image import imread
import matplotlib.pyplot as plt
from matplotlib.testing.compare import compare_images
try:
from matplotlib.testing.decorators import remove_ticks_and_titles
except ImportError:
Expand Down Expand Up @@ -246,7 +272,10 @@ def item_function_wrapper(*args, **kwargs):
fig = original(*args, **kwargs)

if remove_text:
remove_ticks_and_titles(fig)
if not isinstance(fig, tuple):
remove_ticks_and_titles(fig)
else:
[remove_ticks_and_titles(f) for f in fig]

# Find test name to use as plot name
filename = compare.kwargs.get('filename', None)
Expand All @@ -260,52 +289,51 @@ def item_function_wrapper(*args, **kwargs):
# reference images or simply running the test.
if self.generate_dir is None:

# Save the figure
# Save the figure(s)
result_dir = tempfile.mkdtemp(dir=self.results_dir)
test_image = os.path.abspath(os.path.join(result_dir, filename))

fig.savefig(test_image, **savefig_kwargs)
close_mpl_figure(fig)

# Find path to baseline image
if baseline_remote:
baseline_image_ref = _download_file(baseline_dir, filename)
else:
baseline_image_ref = os.path.abspath(os.path.join(
os.path.dirname(item.fspath.strpath), baseline_dir, filename))

if not os.path.exists(baseline_image_ref):
pytest.fail("Image file not found for comparison test in: "
"\n\t{baseline_dir}"
"\n(This is expected for new tests.)\nGenerated Image: "
"\n\t{test}".format(baseline_dir=baseline_dir,
test=test_image),
pytrace=False)

# distutils may put the baseline images in non-accessible places,
# copy to our tmpdir to be sure to keep them in case of failure
baseline_image = os.path.abspath(os.path.join(result_dir,
'baseline-' + filename))
shutil.copyfile(baseline_image_ref, baseline_image)

# Compare image size ourselves since the Matplotlib
# exception is a bit cryptic in this case and doesn't show
# the filenames
expected_shape = imread(baseline_image).shape[:2]
actual_shape = imread(test_image).shape[:2]
if expected_shape != actual_shape:
error = SHAPE_MISMATCH_ERROR.format(expected_path=baseline_image,
expected_shape=expected_shape,
actual_path=test_image,
actual_shape=actual_shape)
pytest.fail(error, pytrace=False)

msg = compare_images(baseline_image, test_image, tol=tolerance)

if msg is None:
shutil.rmtree(result_dir)

if not isinstance(fig, tuple):
fig.savefig(test_image, **savefig_kwargs)
close_mpl_figure(fig)

# Find path to baseline image
if baseline_remote:
baseline_image_ref = _download_file(baseline_dir, filename)
else:
baseline_image_ref = os.path.abspath(os.path.join(
os.path.dirname(item.fspath.strpath), baseline_dir, filename))

if not os.path.exists(baseline_image_ref):
pytest.fail("Image file not found for comparison test in: "
"\n\t{baseline_dir}"
"\n(This is expected for new tests.)\nGenerated Image: "
"\n\t{test}".format(baseline_dir=baseline_dir,
test=test_image),
pytrace=False)

# distutils may put the baseline images in non-accessible places,
# copy to our tmpdir to be sure to keep them in case of failure
shutil.copyfile(baseline_image_ref, baseline_image)

else:
pytest.fail(msg, pytrace=False)
fig[0].savefig(test_image, **savefig_kwargs)
close_mpl_figure(fig[0])
fig[1].savefig(baseline_image, **savefig_kwargs)
close_mpl_figure(fig[1])

_raise_on_image_difference(
expected=baseline_image,
actual=test_image,
tol=tolerance
)

elif self.generate_dir and isinstance(fig, tuple):
close_mpl_figure(fig[0])
close_mpl_figure(fig[1])
pytest.skip("Skipping image comparison test")

else:

Expand Down
55 changes: 55 additions & 0 deletions tests/test_pytest_mpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,58 @@ def test_succeeds(self):
ax = fig.add_subplot(1, 1, 1)
ax.plot(self.x)
return fig


@pytest.mark.mpl_image_compare
def test_check_equal():
fig_test, ax_test = plt.subplots()
ax_test.plot([1, 3, 5])

fig_ref, ax_ref = plt.subplots()
ax_ref.plot([0, 1, 2], [1, 3, 5])

return fig_test, fig_ref


TEST_GENERATE_2 = """
import pytest
import matplotlib.pyplot as plt
@pytest.mark.mpl_image_compare
def test_gen_two_figs():
fig_test, ax_test = plt.subplots()
ax_test.plot([1, 3, 5])
fig_ref, ax_ref = plt.subplots()
ax_ref.plot([0, 1, 2], [1, 3, 7])
return fig_test, fig_ref
"""


def test_check_unequal_fails(tmpdir):

test_file = tmpdir.join("test2.py").strpath
with open(test_file, "w") as f:
f.write(TEST_GENERATE_2)

# If we use --mpl, it should detect that the two figures are not the same
code = subprocess.call([sys.executable, "-m", "pytest", "--mpl", test_file])
assert code != 0

# If we don't use --mpl option, the test should succeed
code = subprocess.call([sys.executable, "-m", "pytest", test_file])
assert code == 0


def test_skip_generate_two_figures(tmpdir):

test_file = tmpdir.join("test2.py").strpath
with open(test_file, "w") as f:
f.write(TEST_GENERATE_2)

gen_dir = tmpdir.mkdir("spam").mkdir("egg").strpath

# If we try to generate, the test should be skipped and a new file won't appear
code = subprocess.call([sys.executable, "-m", "pytest",
"--mpl-generate-path={0}".format(gen_dir),
test_file])
assert code == 0
assert not os.path.exists(os.path.join(gen_dir, "test_gen_two_figs.png"))