Skip to content

Commit beeece5

Browse files
committed
Add hook for doing additional filtering on object type
1 parent e8706c1 commit beeece5

File tree

5 files changed

+42
-79
lines changed

5 files changed

+42
-79
lines changed

query_optimizer/optimizer.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from typing import TYPE_CHECKING
2+
13
from django.core.exceptions import FieldDoesNotExist
24
from django.db.models import ForeignKey, ManyToOneRel, Model, QuerySet
35
from graphene.relay.connection import ConnectionOptions
46
from graphene.types.definitions import GrapheneObjectType, GrapheneUnionType
57
from graphene.utils.str_converters import to_snake_case
8+
from graphene_django.registry import get_global_registry
69
from graphene_django.types import DjangoObjectTypeOptions
710
from graphql import (
811
FieldNode,
@@ -39,6 +42,10 @@
3942
is_to_one,
4043
)
4144

45+
if TYPE_CHECKING:
46+
from graphene import ObjectType
47+
48+
4249
TModel = TypeVar("TModel", bound=Model)
4350
TCallable = TypeVar("TCallable", bound=Callable)
4451

@@ -56,7 +63,6 @@ def optimize(
5663
*,
5764
pk: PK = None,
5865
max_complexity: Optional[int] = None,
59-
repopulate: bool = False,
6066
) -> QuerySet[TModel]:
6167
"""
6268
Optimize the given queryset according to the field selections
@@ -68,13 +74,10 @@ def optimize(
6874
the query cache for that primary key before making query.
6975
:param max_complexity: How many 'select_related' and 'prefetch_related' table joins are allowed.
7076
Used to protect from malicious queries.
71-
:param repopulate: If True, repopulates the QuerySet._result_cache from the optimizer cache.
72-
This should be used when additional filters are applied to the queryset
73-
after optimization.
7477
:return: The optimized queryset.
7578
"""
7679
# Check if prior optimization has been done already
77-
if not repopulate and is_optimized(queryset):
80+
if is_optimized(queryset):
7881
return queryset
7982

