Skip to content

Commit

Permalink
Add hook for doing additional filtering on object type
Browse files Browse the repository at this point in the history
  • Loading branch information
MrThearMan committed Dec 18, 2023
1 parent e8706c1 commit 7ec2d0c
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 79 deletions.
49 changes: 28 additions & 21 deletions query_optimizer/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import TYPE_CHECKING

from django.core.exceptions import FieldDoesNotExist
from django.db.models import ForeignKey, ManyToOneRel, Model, QuerySet
from graphene.relay.connection import ConnectionOptions
from graphene.types.definitions import GrapheneObjectType, GrapheneUnionType
from graphene.utils.str_converters import to_snake_case
from graphene_django.registry import get_global_registry
from graphene_django.types import DjangoObjectTypeOptions
from graphql import (
FieldNode,
Expand Down Expand Up @@ -39,6 +42,10 @@
is_to_one,
)

if TYPE_CHECKING:
from graphene import ObjectType


TModel = TypeVar("TModel", bound=Model)
TCallable = TypeVar("TCallable", bound=Callable)

Expand All @@ -56,7 +63,6 @@ def optimize(
*,
pk: PK = None,
max_complexity: Optional[int] = None,
repopulate: bool = False,
) -> QuerySet[TModel]:
"""
Optimize the given queryset according to the field selections
Expand All @@ -68,13 +74,10 @@ def optimize(
the query cache for that primary key before making query.
:param max_complexity: How many 'select_related' and 'prefetch_related' table joins are allowed.
Used to protect from malicious queries.
:param repopulate: If True, repopulates the QuerySet._result_cache from the optimizer cache.
This should be used when additional filters are applied to the queryset
after optimization.
:return: The optimized queryset.
"""
# Check if prior optimization has been done already
if not repopulate and is_optimized(queryset):
if is_optimized(queryset):
return queryset

