Skip to content

Commit 16ee779

Browse files
committed
Add pytest_django.DjangoAssertNumQueries for typing purposes
This allows typing the `django_assert_num_queries` and `django_assert_max_num_queries` fixtures.
1 parent 28484f4 commit 16ee779

File tree

4 files changed

+45
-9
lines changed

4 files changed

+45
-9
lines changed

docs/helpers.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,15 @@ Example usage::
447447

448448
assert 'foo' in captured.captured_queries[0]['sql']
449449

450+
If you use type annotations, you can annotate the fixture like this::
451+
452+
from pytest_django import DjangoAssertNumQueries
453+
454+
def test_num_queries(
455+
django_assert_num_queries: DjangoAssertNumQueries,
456+
):
457+
...
458+
450459

451460
.. fixture:: django_assert_max_num_queries
452461

@@ -470,6 +479,15 @@ Example usage::
470479
Item.objects.create('foo')
471480
Item.objects.create('bar')
472481

482+
If you use type annotations, you can annotate the fixture like this::
483+
484+
from pytest_django import DjangoAssertNumQueries
485+
486+
def test_max_num_queries(
487+
django_assert_max_num_queries: DjangoAssertNumQueries,
488+
):
489+
...
490+
473491

474492
.. fixture:: django_capture_on_commit_callbacks
475493

pytest_django/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
__version__ = "unknown"
66

77

8-
from .fixtures import DjangoCaptureOnCommitCallbacks
8+
from .fixtures import DjangoAssertNumQueries, DjangoCaptureOnCommitCallbacks
99
from .plugin import DjangoDbBlocker
1010

1111

1212
__all__ = [
1313
"__version__",
14+
"DjangoAssertNumQueries",
1415
"DjangoCaptureOnCommitCallbacks",
1516
"DjangoDbBlocker",
1617
]

pytest_django/fixtures.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -601,12 +601,25 @@ def _live_server_helper(request: pytest.FixtureRequest) -> Generator[None, None,
601601
live_server._live_server_modified_settings.disable()
602602

603603

604+
class DjangoAssertNumQueries(Protocol):
605+
"""The type of the `django_assert_num_queries` and
606+
`django_assert_max_num_queries` fixtures."""
607+
608+
def __call__(
609+
self,
610+
num: int,
611+
connection: Any | None = ...,
612+
info: str | None = ...,
613+
) -> django.test.utils.CaptureQueriesContext:
614+
pass # pragma: no cover
615+
616+
604617
@contextmanager
605618
def _assert_num_queries(
606619
config: pytest.Config,
607620
num: int,
608621
exact: bool = True,
609-
connection=None,
622+
connection: Any | None = None,
610623
info: str | None = None,
611624
) -> Generator[django.test.utils.CaptureQueriesContext, None, None]:
612625
from django.test.utils import CaptureQueriesContext
@@ -641,12 +654,12 @@ def _assert_num_queries(
641654

642655

643656
@pytest.fixture()
644-
def django_assert_num_queries(pytestconfig: pytest.Config):
657+
def django_assert_num_queries(pytestconfig: pytest.Config) -> DjangoAssertNumQueries:
645658
return partial(_assert_num_queries, pytestconfig)
646659

647660

648661
@pytest.fixture()
649-
def django_assert_max_num_queries(pytestconfig: pytest.Config):
662+
def django_assert_max_num_queries(pytestconfig: pytest.Config) -> DjangoAssertNumQueries:
650663
return partial(_assert_num_queries, pytestconfig, exact=False)
651664

652665

tests/test_fixtures.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from .helpers import DjangoPytester
2020

21-
from pytest_django import DjangoCaptureOnCommitCallbacks, DjangoDbBlocker
21+
from pytest_django import DjangoAssertNumQueries, DjangoCaptureOnCommitCallbacks, DjangoDbBlocker
2222
from pytest_django_test.app.models import Item
2323

2424

@@ -91,7 +91,7 @@ def test_async_rf(async_rf: AsyncRequestFactory) -> None:
9191
@pytest.mark.django_db
9292
def test_django_assert_num_queries_db(
9393
request: pytest.FixtureRequest,
94-
django_assert_num_queries,
94+
django_assert_num_queries: DjangoAssertNumQueries,
9595
) -> None:
9696
with nonverbose_config(request.config):
9797
with django_assert_num_queries(3):
@@ -111,7 +111,7 @@ def test_django_assert_num_queries_db(
111111
@pytest.mark.django_db
112112
def test_django_assert_max_num_queries_db(
113113
request: pytest.FixtureRequest,
114-
django_assert_max_num_queries,
114+
django_assert_max_num_queries: DjangoAssertNumQueries,
115115
) -> None:
116116
with nonverbose_config(request.config):
117117
with django_assert_max_num_queries(2):
@@ -134,7 +134,9 @@ def test_django_assert_max_num_queries_db(
134134

135135
@pytest.mark.django_db(transaction=True)
136136
def test_django_assert_num_queries_transactional_db(
137-
request: pytest.FixtureRequest, transactional_db: None, django_assert_num_queries
137+
request: pytest.FixtureRequest,
138+
transactional_db: None,
139+
django_assert_num_queries: DjangoAssertNumQueries,
138140
) -> None:
139141
with nonverbose_config(request.config):
140142
with transaction.atomic():
@@ -187,7 +189,9 @@ def test_queries(django_assert_num_queries):
187189

188190

189191
@pytest.mark.django_db
190-
def test_django_assert_num_queries_db_connection(django_assert_num_queries) -> None:
192+
def test_django_assert_num_queries_db_connection(
193+
django_assert_num_queries: DjangoAssertNumQueries,
194+
) -> None:
191195
from django.db import connection
192196

193197
with django_assert_num_queries(1, connection=connection):

0 commit comments

Comments
 (0)