From 7ec2d0cadc9fc556f0238427820557ea976009cb Mon Sep 17 00:00:00 2001 From: Matti Lamppu Date: Mon, 18 Dec 2023 21:38:48 +0200 Subject: [PATCH] Add hook for doing additional filtering on object type --- query_optimizer/optimizer.py | 49 ++++++++++++++++++++---------------- query_optimizer/typing.py | 22 +++++++++++----- query_optimizer/utils.py | 2 +- tests/example/types.py | 16 +++++------- tests/test_optimizer.py | 42 +------------------------------ 5 files changed, 52 insertions(+), 79 deletions(-) diff --git a/query_optimizer/optimizer.py b/query_optimizer/optimizer.py index e400d28..ea391e6 100644 --- a/query_optimizer/optimizer.py +++ b/query_optimizer/optimizer.py @@ -1,8 +1,11 @@ +from typing import TYPE_CHECKING + from django.core.exceptions import FieldDoesNotExist from django.db.models import ForeignKey, ManyToOneRel, Model, QuerySet from graphene.relay.connection import ConnectionOptions from graphene.types.definitions import GrapheneObjectType, GrapheneUnionType from graphene.utils.str_converters import to_snake_case +from graphene_django.registry import get_global_registry from graphene_django.types import DjangoObjectTypeOptions from graphql import ( FieldNode, @@ -39,6 +42,10 @@ is_to_one, ) +if TYPE_CHECKING: + from graphene import ObjectType + + TModel = TypeVar("TModel", bound=Model) TCallable = TypeVar("TCallable", bound=Callable) @@ -56,7 +63,6 @@ def optimize( *, pk: PK = None, max_complexity: Optional[int] = None, - repopulate: bool = False, ) -> QuerySet[TModel]: """ Optimize the given queryset according to the field selections @@ -68,13 +74,10 @@ def optimize( the query cache for that primary key before making query. :param max_complexity: How many 'select_related' and 'prefetch_related' table joins are allowed. Used to protect from malicious queries. - :param repopulate: If True, repopulates the QuerySet._result_cache from the optimizer cache. - This should be used when additional filters are applied to the queryset - after optimization. :return: The optimized queryset. """ # Check if prior optimization has been done already - if not repopulate and is_optimized(queryset): + if is_optimized(queryset): return queryset field_type = get_field_type(info) @@ -97,21 +100,13 @@ def optimize( queryset._result_cache = [cached_item] return queryset - if repopulate: - queryset._result_cache: list[TModel] = [] - for pk in queryset.values_list("pk", flat=True): - cached_item = get_from_query_cache(info.operation, info.schema, queryset.model, pk, store) - if cached_item is None: - msg = ( - f"Could not find '{queryset.model.__class__.__name__}' object with primary key " - f"'{pk}' from the optimizer cache. Check that the queryset results are narrowed " - f"and not expanded when repopulating." - ) - raise ValueError(msg) - queryset._result_cache.append(cached_item) - return queryset - queryset = store.optimize_queryset(queryset, pk=pk) + + # Apply custom filtering based on the ObjectType + object_type: ObjectType | None = get_global_registry().get_type_for_model(queryset.model) + if hasattr(object_type, "filter_queryset") and callable(object_type.filter_queryset): + queryset = object_type.filter_queryset(queryset, info) + if optimizer.cache_results: store_in_query_cache(info.operation, queryset, info.schema, store) @@ -250,7 +245,8 @@ def find_field_from_model( nested_store = QueryOptimizerStore(model=related_model) if is_to_many(model_field): - store.prefetch_stores[model_field.name] = (nested_store, related_model.objects.all()) + queryset = self.get_filtered_queryset(related_model) + store.prefetch_stores[model_field.name] = (nested_store, queryset) elif is_to_one(model_field): store.select_stores[model_field.name] = nested_store else: # pragma: no cover @@ -342,9 +338,20 @@ def handle_to_many( if isinstance(model_field, ManyToOneRel): nested_store.related_fields.append(model_field.field.name) - related_queryset: QuerySet[Model] = model_field.related_model.objects.all() + related_model: type[Model] = model_field.related_model # type: ignore[assignment] + if related_model == "self": # pragma: no cover + related_model = model_field.model + + related_queryset = self.get_filtered_queryset(related_model) store.prefetch_stores[model_field_name] = nested_store, related_queryset + def get_filtered_queryset(self, model: type[TModel]) -> QuerySet[TModel]: + qs: QuerySet = model.objects.all() + object_type: ObjectType | None = get_global_registry().get_type_for_model(model) + if hasattr(object_type, "filter_queryset") and callable(object_type.filter_queryset): + qs = object_type.filter_queryset(qs, self.info) + return qs + def optimize_fragment_spread( self, field_type: GrapheneObjectType, diff --git a/query_optimizer/typing.py b/query_optimizer/typing.py index bbb9558..991906b 100644 --- a/query_optimizer/typing.py +++ b/query_optimizer/typing.py @@ -1,4 +1,6 @@ -from typing import Any, Callable, Collection, Hashable, Iterable, NamedTuple, Optional, TypeVar, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Collection, Hashable, Iterable, NamedTuple, Optional, TypeVar, Union from graphene.relay.connection import ConnectionOptions from graphene_django.types import DjangoObjectTypeOptions @@ -10,7 +12,6 @@ from typing_extensions import TypeAlias, TypeGuard -import graphql from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation from django.core.handlers.wsgi import WSGIRequest from django.db.models import ( @@ -24,6 +25,10 @@ Model, OneToOneField, ) +from graphql import GraphQLResolveInfo + +if TYPE_CHECKING: + from django.contrib.auth.models import AnonymousUser, User __all__ = [ "Any", @@ -48,10 +53,6 @@ ] -class GQLInfo(graphql.GraphQLResolveInfo): - context: WSGIRequest - - TModel = TypeVar("TModel", bound=Model) TableName: TypeAlias = str StoreStr: TypeAlias = str @@ -61,3 +62,12 @@ class GQLInfo(graphql.GraphQLResolveInfo): ToManyField: TypeAlias = Union[GenericRelation, ManyToManyField, ManyToOneRel, ManyToManyRel] ToOneField: TypeAlias = Union[GenericRelation, ForeignObject, ForeignKey, OneToOneField] TypeOptions: TypeAlias = Union[DjangoObjectTypeOptions, ConnectionOptions] +AnyUser: TypeAlias = Union["User", "AnonymousUser"] + + +class UserHintedWSGIRequest(WSGIRequest): + user: AnyUser + + +class GQLInfo(GraphQLResolveInfo): + context: UserHintedWSGIRequest diff --git a/query_optimizer/utils.py b/query_optimizer/utils.py index 5834a45..2352790 100644 --- a/query_optimizer/utils.py +++ b/query_optimizer/utils.py @@ -72,7 +72,7 @@ def mark_optimized(queryset: QuerySet) -> None: queryset._hints[optimizer_settings.OPTIMIZER_MARK] = True # type: ignore[attr-defined] -def mark_unoptimized(queryset: QuerySet) -> None: +def mark_unoptimized(queryset: QuerySet) -> None: # pragma: no cover """Mark queryset as unoptimized so that later optimizers will run optimization""" queryset._hints.pop(optimizer_settings.OPTIMIZER_MARK, None) # type: ignore[attr-defined] diff --git a/tests/example/types.py b/tests/example/types.py index f84a1bb..15cfcb4 100644 --- a/tests/example/types.py +++ b/tests/example/types.py @@ -4,9 +4,8 @@ from django_filters import CharFilter, FilterSet, OrderingFilter from graphene import relay -from query_optimizer import DjangoObjectType, optimize, required_fields +from query_optimizer import DjangoObjectType, required_fields from query_optimizer.typing import GQLInfo -from query_optimizer.utils import can_optimize from tests.example.models import ( Apartment, ApartmentProxy, @@ -96,6 +95,10 @@ class ApartmentType(DjangoObjectType): def max_complexity(cls) -> int: return 10 + @classmethod + def filter_queryset(cls, queryset: QuerySet, info: GQLInfo) -> QuerySet: + return queryset.filter(rooms__isnull=False) + class Meta: model = Apartment @@ -106,14 +109,7 @@ class Meta: @classmethod def filter_queryset(cls, queryset: QuerySet, info: GQLInfo) -> QuerySet: - if can_optimize(info): - queryset = optimize( - queryset.filter(purchase_price__gte=1), - info, - max_complexity=cls.max_complexity(), - repopulate=True, - ) - return queryset + return queryset.filter(purchase_price__gte=1) class OwnerType(DjangoObjectType): diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 053f789..7ebd73a 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -6,7 +6,7 @@ from graphql_relay import to_global_id from tests.example.models import Apartment, Building, HousingCompany -from tests.example.types import ApartmentNode, SaleType +from tests.example.types import ApartmentNode from tests.example.utils import count_queries pytestmark = pytest.mark.django_db @@ -785,43 +785,3 @@ def test_optimizer_max_complexity_reached(client_query): queries = len(results.queries) assert queries == 0, results.message - - -def test_optimizer_many_to_one_relations__additional_filtering(client_query): - query = """ - query { - allApartments { - streetAddress - stair - apartmentNumber - sales { - purchaseDate - ownerships { - percentage - owner { - name - } - } - } - } - } - """ - - original_get_queryset = SaleType.get_queryset - try: - SaleType.get_queryset = SaleType.filter_queryset - with count_queries() as results: - response = client_query(query) - finally: - SaleType.get_queryset = original_get_queryset - - content = json.loads(response.content) - assert "errors" not in content, content["errors"] - assert "data" in content, content - assert "allApartments" in content["data"], content["data"] - apartments = content["data"]["allApartments"] - assert len(apartments) != 0, apartments - - queries = len(results.queries) - # Normal 3 queries, and an additional 40+ to filter the results in Sales.get_queryset - assert queries > 40, results.message