diff --git a/query_optimizer/optimizer.py b/query_optimizer/optimizer.py index c92e1b5..e400d28 100644 --- a/query_optimizer/optimizer.py +++ b/query_optimizer/optimizer.py @@ -34,6 +34,7 @@ get_selections, get_underlying_type, is_foreign_key_id, + is_optimized, is_to_many, is_to_one, ) @@ -52,8 +53,10 @@ def optimize( queryset: QuerySet[TModel], info: GQLInfo, - max_complexity: Optional[int] = None, + *, pk: PK = None, + max_complexity: Optional[int] = None, + repopulate: bool = False, ) -> QuerySet[TModel]: """ Optimize the given queryset according to the field selections @@ -61,14 +64,17 @@ def optimize( :param queryset: Base queryset to optimize from. :param info: The GraphQLResolveInfo object used in the optimization process. - :param max_complexity: How many 'select_related' and 'prefetch_related' table joins are allowed. - Used to protect from malicious queries. :param pk: Primary key for an item in the queryset model. If set, optimizer will check the query cache for that primary key before making query. + :param max_complexity: How many 'select_related' and 'prefetch_related' table joins are allowed. + Used to protect from malicious queries. + :param repopulate: If True, repopulates the QuerySet._result_cache from the optimizer cache. + This should be used when additional filters are applied to the queryset + after optimization. :return: The optimized queryset. """ # Check if prior optimization has been done already - if queryset._hints.get(optimizer_settings.OPTIMIZER_MARK, False): # type: ignore[attr-defined] + if not repopulate and is_optimized(queryset): return queryset field_type = get_field_type(info) @@ -91,6 +97,20 @@ def optimize( queryset._result_cache = [cached_item] return queryset + if repopulate: + queryset._result_cache: list[TModel] = [] + for pk in queryset.values_list("pk", flat=True): + cached_item = get_from_query_cache(info.operation, info.schema, queryset.model, pk, store) + if cached_item is None: + msg = ( + f"Could not find '{queryset.model.__class__.__name__}' object with primary key " + f"'{pk}' from the optimizer cache. Check that the queryset results are narrowed " + f"and not expanded when repopulating." + ) + raise ValueError(msg) + queryset._result_cache.append(cached_item) + return queryset + queryset = store.optimize_queryset(queryset, pk=pk) if optimizer.cache_results: store_in_query_cache(info.operation, queryset, info.schema, store) @@ -296,8 +316,8 @@ def handle_to_one( model_field.model, ) - if isinstance(model_field, ForeignKey): # Add connecting entity - store.only_fields.append(model_field_name) + if isinstance(model_field, ForeignKey): + store.related_fields.append(model_field_name) store.select_stores[model_field_name] = nested_store @@ -319,8 +339,8 @@ def handle_to_many( model_field.model, ) - if isinstance(model_field, ManyToOneRel): # Add connecting entity - nested_store.only_fields.append(model_field.field.name) + if isinstance(model_field, ManyToOneRel): + nested_store.related_fields.append(model_field.field.name) related_queryset: QuerySet[Model] = model_field.related_model.objects.all() store.prefetch_stores[model_field_name] = nested_store, related_queryset diff --git a/query_optimizer/store.py b/query_optimizer/store.py index 6698d0c..7bd3c44 100644 --- a/query_optimizer/store.py +++ b/query_optimizer/store.py @@ -5,7 +5,7 @@ from .settings import optimizer_settings from .typing import PK, TypeVar -from .utils import unique +from .utils import mark_optimized, unique TModel = TypeVar("TModel", bound=Model) @@ -18,6 +18,7 @@ @dataclass class CompilationResults: only_fields: list[str] + related_fields: list[str] select_related: list[str] prefetch_related: list[Prefetch] @@ -28,12 +29,14 @@ class QueryOptimizerStore: def __init__(self, model: type[Model]) -> None: self.model = model self.only_fields: list[str] = [] + self.related_fields: list[str] = [] self.select_stores: dict[str, "QueryOptimizerStore"] = {} self.prefetch_stores: dict[str, tuple["QueryOptimizerStore", QuerySet[Model]]] = {} def compile(self, *, in_prefetch: bool = False) -> CompilationResults: # noqa: A003 results = CompilationResults( only_fields=self.only_fields.copy(), + related_fields=self.related_fields.copy(), select_related=[], prefetch_related=[], ) @@ -82,14 +85,12 @@ def optimize_queryset( queryset = queryset.prefetch_related(*results.prefetch_related) if results.select_related: queryset = queryset.select_related(*results.select_related) - if results.only_fields and not optimizer_settings.DISABLE_ONLY_FIELDS_OPTIMIZATION: - queryset = queryset.only(*results.only_fields) + if not optimizer_settings.DISABLE_ONLY_FIELDS_OPTIMIZATION and (results.only_fields or results.related_fields): + queryset = queryset.only(*results.only_fields, *results.related_fields) if pk is not None: queryset = queryset.filter(pk=pk) - # Mark queryset as optimized so that later optimizers know to skip optimization - queryset._hints[optimizer_settings.OPTIMIZER_MARK] = True # type: ignore[attr-defined] - + mark_optimized(queryset) return queryset @property @@ -103,6 +104,7 @@ def complexity(self) -> int: def __add__(self, other: "QueryOptimizerStore") -> "QueryOptimizerStore": self.only_fields += other.only_fields + self.related_fields += other.related_fields self.select_stores.update(other.select_stores) self.prefetch_stores.update(other.prefetch_stores) return self diff --git a/query_optimizer/types.py b/query_optimizer/types.py index 5031f91..b249b37 100644 --- a/query_optimizer/types.py +++ b/query_optimizer/types.py @@ -27,12 +27,12 @@ def max_complexity(cls) -> int: @classmethod def get_queryset(cls, queryset: QuerySet[TModel], info: GQLInfo) -> QuerySet[TModel]: if can_optimize(info): - queryset = optimize(queryset, info, cls.max_complexity()) + queryset = optimize(queryset, info, max_complexity=cls.max_complexity()) return queryset @classmethod def get_node(cls, info: GQLInfo, id: PK) -> Optional[TModel]: # noqa: A002 queryset: QuerySet[TModel] = cls._meta.model.objects.filter(pk=id) if can_optimize(info): - queryset = optimize(queryset, info, cls.max_complexity(), pk=id) + queryset = optimize(queryset, info, max_complexity=cls.max_complexity(), pk=id) return queryset.first() diff --git a/query_optimizer/utils.py b/query_optimizer/utils.py index 8d322d6..5834a45 100644 --- a/query_optimizer/utils.py +++ b/query_optimizer/utils.py @@ -1,10 +1,11 @@ -from django.db.models import ForeignKey +from django.db.models import ForeignKey, QuerySet from graphene import Connection from graphene.types.definitions import GrapheneObjectType from graphene_django import DjangoObjectType from graphql import GraphQLOutputType, SelectionNode from graphql.execution.execute import get_field_def +from .settings import optimizer_settings from .typing import Collection, GQLInfo, ModelField, ToManyField, ToOneField, TypeGuard, TypeVar __all__ = [ @@ -12,8 +13,11 @@ "get_selections", "get_underlying_type", "is_foreign_key_id", + "is_optimized", "is_to_many", "is_to_one", + "mark_optimized", + "mark_unoptimized", "unique", ] @@ -61,3 +65,18 @@ def can_optimize(info: GQLInfo) -> bool: return isinstance(return_type, GrapheneObjectType) and ( issubclass(return_type.graphene_type, (DjangoObjectType, Connection)) ) + + +def mark_optimized(queryset: QuerySet) -> None: + """Mark queryset as optimized so that later optimizers know to skip optimization""" + queryset._hints[optimizer_settings.OPTIMIZER_MARK] = True # type: ignore[attr-defined] + + +def mark_unoptimized(queryset: QuerySet) -> None: + """Mark queryset as unoptimized so that later optimizers will run optimization""" + queryset._hints.pop(optimizer_settings.OPTIMIZER_MARK, None) # type: ignore[attr-defined] + + +def is_optimized(queryset: QuerySet) -> bool: + """Has the queryset be optimized?""" + return queryset._hints.get(optimizer_settings.OPTIMIZER_MARK, False) # type: ignore[attr-defined] diff --git a/tests/example/types.py b/tests/example/types.py index ce01108..f84a1bb 100644 --- a/tests/example/types.py +++ b/tests/example/types.py @@ -4,8 +4,9 @@ from django_filters import CharFilter, FilterSet, OrderingFilter from graphene import relay -from query_optimizer import DjangoObjectType, required_fields +from query_optimizer import DjangoObjectType, optimize, required_fields from query_optimizer.typing import GQLInfo +from query_optimizer.utils import can_optimize from tests.example.models import ( Apartment, ApartmentProxy, @@ -103,6 +104,17 @@ class SaleType(DjangoObjectType): class Meta: model = Sale + @classmethod + def filter_queryset(cls, queryset: QuerySet, info: GQLInfo) -> QuerySet: + if can_optimize(info): + queryset = optimize( + queryset.filter(purchase_price__gte=1), + info, + max_complexity=cls.max_complexity(), + repopulate=True, + ) + return queryset + class OwnerType(DjangoObjectType): class Meta: diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 49950bf..053f789 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -6,7 +6,7 @@ from graphql_relay import to_global_id from tests.example.models import Apartment, Building, HousingCompany -from tests.example.types import ApartmentNode +from tests.example.types import ApartmentNode, SaleType from tests.example.utils import count_queries pytestmark = pytest.mark.django_db @@ -126,9 +126,7 @@ def test_optimizer_relay_node(client_query): } } } - """ % ( - global_id, - ) + """ % (global_id,) with count_queries() as results: response = client_query(query) @@ -168,9 +166,7 @@ def test_optimizer_relay_node_deep(client_query): } } } - """ % ( - global_id, - ) + """ % (global_id,) with count_queries() as results: response = client_query(query) @@ -285,9 +281,7 @@ def test_optimizer_relay_connection_filtering(client_query): } } } - """ % ( - street_address, - ) + """ % (street_address,) with count_queries() as results: response = client_query(query) @@ -323,9 +317,7 @@ def test_optimizer_relay_connection_filtering_nested(client_query): } } } - """ % ( - building_name, - ) + """ % (building_name,) with count_queries() as results: response = client_query(query) @@ -357,9 +349,7 @@ def test_optimizer_relay_connection_filtering_empty(client_query): } } } - """ % ( - "foo", - ) + """ % ("foo",) with count_queries() as results: response = client_query(query) @@ -663,9 +653,7 @@ def test_optimizer_filter(client_query): } } } - """ % ( - postal_code, - ) + """ % (postal_code,) with count_queries() as results: response = client_query(query) @@ -699,9 +687,7 @@ def test_optimizer_filter_to_many_relations(client_query): } } } - """ % ( - developer_name, - ) + """ % (developer_name,) with count_queries() as results: response = client_query(query) @@ -799,3 +785,43 @@ def test_optimizer_max_complexity_reached(client_query): queries = len(results.queries) assert queries == 0, results.message + + +def test_optimizer_many_to_one_relations__additional_filtering(client_query): + query = """ + query { + allApartments { + streetAddress + stair + apartmentNumber + sales { + purchaseDate + ownerships { + percentage + owner { + name + } + } + } + } + } + """ + + original_get_queryset = SaleType.get_queryset + try: + SaleType.get_queryset = SaleType.filter_queryset + with count_queries() as results: + response = client_query(query) + finally: + SaleType.get_queryset = original_get_queryset + + content = json.loads(response.content) + assert "errors" not in content, content["errors"] + assert "data" in content, content + assert "allApartments" in content["data"], content["data"] + apartments = content["data"]["allApartments"] + assert len(apartments) != 0, apartments + + queries = len(results.queries) + # Normal 3 queries, and an additional 40+ to filter the results in Sales.get_queryset + assert queries > 40, results.message