Skip to content

Commit

Permalink
Add pytest_django.DjangoAssertNumQueries for typing purposes
Browse files Browse the repository at this point in the history
This allows typing the `django_assert_num_queries` and
`django_assert_max_num_queries` fixtures.
  • Loading branch information
bluetech committed Nov 8, 2023
1 parent 28484f4 commit 16ee779
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 9 deletions.
18 changes: 18 additions & 0 deletions docs/helpers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,15 @@ Example usage::

assert 'foo' in captured.captured_queries[0]['sql']

If you use type annotations, you can annotate the fixture like this::

from pytest_django import DjangoAssertNumQueries

def test_num_queries(
django_assert_num_queries: DjangoAssertNumQueries,
):
...


.. fixture:: django_assert_max_num_queries

Expand All @@ -470,6 +479,15 @@ Example usage::
Item.objects.create('foo')
Item.objects.create('bar')

If you use type annotations, you can annotate the fixture like this::

from pytest_django import DjangoAssertNumQueries

def test_max_num_queries(
django_assert_max_num_queries: DjangoAssertNumQueries,
):
...


.. fixture:: django_capture_on_commit_callbacks

Expand Down
3 changes: 2 additions & 1 deletion pytest_django/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
__version__ = "unknown"


from .fixtures import DjangoCaptureOnCommitCallbacks
from .fixtures import DjangoAssertNumQueries, DjangoCaptureOnCommitCallbacks
from .plugin import DjangoDbBlocker


__all__ = [
"__version__",
"DjangoAssertNumQueries",
"DjangoCaptureOnCommitCallbacks",
"DjangoDbBlocker",
]
19 changes: 16 additions & 3 deletions pytest_django/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,12 +601,25 @@ def _live_server_helper(request: pytest.FixtureRequest) -> Generator[None, None,
live_server._live_server_modified_settings.disable()


class DjangoAssertNumQueries(Protocol):
"""The type of the `django_assert_num_queries` and
`django_assert_max_num_queries` fixtures."""

def __call__(
self,
num: int,
connection: Any | None = ...,
info: str | None = ...,
) -> django.test.utils.CaptureQueriesContext:
pass # pragma: no cover


@contextmanager
def _assert_num_queries(
config: pytest.Config,
num: int,
exact: bool = True,
connection=None,
connection: Any | None = None,
info: str | None = None,
) -> Generator[django.test.utils.CaptureQueriesContext, None, None]:
from django.test.utils import CaptureQueriesContext
Expand Down Expand Up @@ -641,12 +654,12 @@ def _assert_num_queries(


@pytest.fixture()
def django_assert_num_queries(pytestconfig: pytest.Config):
def django_assert_num_queries(pytestconfig: pytest.Config) -> DjangoAssertNumQueries:
return partial(_assert_num_queries, pytestconfig)


@pytest.fixture()
def django_assert_max_num_queries(pytestconfig: pytest.Config):
def django_assert_max_num_queries(pytestconfig: pytest.Config) -> DjangoAssertNumQueries:
return partial(_assert_num_queries, pytestconfig, exact=False)


Expand Down
14 changes: 9 additions & 5 deletions tests/test_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from .helpers import DjangoPytester

from pytest_django import DjangoCaptureOnCommitCallbacks, DjangoDbBlocker
from pytest_django import DjangoAssertNumQueries, DjangoCaptureOnCommitCallbacks, DjangoDbBlocker
from pytest_django_test.app.models import Item


Expand Down Expand Up @@ -91,7 +91,7 @@ def test_async_rf(async_rf: AsyncRequestFactory) -> None:
@pytest.mark.django_db
def test_django_assert_num_queries_db(
request: pytest.FixtureRequest,
django_assert_num_queries,
django_assert_num_queries: DjangoAssertNumQueries,
) -> None:
with nonverbose_config(request.config):
with django_assert_num_queries(3):
Expand All @@ -111,7 +111,7 @@ def test_django_assert_num_queries_db(
@pytest.mark.django_db
def test_django_assert_max_num_queries_db(
request: pytest.FixtureRequest,
django_assert_max_num_queries,
django_assert_max_num_queries: DjangoAssertNumQueries,
) -> None:
with nonverbose_config(request.config):
with django_assert_max_num_queries(2):
Expand All @@ -134,7 +134,9 @@ def test_django_assert_max_num_queries_db(

@pytest.mark.django_db(transaction=True)
def test_django_assert_num_queries_transactional_db(
request: pytest.FixtureRequest, transactional_db: None, django_assert_num_queries
request: pytest.FixtureRequest,
transactional_db: None,
django_assert_num_queries: DjangoAssertNumQueries,
) -> None:
with nonverbose_config(request.config):
with transaction.atomic():
Expand Down Expand Up @@ -187,7 +189,9 @@ def test_queries(django_assert_num_queries):


@pytest.mark.django_db
def test_django_assert_num_queries_db_connection(django_assert_num_queries) -> None:
def test_django_assert_num_queries_db_connection(
django_assert_num_queries: DjangoAssertNumQueries,
) -> None:
from django.db import connection

with django_assert_num_queries(1, connection=connection):
Expand Down

0 comments on commit 16ee779

Please sign in to comment.