Skip to content

Commit

Permalink
Added assertMessages() from django.contrib.messages.
Browse files Browse the repository at this point in the history
  • Loading branch information
sdolemelipone committed Jan 29, 2024
1 parent 5283aa4 commit b7f51ad
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 2 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ Donald Stufft <[email protected]>
Nicolas Delaby <[email protected]>
Hasan Ramezani <[email protected]>
Michael Howitz
Mark Gensler
9 changes: 9 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
Changelog
=========

Pending
-------

Improvements
^^^^^^^^^^^^

* Added `pytest.asserts.assertMessages()` to mimic the behaviour of the
`django.contrib.messages.test.MessagesTestMixin` function for Django versions >= 5.0.

v4.7.0 (2023-11-08)
-------------------

Expand Down
28 changes: 27 additions & 1 deletion pytest_django/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,22 @@
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Sequence

from django import VERSION
from django.test import LiveServerTestCase, SimpleTestCase, TestCase, TransactionTestCase


test_case = TestCase("run")
use_contrib_messages = VERSION >= (5, 0)

if use_contrib_messages:
from django.contrib.messages import Message
from django.contrib.messages.test import MessagesTestMixin

class MessagesTestCase(MessagesTestMixin, TestCase):
pass

test_case = MessagesTestCase("run")
else:
test_case = TestCase("run")


def _wrapper(name: str):
Expand All @@ -31,6 +43,11 @@ def assertion_func(*args, **kwargs):
{attr for attr in vars(TransactionTestCase) if attr.startswith("assert")},
)

if use_contrib_messages:
assertions_names.update(
{attr for attr in vars(MessagesTestMixin) if attr.startswith("assert")},
)

for assert_func in assertions_names:
globals()[assert_func] = _wrapper(assert_func)
__all__.append(assert_func) # noqa: PYI056
Expand Down Expand Up @@ -213,6 +230,15 @@ def assertNumQueries(
):
...

if use_contrib_messages:

def assertMessages(
response: HttpResponseBase,
expected_messages: Sequence[Message],
ordered: bool = ...,
) -> None:
...

# Fallback in case Django adds new asserts.
def __getattr__(name: str) -> Callable[..., Any]:
...
11 changes: 10 additions & 1 deletion tests/test_asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,18 @@ def _get_actual_assertions_names() -> list[str]:
"""
from unittest import TestCase as DefaultTestCase

from django import VERSION
from django.test import TestCase as DjangoTestCase

obj = DjangoTestCase("run")
if VERSION >= (5, 0):
from django.contrib.messages.test import MessagesTestMixin

class MessagesTestCase(MessagesTestMixin, DjangoTestCase):
pass

obj = MessagesTestCase("run")
else:
obj = DjangoTestCase("run")

def is_assert(func) -> bool:
return func.startswith("assert") and "_" not in func
Expand Down

0 comments on commit b7f51ad

Please sign in to comment.