Skip to content

Commit

Permalink
Merge pull request #171 from ConorMacBride/fix-unittest
Browse files Browse the repository at this point in the history
Fix tests which exit before returning a figure or use `unittest.TestCase`
  • Loading branch information
ConorMacBride authored Jul 22, 2022
2 parents 48e652f + d84892b commit 310ee99
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 36 deletions.
75 changes: 45 additions & 30 deletions pytest_mpl/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,33 @@ def pathify(path):
return Path(path + ext)


def _pytest_pyfunc_call(obj, pyfuncitem):
testfunction = pyfuncitem.obj
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
obj.result = testfunction(**testargs)
return True
def generate_test_name(item):
"""
Generate a unique name for the hash for this test.
"""
if item.cls is not None:
name = f"{item.module.__name__}.{item.cls.__name__}.{item.name}"
else:
name = f"{item.module.__name__}.{item.name}"
return name


def wrap_figure_interceptor(plugin, item):
"""
Intercept and store figures returned by test functions.
"""
# Only intercept figures on marked figure tests
if get_compare(item) is not None:

# Use the full test name as a key to ensure correct figure is being retrieved
test_name = generate_test_name(item)

def figure_interceptor(store, obj):
def wrapper(*args, **kwargs):
store.return_value[test_name] = obj(*args, **kwargs)
return wrapper

item.obj = figure_interceptor(plugin, item.obj)


def pytest_report_header(config, startdir):
Expand Down Expand Up @@ -275,6 +296,7 @@ def __init__(self,
self._generated_hash_library = {}
self._test_results = {}
self._test_stats = None
self.return_value = {}

# https://stackoverflow.com/questions/51737378/how-should-i-log-in-my-pytest-plugin
# turn debug prints on only if "-vv" or more passed
Expand All @@ -287,7 +309,7 @@ def generate_filename(self, item):
Given a pytest item, generate the figure filename.
"""
if self.config.getini('mpl-use-full-test-name'):
filename = self.generate_test_name(item) + '.png'
filename = generate_test_name(item) + '.png'
else:
compare = get_compare(item)
# Find test name to use as plot name
Expand All @@ -298,21 +320,11 @@ def generate_filename(self, item):
filename = str(pathify(filename))
return filename

def generate_test_name(self, item):
"""
Generate a unique name for the hash for this test.
"""
if item.cls is not None:
name = f"{item.module.__name__}.{item.cls.__name__}.{item.name}"
else:
name = f"{item.module.__name__}.{item.name}"
return name

def make_test_results_dir(self, item):
"""
Generate the directory to put the results in.
"""
test_name = pathify(self.generate_test_name(item))
test_name = pathify(generate_test_name(item))
results_dir = self.results_dir / test_name
results_dir.mkdir(exist_ok=True, parents=True)
return results_dir
Expand Down Expand Up @@ -526,7 +538,7 @@ def compare_image_to_hash_library(self, item, fig, result_dir, summary=None):
pytest.fail(f"Can't find hash library at path {hash_library_filename}")

hash_library = self.load_hash_library(hash_library_filename)
hash_name = self.generate_test_name(item)
hash_name = generate_test_name(item)
baseline_hash = hash_library.get(hash_name, None)
summary['baseline_hash'] = baseline_hash

Expand Down Expand Up @@ -607,13 +619,17 @@ def pytest_runtest_call(self, item): # noqa
with plt.style.context(style, after_reset=True), switch_backend(backend):

# Run test and get figure object
wrap_figure_interceptor(self, item)
yield
fig = self.result
test_name = generate_test_name(item)
if test_name not in self.return_value:
# Test function did not complete successfully
return
fig = self.return_value[test_name]

if remove_text:
remove_ticks_and_titles(fig)

test_name = self.generate_test_name(item)
result_dir = self.make_test_results_dir(item)

summary = {
Expand Down Expand Up @@ -677,10 +693,6 @@ def pytest_runtest_call(self, item): # noqa
if summary['status'] == 'skipped':
pytest.skip(summary['status_msg'])

@pytest.hookimpl(tryfirst=True)
def pytest_pyfunc_call(self, pyfuncitem):
return _pytest_pyfunc_call(self, pyfuncitem)

def generate_summary_json(self):
json_file = self.results_dir / 'results.json'
with open(json_file, 'w') as f:
Expand Down Expand Up @@ -732,13 +744,16 @@ class FigureCloser:

def __init__(self, config):
self.config = config
self.return_value = {}

@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_call(self, item):
wrap_figure_interceptor(self, item)
yield
if get_compare(item) is not None:
close_mpl_figure(self.result)

@pytest.hookimpl(tryfirst=True)
def pytest_pyfunc_call(self, pyfuncitem):
return _pytest_pyfunc_call(self, pyfuncitem)
test_name = generate_test_name(item)
if test_name not in self.return_value:
# Test function did not complete successfully
return
fig = self.return_value[test_name]
close_mpl_figure(fig)
7 changes: 7 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ test =

[tool:pytest]
testpaths = "tests"
markers =
image: run test during image comparison only mode.
hash: run test during hash comparison only mode.
filterwarnings =
error
ignore:distutils Version classes are deprecated
ignore:the imp module is deprecated in favour of importlib

[flake8]
max-line-length = 100
Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest
from packaging.version import Version

pytest_plugins = ["pytester"]

if Version(pytest.__version__) < Version("6.2.0"):
@pytest.fixture
def pytester(testdir):
return testdir
143 changes: 142 additions & 1 deletion tests/test_pytest_mpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import subprocess
from pathlib import Path
from unittest import TestCase

import matplotlib
import matplotlib.ft2font
Expand Down Expand Up @@ -259,6 +260,23 @@ def test_succeeds(self):
return fig


class TestClassWithTestCase(TestCase):

# Regression test for a bug that occurred when using unittest.TestCase

def setUp(self):
self.x = [1, 2, 3]

@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir_local,
filename='test_succeeds.png',
tolerance=DEFAULT_TOLERANCE)
def test_succeeds(self):
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.plot(self.x)
return fig


# hashlib

@pytest.mark.skipif(not hash_library.exists(), reason="No hash library for this mpl version")
Expand Down Expand Up @@ -514,8 +532,27 @@ def test_fails(self):
return fig
"""

TEST_FAILING_UNITTEST_TESTCASE = """
from unittest import TestCase
import pytest
import matplotlib.pyplot as plt
class TestClassWithTestCase(TestCase):
def setUp(self):
self.x = [1, 2, 3]
@pytest.mark.mpl_image_compare
def test_fails(self):
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.plot(self.x)
return fig
"""

@pytest.mark.parametrize("code", [TEST_FAILING_CLASS, TEST_FAILING_CLASS_SETUP_METHOD])

@pytest.mark.parametrize("code", [
TEST_FAILING_CLASS,
TEST_FAILING_CLASS_SETUP_METHOD,
TEST_FAILING_UNITTEST_TESTCASE,
])
def test_class_fail(code, tmpdir):

test_file = tmpdir.join('test.py').strpath
Expand All @@ -529,3 +566,107 @@ def test_class_fail(code, tmpdir):
# If we don't use --mpl option, the test should succeed
code = call_pytest([test_file])
assert code == 0


@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)])
def test_user_fail(pytester, runpytest_args):
pytester.makepyfile(
"""
import pytest
@pytest.mark.mpl_image_compare
def test_fail():
pytest.fail("Manually failed by user.")
"""
)
result = pytester.runpytest(*runpytest_args)
result.assert_outcomes(failed=1)
result.stdout.fnmatch_lines("FAILED*Manually failed by user.*")


