Skip to content

Commit

Permalink
Initial fix to filtering object type queryset
Browse files Browse the repository at this point in the history
  • Loading branch information
MrThearMan committed Dec 18, 2023
1 parent 454d2fb commit e8706c1
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 40 deletions.
36 changes: 28 additions & 8 deletions query_optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
get_selections,
get_underlying_type,
is_foreign_key_id,
is_optimized,
is_to_many,
is_to_one,
)
Expand All @@ -52,23 +53,28 @@
def optimize(
queryset: QuerySet[TModel],
info: GQLInfo,
max_complexity: Optional[int] = None,
*,
pk: PK = None,
max_complexity: Optional[int] = None,
repopulate: bool = False,
) -> QuerySet[TModel]:
"""
Optimize the given queryset according to the field selections
received in the GraphQLResolveInfo.
:param queryset: Base queryset to optimize from.
:param info: The GraphQLResolveInfo object used in the optimization process.
:param max_complexity: How many 'select_related' and 'prefetch_related' table joins are allowed.
Used to protect from malicious queries.
:param pk: Primary key for an item in the queryset model. If set, optimizer will check
the query cache for that primary key before making query.
:param max_complexity: How many 'select_related' and 'prefetch_related' table joins are allowed.
Used to protect from malicious queries.
:param repopulate: If True, repopulates the QuerySet._result_cache from the optimizer cache.
This should be used when additional filters are applied to the queryset
after optimization.
:return: The optimized queryset.
"""
# Check if prior optimization has been done already
if queryset._hints.get(optimizer_settings.OPTIMIZER_MARK, False): # type: ignore[attr-defined]
if not repopulate and is_optimized(queryset):
return queryset

