Skip to content

Commit

Permalink
Mock plt.subplots which has multiple outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
hunse committed Dec 7, 2023
1 parent ff174b2 commit 9cea29c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
15 changes: 11 additions & 4 deletions pytest_plt/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def pytest_report_teststatus(report):


class Mock:
multi_functions = {}

def __init__(self, *args, **kwargs):
pass

Expand All @@ -81,10 +83,16 @@ def __getattr__(cls, name):
mockType = type(name, (), {})
mockType.__module__ = __name__
return mockType
elif name in cls.multi_functions:
return lambda *args, **kwargs: tuple(Mock() for _ in range(cls.multi_functions[name]))
else:
return Mock()


class PltMock(Mock):
multi_functions = {"subplots": 2}


class Recorder:
def __init__(self, dirname, nodeid, filename_drop=None):
self.dirname = dirname
Expand Down Expand Up @@ -139,7 +147,7 @@ def __enter__(self):
if self.record:
self.plt = mpl_plt
else:
self.plt = Mock()
self.plt = PltMock()
self.plt.saveas = self.get_filename(ext="pdf")
return self.plt

Expand Down Expand Up @@ -174,8 +182,7 @@ def save(self, path):

@pytest.fixture
def plt(request):
"""
A pyplot-compatible plotting interface.
"""A pyplot-compatible plotting interface.
Use this to create plots in your tests using the ``matplotlib.pyplot``
interface.
Expand Down Expand Up @@ -210,4 +217,4 @@ def _finalize():
request.node.user_properties.append(("plt_saved", plotter.saved))

request.addfinalizer(_finalize)
return plotter.__enter__() # pylint: disable=unnecessary-dunder-call
return plotter.__enter__()
7 changes: 7 additions & 0 deletions pytest_plt/tests/test_plt.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ def test_mock_iter(plt):
plt.saveas = None


def test_mock_subplots(plt):
fig, axes = plt.subplots(2, 1)
axes[0].plot(np.arange(10))
axes[1].plot(-np.arange(10))
fig.tight_layout()


def test_simple_plot(plt):
plt.plot(np.linspace(0, 1, 20), np.linspace(0, 2, 20))

Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def read(*filenames, **kwargs):
"sphinx",
]
optional_req = []
tests_req = []
tests_req = [
"pytest-cov", # required since `addopts = --cov` in setup.cfg
]

setup(
name="pytest-plt",
Expand Down

0 comments on commit 9cea29c

Please sign in to comment.