1
+ from typing import TYPE_CHECKING
2
+
1
3
from django .core .exceptions import FieldDoesNotExist
2
4
from django .db .models import ForeignKey , ManyToOneRel , Model , QuerySet
3
5
from graphene .relay .connection import ConnectionOptions
4
6
from graphene .types .definitions import GrapheneObjectType , GrapheneUnionType
5
7
from graphene .utils .str_converters import to_snake_case
8
+ from graphene_django .registry import get_global_registry
6
9
from graphene_django .types import DjangoObjectTypeOptions
7
10
from graphql import (
8
11
FieldNode ,
39
42
is_to_one ,
40
43
)
41
44
45
+ if TYPE_CHECKING :
46
+ from graphene import ObjectType
47
+
48
+
42
49
TModel = TypeVar ("TModel" , bound = Model )
43
50
TCallable = TypeVar ("TCallable" , bound = Callable )
44
51
@@ -56,7 +63,6 @@ def optimize(
56
63
* ,
57
64
pk : PK = None ,
58
65
max_complexity : Optional [int ] = None ,
59
- repopulate : bool = False ,
60
66
) -> QuerySet [TModel ]:
61
67
"""
62
68
Optimize the given queryset according to the field selections
@@ -68,13 +74,10 @@ def optimize(
68
74
the query cache for that primary key before making query.
69
75
:param max_complexity: How many 'select_related' and 'prefetch_related' table joins are allowed.
70
76
Used to protect from malicious queries.
71
- :param repopulate: If True, repopulates the QuerySet._result_cache from the optimizer cache.
72
- This should be used when additional filters are applied to the queryset
73
- after optimization.
74
77
:return: The optimized queryset.
75
78
"""
76
79
# Check if prior optimization has been done already
77
- if not repopulate and is_optimized (queryset ):
80
+ if is_optimized (queryset ):
78
81
return queryset
79
82
80
83
field_type = get_field_type (info )
@@ -97,20 +100,6 @@ def optimize(
97
100
queryset ._result_cache = [cached_item ]
98
101
return queryset
99
102
100
- if repopulate :
101
- queryset ._result_cache : list [TModel ] = []
102
- for pk in queryset .values_list ("pk" , flat = True ):
103
- cached_item = get_from_query_cache (info .operation , info .schema , queryset .model , pk , store )
104
- if cached_item is None :
105
- msg = (
106
- f"Could not find '{ queryset .model .__class__ .__name__ } ' object with primary key "
107
- f"'{ pk } ' from the optimizer cache. Check that the queryset results are narrowed "
108
- f"and not expanded when repopulating."
109
- )
110
- raise ValueError (msg )
111
- queryset ._result_cache .append (cached_item )
112
- return queryset
113
-
114
103
queryset = store .optimize_queryset (queryset , pk = pk )
115
104
if optimizer .cache_results :
116
105
store_in_query_cache (info .operation , queryset , info .schema , store )
@@ -250,7 +239,8 @@ def find_field_from_model(
250
239
251
240
nested_store = QueryOptimizerStore (model = related_model )
252
241
if is_to_many (model_field ):
253
- store .prefetch_stores [model_field .name ] = (nested_store , related_model .objects .all ())
242
+ queryset = self .get_filtered_queryset (related_model )
243
+ store .prefetch_stores [model_field .name ] = (nested_store , queryset )
254
244
elif is_to_one (model_field ):
255
245
store .select_stores [model_field .name ] = nested_store
256
246
else : # pragma: no cover
@@ -342,9 +332,20 @@ def handle_to_many(
342
332
if isinstance (model_field , ManyToOneRel ):
343
333
nested_store .related_fields .append (model_field .field .name )
344
334
345
- related_queryset : QuerySet [Model ] = model_field .related_model .objects .all ()
335
+ related_model : type [Model ] = model_field .related_model # type: ignore[assignment]
336
+ if related_model == "self" : # pragma: no cover
337
+ related_model = model_field .model
338
+
339
+ related_queryset = self .get_filtered_queryset (related_model )
346
340
store .prefetch_stores [model_field_name ] = nested_store , related_queryset
347
341
342
+ def get_filtered_queryset (self , model : type [TModel ]) -> QuerySet [TModel ]:
343
+ qs : QuerySet = model .objects .all ()
344
+ object_type : ObjectType | None = get_global_registry ().get_type_for_model (model )
345
+ if hasattr (object_type , "filter_queryset" ) and callable (object_type .filter_queryset ):
346
+ qs = object_type .filter_queryset (qs , self .info )
347
+ return qs
348
+
348
349
def optimize_fragment_spread (
349
350
self ,
350
351
field_type : GrapheneObjectType ,
0 commit comments