field_type = get_field_type(info)
Expand All @@ -91,6 +97,20 @@ def optimize(
queryset._result_cache = [cached_item]
return queryset

if repopulate:
queryset._result_cache: list[TModel] = []
for pk in queryset.values_list("pk", flat=True):
cached_item = get_from_query_cache(info.operation, info.schema, queryset.model, pk, store)
if cached_item is None:
msg = (
f"Could not find '{queryset.model.__class__.__name__}' object with primary key "
f"'{pk}' from the optimizer cache. Check that the queryset results are narrowed "
f"and not expanded when repopulating."
)
raise ValueError(msg)
queryset._result_cache.append(cached_item)
return queryset

queryset = store.optimize_queryset(queryset, pk=pk)
if optimizer.cache_results:
store_in_query_cache(info.operation, queryset, info.schema, store)
Expand Down Expand Up @@ -296,8 +316,8 @@ def handle_to_one(
model_field.model,
)

if isinstance(model_field, ForeignKey): # Add connecting entity
store.only_fields.append(model_field_name)
if isinstance(model_field, ForeignKey):
store.related_fields.append(model_field_name)

store.select_stores[model_field_name] = nested_store

Expand All @@ -319,8 +339,8 @@ def handle_to_many(
model_field.model,
)

if isinstance(model_field, ManyToOneRel): # Add connecting entity
nested_store.only_fields.append(model_field.field.name)
if isinstance(model_field, ManyToOneRel):
nested_store.related_fields.append(model_field.field.name)

related_queryset: QuerySet[Model] = model_field.related_model.objects.all()
store.prefetch_stores[model_field_name] = nested_store, related_queryset
Expand Down
14 changes: 8 additions & 6 deletions query_optimizer/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .settings import optimizer_settings
from .typing import PK, TypeVar
from .utils import unique
from .utils import mark_optimized, unique

TModel = TypeVar("TModel", bound=Model)

Expand All @@ -18,6 +18,7 @@
@dataclass
class CompilationResults:
only_fields: list[str]
related_fields: list[str]
select_related: list[str]
prefetch_related: list[Prefetch]

Expand All @@ -28,12 +29,14 @@ class QueryOptimizerStore:
def __init__(self, model: type[Model]) -> None:
self.model = model
self.only_fields: list[str] = []
self.related_fields: list[str] = []
self.select_stores: dict[str, "QueryOptimizerStore"] = {}
self.prefetch_stores: dict[str, tuple["QueryOptimizerStore", QuerySet[Model]]] = {}

def compile(self, *, in_prefetch: bool = False) -> CompilationResults: # noqa: A003
results = CompilationResults(
only_fields=self.only_fields.copy(),
related_fields=self.related_fields.copy(),
select_related=[],
prefetch_related=[],
)
Expand Down Expand Up @@ -82,14 +85,12 @@ def optimize_queryset(
queryset = queryset.prefetch_related(*results.prefetch_related)
if results.select_related:
queryset = queryset.select_related(*results.select_related)
if results.only_fields and not optimizer_settings.DISABLE_ONLY_FIELDS_OPTIMIZATION:
queryset = queryset.only(*results.only_fields)
if not optimizer_settings.DISABLE_ONLY_FIELDS_OPTIMIZATION and (results.only_fields or results.related_fields):
queryset = queryset.only(*results.only_fields, *results.related_fields)
if pk is not None:
queryset = queryset.filter(pk=pk)

# Mark queryset as optimized so that later optimizers know to skip optimization
queryset._hints[optimizer_settings.OPTIMIZER_MARK] = True # type: ignore[attr-defined]

mark_optimized(queryset)
return queryset

@property
Expand All @@ -103,6 +104,7 @@ def complexity(self) -> int:

def __add__(self, other: "QueryOptimizerStore") -> "QueryOptimizerStore":
self.only_fields += other.only_fields
self.related_fields += other.related_fields
self.select_stores.update(other.select_stores)
self.prefetch_stores.update(other.prefetch_stores)
return self
Expand Down
4 changes: 2 additions & 2 deletions query_optimizer/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ def max_complexity(cls) -> int:
@classmethod
def get_queryset(cls, queryset: QuerySet[TModel], info: GQLInfo) -> QuerySet[TModel]:
if can_optimize(info):
queryset = optimize(queryset, info, cls.max_complexity())
queryset = optimize(queryset, info, max_complexity=cls.max_complexity())
return queryset

@classmethod
def get_node(cls, info: GQLInfo, id: PK) -> Optional[TModel]: # noqa: A002
queryset: QuerySet[TModel] = cls._meta.model.objects.filter(pk=id)
if can_optimize(info):
queryset = optimize(queryset, info, cls.max_complexity(), pk=id)
queryset = optimize(queryset, info, max_complexity=cls.max_complexity(), pk=id)
return queryset.first()
21 changes: 20 additions & 1 deletion query_optimizer/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
from django.db.models import ForeignKey
from django.db.models import ForeignKey, QuerySet
from graphene import Connection
from graphene.types.definitions import GrapheneObjectType
from graphene_django import DjangoObjectType
from graphql import GraphQLOutputType, SelectionNode
from graphql.execution.execute import get_field_def

from .settings import optimizer_settings
from .typing import Collection, GQLInfo, ModelField, ToManyField, ToOneField, TypeGuard, TypeVar

__all__ = [
"get_field_type",
"get_selections",
"get_underlying_type",
"is_foreign_key_id",
"is_optimized",
"is_to_many",
"is_to_one",
"mark_optimized",
"mark_unoptimized",
"unique",
]

Expand Down Expand Up @@ -61,3 +65,18 @@ def can_optimize(info: GQLInfo) -> bool:
return isinstance(return_type, GrapheneObjectType) and (
issubclass(return_type.graphene_type, (DjangoObjectType, Connection))
)


def mark_optimized(queryset: QuerySet) -> None:
"""Mark queryset as optimized so that later optimizers know to skip optimization"""
queryset._hints[optimizer_settings.OPTIMIZER_MARK] = True # type: ignore[attr-defined]


def mark_unoptimized(queryset: QuerySet) -> None:
"""Mark queryset as unoptimized so that later optimizers will run optimization"""
queryset._hints.pop(optimizer_settings.OPTIMIZER_MARK, None) # type: ignore[attr-defined]


def is_optimized(queryset: QuerySet) -> bool:
"""Has the queryset be optimized?"""
return queryset._hints.get(optimizer_settings.OPTIMIZER_MARK, False) # type: ignore[attr-defined]
14 changes: 13 additions & 1 deletion tests/example/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from django_filters import CharFilter, FilterSet, OrderingFilter
from graphene import relay

from query_optimizer import DjangoObjectType, required_fields
from query_optimizer import DjangoObjectType, optimize, required_fields
from query_optimizer.typing import GQLInfo
from query_optimizer.utils import can_optimize
from tests.example.models import (
Apartment,
ApartmentProxy,
Expand Down Expand Up @@ -103,6 +104,17 @@ class SaleType(DjangoObjectType):
class Meta:
model = Sale

@classmethod
def filter_queryset(cls, queryset: QuerySet, info: GQLInfo) -> QuerySet:
if can_optimize(info):
queryset = optimize(
queryset.filter(purchase_price__gte=1),
info,
max_complexity=cls.max_complexity(),
repopulate=True,
)
return queryset


class OwnerType(DjangoObjectType):
class Meta:
Expand Down
70 changes: 48 additions & 22 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from graphql_relay import to_global_id

from tests.example.models import Apartment, Building, HousingCompany
from tests.example.types import ApartmentNode
from tests.example.types import ApartmentNode, SaleType
from tests.example.utils import count_queries

pytestmark = pytest.mark.django_db
Expand Down Expand Up @@ -126,9 +126,7 @@ def test_optimizer_relay_node(client_query):
}
}
}
""" % (
global_id,
)
""" % (global_id,)

