From b7f51ada6bdf20f586197d8762c013626d96b69b Mon Sep 17 00:00:00 2001 From: Mark Gensler Date: Mon, 29 Jan 2024 16:00:58 +0000 Subject: [PATCH] Added assertMessages() from django.contrib.messages. --- AUTHORS | 1 + docs/changelog.rst | 9 +++++++++ pytest_django/asserts.py | 28 +++++++++++++++++++++++++++- tests/test_asserts.py | 11 ++++++++++- 4 files changed, 47 insertions(+), 2 deletions(-) diff --git a/AUTHORS b/AUTHORS index 3f9b7ea6..060864a4 100644 --- a/AUTHORS +++ b/AUTHORS @@ -19,3 +19,4 @@ Donald Stufft Nicolas Delaby Hasan Ramezani Michael Howitz +Mark Gensler diff --git a/docs/changelog.rst b/docs/changelog.rst index de564d40..29f7ada0 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,6 +1,15 @@ Changelog ========= +Pending +------- + +Improvements +^^^^^^^^^^^^ + +* Added `pytest.asserts.assertMessages()` to mimic the behaviour of the + `django.contrib.messages.test.MessagesTestMixin` function for Django versions >= 5.0. + v4.7.0 (2023-11-08) ------------------- diff --git a/pytest_django/asserts.py b/pytest_django/asserts.py index f305fab0..cef2e6f8 100644 --- a/pytest_django/asserts.py +++ b/pytest_django/asserts.py @@ -6,10 +6,22 @@ from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Sequence +from django import VERSION from django.test import LiveServerTestCase, SimpleTestCase, TestCase, TransactionTestCase -test_case = TestCase("run") +use_contrib_messages = VERSION >= (5, 0) + +if use_contrib_messages: + from django.contrib.messages import Message + from django.contrib.messages.test import MessagesTestMixin + + class MessagesTestCase(MessagesTestMixin, TestCase): + pass + + test_case = MessagesTestCase("run") +else: + test_case = TestCase("run") def _wrapper(name: str): @@ -31,6 +43,11 @@ def assertion_func(*args, **kwargs): {attr for attr in vars(TransactionTestCase) if attr.startswith("assert")}, ) +if use_contrib_messages: + assertions_names.update( + {attr for attr in vars(MessagesTestMixin) if attr.startswith("assert")}, + ) + for assert_func in assertions_names: globals()[assert_func] = _wrapper(assert_func) __all__.append(assert_func) # noqa: PYI056 @@ -213,6 +230,15 @@ def assertNumQueries( ): ... + if use_contrib_messages: + + def assertMessages( + response: HttpResponseBase, + expected_messages: Sequence[Message], + ordered: bool = ..., + ) -> None: + ... + # Fallback in case Django adds new asserts. def __getattr__(name: str) -> Callable[..., Any]: ... diff --git a/tests/test_asserts.py b/tests/test_asserts.py index d8ef2455..7a2db7dc 100644 --- a/tests/test_asserts.py +++ b/tests/test_asserts.py @@ -17,9 +17,18 @@ def _get_actual_assertions_names() -> list[str]: """ from unittest import TestCase as DefaultTestCase + from django import VERSION from django.test import TestCase as DjangoTestCase - obj = DjangoTestCase("run") + if VERSION >= (5, 0): + from django.contrib.messages.test import MessagesTestMixin + + class MessagesTestCase(MessagesTestMixin, DjangoTestCase): + pass + + obj = MessagesTestCase("run") + else: + obj = DjangoTestCase("run") def is_assert(func) -> bool: return func.startswith("assert") and "_" not in func