@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)])
def test_user_skip(pytester, runpytest_args):
pytester.makepyfile(
"""
import pytest
@pytest.mark.mpl_image_compare
def test_skip():
pytest.skip("Manually skipped by user.")
"""
)
result = pytester.runpytest(*runpytest_args)
result.assert_outcomes(skipped=1)


@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)])
def test_user_importorskip(pytester, runpytest_args):
pytester.makepyfile(
"""
import pytest
@pytest.mark.mpl_image_compare
def test_importorskip():
pytest.importorskip("nonexistantmodule")
"""
)
result = pytester.runpytest(*runpytest_args)
result.assert_outcomes(skipped=1)


@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)])
def test_user_xfail(pytester, runpytest_args):
pytester.makepyfile(
"""
import pytest
@pytest.mark.mpl_image_compare
def test_xfail():
pytest.xfail()
"""
)
result = pytester.runpytest(*runpytest_args)
result.assert_outcomes(xfailed=1)


@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)])
def test_user_exit_success(pytester, runpytest_args):
pytester.makepyfile(
"""
import pytest
@pytest.mark.mpl_image_compare
def test_exit_success():
pytest.exit("Manually exited by user.", returncode=0)
"""
)
result = pytester.runpytest(*runpytest_args)
result.assert_outcomes()
assert result.ret == 0
result.stdout.fnmatch_lines("*Exit*Manually exited by user.*")


@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)])
def test_user_exit_failure(pytester, runpytest_args):
pytester.makepyfile(
"""
import pytest
@pytest.mark.mpl_image_compare
def test_exit_fail():
pytest.exit("Manually exited by user.", returncode=1)
"""
)
result = pytester.runpytest(*runpytest_args)
result.assert_outcomes()
assert result.ret == 1
result.stdout.fnmatch_lines("*Exit*Manually exited by user.*")


@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)])
def test_user_function_raises(pytester, runpytest_args):
pytester.makepyfile(
"""
import pytest
@pytest.mark.mpl_image_compare
def test_raises():
raise ValueError("User code raised an exception.")
"""
)
result = pytester.runpytest(*runpytest_args)
result.assert_outcomes(failed=1)
result.stdout.fnmatch_lines("FAILED*ValueError*User code*")
5 changes: 0 additions & 5 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,3 @@ description = check code style, e.g. with flake8
deps = pre-commit
commands =
pre-commit run --all-files

[pytest]
markers =
image: run test during image comparison only mode.
hash: run test during hash comparison only mode.

0 comments on commit 310ee99

Please sign in to comment.