with count_queries() as results:
response = client_query(query)
Expand Down Expand Up @@ -168,9 +166,7 @@ def test_optimizer_relay_node_deep(client_query):
}
}
}
""" % (
global_id,
)
""" % (global_id,)

with count_queries() as results:
response = client_query(query)
Expand Down Expand Up @@ -285,9 +281,7 @@ def test_optimizer_relay_connection_filtering(client_query):
}
}
}
""" % (
street_address,
)
""" % (street_address,)

with count_queries() as results:
response = client_query(query)
Expand Down Expand Up @@ -323,9 +317,7 @@ def test_optimizer_relay_connection_filtering_nested(client_query):
}
}
}
""" % (
building_name,
)
""" % (building_name,)

with count_queries() as results:
response = client_query(query)
Expand Down Expand Up @@ -357,9 +349,7 @@ def test_optimizer_relay_connection_filtering_empty(client_query):
}
}
}
""" % (
"foo",
)
""" % ("foo",)

with count_queries() as results:
response = client_query(query)
Expand Down Expand Up @@ -663,9 +653,7 @@ def test_optimizer_filter(client_query):
}
}
}
""" % (
postal_code,
)
""" % (postal_code,)

with count_queries() as results:
response = client_query(query)
Expand Down Expand Up @@ -699,9 +687,7 @@ def test_optimizer_filter_to_many_relations(client_query):
}
}
}
""" % (
developer_name,
)
""" % (developer_name,)

with count_queries() as results:
response = client_query(query)
Expand Down Expand Up @@ -799,3 +785,43 @@ def test_optimizer_max_complexity_reached(client_query):

queries = len(results.queries)
assert queries == 0, results.message


def test_optimizer_many_to_one_relations__additional_filtering(client_query):
query = """
query {
allApartments {
streetAddress
stair
apartmentNumber
sales {
purchaseDate
ownerships {
percentage
owner {
name
}
}
}
}
}
"""

original_get_queryset = SaleType.get_queryset
try:
SaleType.get_queryset = SaleType.filter_queryset
with count_queries() as results:
response = client_query(query)
finally:
SaleType.get_queryset = original_get_queryset

content = json.loads(response.content)
assert "errors" not in content, content["errors"]
assert "data" in content, content
assert "allApartments" in content["data"], content["data"]
apartments = content["data"]["allApartments"]
assert len(apartments) != 0, apartments

queries = len(results.queries)
# Normal 3 queries, and an additional 40+ to filter the results in Sales.get_queryset
assert queries > 40, results.message

0 comments on commit e8706c1

Please sign in to comment.