Skip to content

Commit 8cb9a18

Browse files
authored
DSL using immutable graphql-core AST classes (#589)
1 parent 2997cd9 commit 8cb9a18

File tree

4 files changed

+93
-29
lines changed

4 files changed

+93
-29
lines changed

gql/dsl.py

Lines changed: 80 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from graphql import (
2626
ArgumentNode,
2727
BooleanValueNode,
28+
ConstDirectiveNode,
2829
DirectiveLocation,
2930
DirectiveNode,
3031
DocumentNode,
@@ -433,7 +434,10 @@ def args(self, **kwargs: Any) -> Self:
433434
arguments=tuple(
434435
ArgumentNode(
435436
name=NameNode(value=key),
436-
value=ast_from_value(value, self.directive_def.args[key].type),
437+
value=cast(
438+
ValueNode,
439+
ast_from_value(value, self.directive_def.args[key].type),
440+
),
437441
)
438442
for key, value in kwargs.items()
439443
),
@@ -596,7 +600,13 @@ def alias(self, alias: str) -> Self:
596600
:return: itself
597601
"""
598602

599-
self.ast_field.alias = NameNode(value=alias)
603+
self.ast_field = FieldNode(
604+
name=self.ast_field.name,
605+
alias=NameNode(value=alias),
606+
arguments=self.ast_field.arguments,
607+
directives=self.ast_field.directives,
608+
selection_set=self.ast_field.selection_set,
609+
)
600610
return self
601611

602612

@@ -667,7 +677,9 @@ def select(
667677
] = tuple(field.ast_field for field in added_fields)
668678

669679
# Update the current selection list with new selections
670-
self.selection_set.selections = self.selection_set.selections + added_selections
680+
self.selection_set = SelectionSetNode(
681+
selections=self.selection_set.selections + added_selections
682+
)
671683

672684
log.debug(f"Added fields: {added_fields} in {self!r}")
673685

@@ -799,7 +811,7 @@ def executable_ast(self) -> OperationDefinitionNode:
799811
operation=OperationType(self.operation_type),
800812
selection_set=self.selection_set,
801813
variable_definitions=self.variable_definitions.get_ast_definitions(),
802-
**({"name": NameNode(value=self.name)} if self.name else {}),
814+
name=NameNode(value=self.name) if self.name else None,
803815
directives=self.directives_ast,
804816
)
805817

@@ -857,7 +869,12 @@ def to_ast_type(self, type_: GraphQLInputType) -> TypeNode:
857869
return ListTypeNode(type=self.to_ast_type(type_.of_type))
858870

859871
elif isinstance(type_, GraphQLNonNull):
860-
return NonNullTypeNode(type=self.to_ast_type(type_.of_type))
872+
return NonNullTypeNode(
873+
type=cast(
874+
Union[NamedTypeNode, ListTypeNode],
875+
self.to_ast_type(type_.of_type),
876+
)
877+
)
861878

862879
assert isinstance(
863880
type_, (GraphQLScalarType, GraphQLEnumType, GraphQLInputObjectType)
@@ -924,14 +941,14 @@ def get_ast_definitions(self) -> Tuple[VariableDefinitionNode, ...]:
924941
"""
925942
return tuple(
926943
VariableDefinitionNode(
927-
type=var.ast_variable_type,
944+
type=cast(TypeNode, var.ast_variable_type),
928945
variable=var.ast_variable_name,
929946
default_value=(
930947
None
931948
if var.default_value is None
932949
else ast_from_value(var.default_value, var.type)
933950
),
934-
directives=var.directives_ast,
951+
directives=cast(Tuple[ConstDirectiveNode, ...], var.directives_ast),
935952
)
936953
for var in self.variables.values()
937954
if var.type is not None # only variables used
@@ -1141,13 +1158,23 @@ def args(self, **kwargs: Any) -> Self:
11411158

11421159
assert self.ast_field.arguments is not None
11431160

1144-
self.ast_field.arguments = self.ast_field.arguments + tuple(
1161+
new_arguments = self.ast_field.arguments + tuple(
11451162
ArgumentNode(
11461163
name=NameNode(value=name),
1147-
value=ast_from_value(value, self._get_argument(name).type),
1164+
value=cast(
1165+
ValueNode,
1166+
ast_from_value(value, self._get_argument(name).type),
1167+
),
11481168
)
11491169
for name, value in kwargs.items()
11501170
)
1171+
self.ast_field = FieldNode(
1172+
name=self.ast_field.name,
1173+
alias=self.ast_field.alias,
1174+
arguments=new_arguments,
1175+
directives=self.ast_field.directives,
1176+
selection_set=self.ast_field.selection_set,
1177+
)
11511178

11521179
log.debug(f"Added arguments {kwargs} in field {self!r})")
11531180

@@ -1175,14 +1202,26 @@ def select(
11751202
"""
11761203

11771204
super().select(*fields, **fields_with_alias)
1178-
self.ast_field.selection_set = self.selection_set
1205+
self.ast_field = FieldNode(
1206+
name=self.ast_field.name,
1207+
alias=self.ast_field.alias,
1208+
arguments=self.ast_field.arguments,
1209+
directives=self.ast_field.directives,
1210+
selection_set=self.selection_set,
1211+
)
11791212

11801213
return self
11811214

11821215
def directives(self, *directives: DSLDirective) -> Self:
11831216
"""Add directives to this field."""
11841217
super().directives(*directives)
1185-
self.ast_field.directives = self.directives_ast
1218+
self.ast_field = FieldNode(
1219+
name=self.ast_field.name,
1220+
alias=self.ast_field.alias,
1221+
arguments=self.ast_field.arguments,
1222+
directives=self.directives_ast,
1223+
selection_set=self.ast_field.selection_set,
1224+
)
11861225

11871226
return self
11881227

@@ -1254,7 +1293,10 @@ def __init__(
12541293

12551294
log.debug(f"Creating {self!r}")
12561295

1257-
self.ast_field = InlineFragmentNode(directives=())
1296+
self.ast_field = InlineFragmentNode(
1297+
selection_set=SelectionSetNode(selections=()),
1298+
directives=(),
1299+
)
12581300

12591301
DSLSelector.__init__(self, *fields, **fields_with_alias)
12601302
DSLDirectable.__init__(self)
@@ -1266,16 +1308,22 @@ def select(
12661308
corrected typing hints
12671309
"""
12681310
super().select(*fields, **fields_with_alias)
1269-
self.ast_field.selection_set = self.selection_set
1311+
self.ast_field = InlineFragmentNode(
1312+
selection_set=self.selection_set,
1313+
type_condition=self.ast_field.type_condition,
1314+
directives=self.ast_field.directives,
1315+
)
12701316

12711317
return self
12721318

12731319
def on(self, type_condition: DSLType) -> Self:
12741320
"""Provides the GraphQL type of this inline fragment."""
12751321

12761322
self._type = type_condition._type
1277-
self.ast_field.type_condition = NamedTypeNode(
1278-
name=NameNode(value=self._type.name)
1323+
self.ast_field = InlineFragmentNode(
1324+
selection_set=self.ast_field.selection_set,
1325+
type_condition=NamedTypeNode(name=NameNode(value=self._type.name)),
1326+
directives=self.ast_field.directives,
12791327
)
12801328
return self
12811329

@@ -1285,7 +1333,11 @@ def directives(self, *directives: DSLDirective) -> Self:
12851333
Inline fragments support all directive types through auto-validation.
12861334
"""
12871335
super().directives(*directives)
1288-
self.ast_field.directives = self.directives_ast
1336+
self.ast_field = InlineFragmentNode(
1337+
selection_set=self.ast_field.selection_set,
1338+
type_condition=self.ast_field.type_condition,
1339+
directives=self.directives_ast,
1340+
)
12891341
return self
12901342

12911343
def __repr__(self) -> str:
@@ -1338,7 +1390,10 @@ def directives(self, *directives: DSLDirective) -> Self:
13381390
Fragment spreads support all directive types through auto-validation.
13391391
"""
13401392
super().directives(*directives)
1341-
self.ast_field.directives = self.directives_ast
1393+
self.ast_field = FragmentSpreadNode(
1394+
name=self.ast_field.name,
1395+
directives=self.directives_ast,
1396+
)
13421397
return self
13431398

13441399
def is_valid_directive(self, directive: DSLDirective) -> bool:
@@ -1382,7 +1437,10 @@ def name(self) -> str:
13821437
def name(self, value: str) -> None:
13831438
""":meta private:"""
13841439
if hasattr(self, "ast_field"):
1385-
self.ast_field.name.value = value
1440+
self.ast_field = FragmentSpreadNode(
1441+
name=NameNode(value=value),
1442+
directives=self.ast_field.directives,
1443+
)
13861444

13871445
def spread(self) -> DSLFragmentSpread:
13881446
"""Create a fragment spread that can have its own directives.
@@ -1435,6 +1493,8 @@ def executable_ast(self) -> FragmentDefinitionNode:
14351493

14361494
fragment_variable_definitions = self.variable_definitions.get_ast_definitions()
14371495

1496+
variable_definition_kwargs: Dict[str, Any]
1497+
14381498
if len(fragment_variable_definitions) == 0:
14391499
"""Fragment variable definitions are obsolete and only supported on
14401500
graphql-core if the Parser is initialized with:
@@ -1452,9 +1512,9 @@ def executable_ast(self) -> FragmentDefinitionNode:
14521512
return FragmentDefinitionNode(
14531513
type_condition=NamedTypeNode(name=NameNode(value=self._type.name)),
14541514
selection_set=self.selection_set,
1455-
**variable_definition_kwargs,
14561515
name=NameNode(value=self.name),
14571516
directives=self.directives_ast,
1517+
**variable_definition_kwargs,
14581518
)
14591519

14601520
def is_valid_directive(self, directive: DSLDirective) -> bool:
@@ -1516,7 +1576,7 @@ def dsl_gql(
15161576
)
15171577

15181578
document = DocumentNode(
1519-
definitions=[operation.executable_ast for operation in all_operations]
1579+
definitions=tuple(operation.executable_ast for operation in all_operations)
15201580
)
15211581

15221582
return GraphQLRequest(document)

gql/utilities/get_introspection_query_ast.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,13 @@ def get_introspection_query_ast(
135135

136136
if type_recursion_level >= 1:
137137
current_field = ds.__Type.ofType.select(ds.__Type.kind, ds.__Type.name)
138-
fragment_TypeRef.select(current_field)
139138

140139
for _ in repeat(None, type_recursion_level - 1):
141-
new_oftype = ds.__Type.ofType.select(ds.__Type.kind, ds.__Type.name)
142-
current_field.select(new_oftype)
143-
current_field = new_oftype
140+
parent_field = ds.__Type.ofType.select(ds.__Type.kind, ds.__Type.name)
141+
parent_field.select(current_field)
142+
current_field = parent_field
143+
144+
fragment_TypeRef.select(current_field)
144145

145146
query = DSLQuery(schema)
146147

gql/utilities/node_tree.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def _node_tree_recursive(
1414

1515
results = []
1616

17-
if hasattr(obj, "__slots__"):
17+
if hasattr(obj, "__slots__") or isinstance(obj, Node):
1818

1919
results.append(" " * indent + f"{type(obj).__name__}")
2020

@@ -89,4 +89,7 @@ def node_tree(
8989
# We are ignoring block attributes by default (in StringValueNode)
9090
ignored_keys.append("block")
9191

92+
# Ignore new field added in graphql-core 3.3.0a12 to keep output compatible
93+
ignored_keys.append("nullability_assertion")
94+
9295
return _node_tree_recursive(obj, ignored_keys=ignored_keys)

gql/utilities/parse_result.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,10 @@ def enter_operation_definition(
124124
if not hasattr(node.name, "value"):
125125
return REMOVE # pragma: no cover
126126

127-
node.name = cast(NameNode, node.name)
127+
name = cast(NameNode, node.name)
128128

129-
if node.name.value != self.operation_name:
130-
log.debug(f"SKIPPING operation {node.name.value}")
129+
if name.value != self.operation_name:
130+
log.debug(f"SKIPPING operation {name.value}")
131131
return REMOVE
132132

133133
return IDLE
@@ -238,7 +238,7 @@ def enter_field(
238238
assert isinstance(selection_set_node, SelectionSetNode)
239239

240240
# Keep only the current node in a new selection set node
241-
new_node = SelectionSetNode(selections=[node])
241+
new_node = SelectionSetNode(selections=(node,))
242242

243243
for item in result_value:
244244

0 commit comments

Comments
 (0)