8083
field_type = get_field_type(info)
@@ -97,20 +100,6 @@ def optimize(
97100
queryset._result_cache = [cached_item]
98101
return queryset
99102

100-
if repopulate:
101-
queryset._result_cache: list[TModel] = []
102-
for pk in queryset.values_list("pk", flat=True):
103-
cached_item = get_from_query_cache(info.operation, info.schema, queryset.model, pk, store)
104-
if cached_item is None:
105-
msg = (
106-
f"Could not find '{queryset.model.__class__.__name__}' object with primary key "
107-
f"'{pk}' from the optimizer cache. Check that the queryset results are narrowed "
108-
f"and not expanded when repopulating."
109-
)
110-
raise ValueError(msg)
111-
queryset._result_cache.append(cached_item)
112-
return queryset
113-
114103
queryset = store.optimize_queryset(queryset, pk=pk)
115104
if optimizer.cache_results:
116105
store_in_query_cache(info.operation, queryset, info.schema, store)
@@ -250,7 +239,8 @@ def find_field_from_model(
250239

251240
nested_store = QueryOptimizerStore(model=related_model)
252241
if is_to_many(model_field):
253-
store.prefetch_stores[model_field.name] = (nested_store, related_model.objects.all())
242+
queryset = self.get_filtered_queryset(related_model)
243+
store.prefetch_stores[model_field.name] = (nested_store, queryset)
254244
elif is_to_one(model_field):
255245
store.select_stores[model_field.name] = nested_store
256246
else: # pragma: no cover
@@ -342,9 +332,20 @@ def handle_to_many(
342332
if isinstance(model_field, ManyToOneRel):
343333
nested_store.related_fields.append(model_field.field.name)
344334

345-
related_queryset: QuerySet[Model] = model_field.related_model.objects.all()
335+
related_model: type[Model] = model_field.related_model # type: ignore[assignment]
336+
if related_model == "self": # pragma: no cover
337+
related_model = model_field.model
338+
339+
related_queryset = self.get_filtered_queryset(related_model)
346340
store.prefetch_stores[model_field_name] = nested_store, related_queryset
347341

342+
def get_filtered_queryset(self, model: type[TModel]) -> QuerySet[TModel]:
343+
qs: QuerySet = model.objects.all()
344+
object_type: ObjectType | None = get_global_registry().get_type_for_model(model)
345+
if hasattr(object_type, "filter_queryset") and callable(object_type.filter_queryset):
346+
qs = object_type.filter_queryset(qs, self.info)
347+
return qs
348+
348349
def optimize_fragment_spread(
349350
self,
350351
field_type: GrapheneObjectType,

query_optimizer/typing.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Any, Callable, Collection, Hashable, Iterable, NamedTuple, Optional, TypeVar, Union
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any, Callable, Collection, Hashable, Iterable, NamedTuple, Optional, TypeVar, Union
24

35
from graphene.relay.connection import ConnectionOptions
46
from graphene_django.types import DjangoObjectTypeOptions
@@ -10,7 +12,6 @@
1012
from typing_extensions import TypeAlias, TypeGuard
1113

1214

13-
import graphql
1415
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
1516
from django.core.handlers.wsgi import WSGIRequest
1617
from django.db.models import (
@@ -24,6 +25,10 @@
2425
Model,
2526
OneToOneField,
2627
)
28+
from graphql import GraphQLResolveInfo
29+
30+
if TYPE_CHECKING:
31+
from django.contrib.auth.models import AnonymousUser, User
2732

2833
__all__ = [
2934
"Any",
@@ -48,10 +53,6 @@
4853
]
4954

5055

51-
class GQLInfo(graphql.GraphQLResolveInfo):
52-
context: WSGIRequest
53-
54-
5556
TModel = TypeVar("TModel", bound=Model)
5657
TableName: TypeAlias = str
5758
StoreStr: TypeAlias = str
@@ -61,3 +62,12 @@ class GQLInfo(graphql.GraphQLResolveInfo):
6162
ToManyField: TypeAlias = Union[GenericRelation, ManyToManyField, ManyToOneRel, ManyToManyRel]
6263
ToOneField: TypeAlias = Union[GenericRelation, ForeignObject, ForeignKey, OneToOneField]
6364
TypeOptions: TypeAlias = Union[DjangoObjectTypeOptions, ConnectionOptions]
65+
AnyUser: TypeAlias = Union["User", "AnonymousUser"]
66+
67+
68+
class UserHintedWSGIRequest(WSGIRequest):
69+
user: AnyUser
70+
71+
72+
class GQLInfo(GraphQLResolveInfo):
73+
context: UserHintedWSGIRequest

query_optimizer/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def mark_optimized(queryset: QuerySet) -> None:
7272
queryset._hints[optimizer_settings.OPTIMIZER_MARK] = True # type: ignore[attr-defined]
7373

7474

75-
def mark_unoptimized(queryset: QuerySet) -> None:
75+
def mark_unoptimized(queryset: QuerySet) -> None: # pragma: no cover
7676
"""Mark queryset as unoptimized so that later optimizers will run optimization"""
7777
queryset._hints.pop(optimizer_settings.OPTIMIZER_MARK, None) # type: ignore[attr-defined]
7878

tests/example/types.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
from django_filters import CharFilter, FilterSet, OrderingFilter
55
from graphene import relay
66

7-
from query_optimizer import DjangoObjectType, optimize, required_fields
7+
from query_optimizer import DjangoObjectType, required_fields
88
from query_optimizer.typing import GQLInfo
9-
from query_optimizer.utils import can_optimize
109
from tests.example.models import (
1110
Apartment,
1211
ApartmentProxy,
@@ -106,14 +105,7 @@ class Meta:
106105

107106
@classmethod
108107
def filter_queryset(cls, queryset: QuerySet, info: GQLInfo) -> QuerySet:
109-
if can_optimize(info):
110-
queryset = optimize(
111-
queryset.filter(purchase_price__gte=1),
112-
info,
113-
max_complexity=cls.max_complexity(),
114-
repopulate=True,
115-
)
116-
return queryset
108+
return queryset.filter(purchase_price__gte=1)
117109

118110

119111
class OwnerType(DjangoObjectType):

tests/test_optimizer.py

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from graphql_relay import to_global_id
77

88
from tests.example.models import Apartment, Building, HousingCompany
9-
from tests.example.types import ApartmentNode, SaleType
9+
from tests.example.types import ApartmentNode
1010
from tests.example.utils import count_queries
1111

1212
pytestmark = pytest.mark.django_db
@@ -785,43 +785,3 @@ def test_optimizer_max_complexity_reached(client_query):
785785

786786
queries = len(results.queries)
787787
assert queries == 0, results.message
788-
789-
790-
def test_optimizer_many_to_one_relations__additional_filtering(client_query):
791-
query = """
792-
query {
793-
allApartments {
794-
streetAddress
795-
stair
796-
apartmentNumber
797-
sales {
798-
purchaseDate
799-
ownerships {
800-
percentage
801-
owner {
802-
name
803-
}
804-
}
805-
}
806-
}
807-
}
808-
"""
809-
810-
original_get_queryset = SaleType.get_queryset
811-
try:
812-
SaleType.get_queryset = SaleType.filter_queryset
813-
with count_queries() as results:
814-
response = client_query(query)
815-
finally:
816-
SaleType.get_queryset = original_get_queryset
817-
818-
content = json.loads(response.content)
819-
assert "errors" not in content, content["errors"]
820-
assert "data" in content, content
821-
assert "allApartments" in content["data"], content["data"]
822-
apartments = content["data"]["allApartments"]
823-
assert len(apartments) != 0, apartments
824-
825-
queries = len(results.queries)
826-
# Normal 3 queries, and an additional 40+ to filter the results in Sales.get_queryset
827-
assert queries > 40, results.message

0 commit comments

Comments
 (0)