field_type = get_field_type(info)
Expand All @@ -97,21 +100,13 @@ def optimize(
queryset._result_cache = [cached_item]
return queryset

if repopulate:
queryset._result_cache: list[TModel] = []
for pk in queryset.values_list("pk", flat=True):
cached_item = get_from_query_cache(info.operation, info.schema, queryset.model, pk, store)
if cached_item is None:
msg = (
f"Could not find '{queryset.model.__class__.__name__}' object with primary key "
f"'{pk}' from the optimizer cache. Check that the queryset results are narrowed "
f"and not expanded when repopulating."
)
raise ValueError(msg)
queryset._result_cache.append(cached_item)
return queryset

queryset = store.optimize_queryset(queryset, pk=pk)

# Apply custom filtering based on the ObjectType
object_type: ObjectType | None = get_global_registry().get_type_for_model(queryset.model)
if hasattr(object_type, "filter_queryset") and callable(object_type.filter_queryset):
queryset = object_type.filter_queryset(queryset, info)

if optimizer.cache_results:
store_in_query_cache(info.operation, queryset, info.schema, store)

Expand Down Expand Up @@ -250,7 +245,8 @@ def find_field_from_model(

nested_store = QueryOptimizerStore(model=related_model)
if is_to_many(model_field):
store.prefetch_stores[model_field.name] = (nested_store, related_model.objects.all())
queryset = self.get_filtered_queryset(related_model)
store.prefetch_stores[model_field.name] = (nested_store, queryset)
elif is_to_one(model_field):
store.select_stores[model_field.name] = nested_store
else: # pragma: no cover
Expand Down Expand Up @@ -342,9 +338,20 @@ def handle_to_many(
if isinstance(model_field, ManyToOneRel):
nested_store.related_fields.append(model_field.field.name)

related_queryset: QuerySet[Model] = model_field.related_model.objects.all()
related_model: type[Model] = model_field.related_model # type: ignore[assignment]
if related_model == "self": # pragma: no cover
related_model = model_field.model

related_queryset = self.get_filtered_queryset(related_model)
store.prefetch_stores[model_field_name] = nested_store, related_queryset

def get_filtered_queryset(self, model: type[TModel]) -> QuerySet[TModel]:
qs: QuerySet = model.objects.all()
object_type: ObjectType | None = get_global_registry().get_type_for_model(model)
if hasattr(object_type, "filter_queryset") and callable(object_type.filter_queryset):
qs = object_type.filter_queryset(qs, self.info)
return qs

def optimize_fragment_spread(
self,
field_type: GrapheneObjectType,
Expand Down
22 changes: 16 additions & 6 deletions query_optimizer/typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Callable, Collection, Hashable, Iterable, NamedTuple, Optional, TypeVar, Union
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Collection, Hashable, Iterable, NamedTuple, Optional, TypeVar, Union

from graphene.relay.connection import ConnectionOptions
from graphene_django.types import DjangoObjectTypeOptions
Expand All @@ -10,7 +12,6 @@
from typing_extensions import TypeAlias, TypeGuard


import graphql
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
from django.core.handlers.wsgi import WSGIRequest
from django.db.models import (
Expand All @@ -24,6 +25,10 @@
Model,
OneToOneField,
)
from graphql import GraphQLResolveInfo

if TYPE_CHECKING:
from django.contrib.auth.models import AnonymousUser, User

__all__ = [
"Any",
Expand All @@ -48,10 +53,6 @@
]


class GQLInfo(graphql.GraphQLResolveInfo):
context: WSGIRequest


TModel = TypeVar("TModel", bound=Model)
TableName: TypeAlias = str
StoreStr: TypeAlias = str
Expand All @@ -61,3 +62,12 @@ class GQLInfo(graphql.GraphQLResolveInfo):
ToManyField: TypeAlias = Union[GenericRelation, ManyToManyField, ManyToOneRel, ManyToManyRel]
ToOneField: TypeAlias = Union[GenericRelation, ForeignObject, ForeignKey, OneToOneField]
TypeOptions: TypeAlias = Union[DjangoObjectTypeOptions, ConnectionOptions]
AnyUser: TypeAlias = Union["User", "AnonymousUser"]


class UserHintedWSGIRequest(WSGIRequest):
user: AnyUser


class GQLInfo(GraphQLResolveInfo):
context: UserHintedWSGIRequest
2 changes: 1 addition & 1 deletion query_optimizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def mark_optimized(queryset: QuerySet) -> None:
queryset._hints[optimizer_settings.OPTIMIZER_MARK] = True # type: ignore[attr-defined]


def mark_unoptimized(queryset: QuerySet) -> None:
def mark_unoptimized(queryset: QuerySet) -> None: # pragma: no cover
"""Mark queryset as unoptimized so that later optimizers will run optimization"""
queryset._hints.pop(optimizer_settings.OPTIMIZER_MARK, None) # type: ignore[attr-defined]

Expand Down
16 changes: 6 additions & 10 deletions tests/example/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from django_filters import CharFilter, FilterSet, OrderingFilter
from graphene import relay

from query_optimizer import DjangoObjectType, optimize, required_fields
from query_optimizer import DjangoObjectType, required_fields
from query_optimizer.typing import GQLInfo
from query_optimizer.utils import can_optimize
from tests.example.models import (
Apartment,
ApartmentProxy,
Expand Down Expand Up @@ -96,6 +95,10 @@ class ApartmentType(DjangoObjectType):
def max_complexity(cls) -> int:
return 10

@classmethod
def filter_queryset(cls, queryset: QuerySet, info: GQLInfo) -> QuerySet:
return queryset.filter(rooms__isnull=False)

class Meta:
model = Apartment

Expand All @@ -106,14 +109,7 @@ class Meta:

@classmethod
def filter_queryset(cls, queryset: QuerySet, info: GQLInfo) -> QuerySet:
if can_optimize(info):
queryset = optimize(
queryset.filter(purchase_price__gte=1),
info,
max_complexity=cls.max_complexity(),
repopulate=True,
)
return queryset
return queryset.filter(purchase_price__gte=1)


class OwnerType(DjangoObjectType):
Expand Down
42 changes: 1 addition & 41 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from graphql_relay import to_global_id

from tests.example.models import Apartment, Building, HousingCompany
from tests.example.types import ApartmentNode, SaleType
from tests.example.types import ApartmentNode
from tests.example.utils import count_queries

pytestmark = pytest.mark.django_db
Expand Down Expand Up @@ -785,43 +785,3 @@ def test_optimizer_max_complexity_reached(client_query):

queries = len(results.queries)
assert queries == 0, results.message


def test_optimizer_many_to_one_relations__additional_filtering(client_query):
query = """
query {
allApartments {
streetAddress
stair
apartmentNumber
sales {
purchaseDate
ownerships {
percentage
owner {
name
}
}
}
}
}
"""

original_get_queryset = SaleType.get_queryset
try:
SaleType.get_queryset = SaleType.filter_queryset
with count_queries() as results:
response = client_query(query)
finally:
SaleType.get_queryset = original_get_queryset

content = json.loads(response.content)
assert "errors" not in content, content["errors"]
assert "data" in content, content
assert "allApartments" in content["data"], content["data"]
apartments = content["data"]["allApartments"]
assert len(apartments) != 0, apartments

queries = len(results.queries)
# Normal 3 queries, and an additional 40+ to filter the results in Sales.get_queryset
assert queries > 40, results.message

0 comments on commit 7ec2d0c

Please sign in to comment.