Skip to content

Commit

Permalink
Added more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
DamianCzajkowski committed Sep 9, 2024
1 parent a2031e8 commit c878df9
Show file tree
Hide file tree
Showing 24 changed files with 635 additions and 166 deletions.
12 changes: 11 additions & 1 deletion ariadne_graphql_modules/base_object_type/graphql_field.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from ariadne.types import Resolver, Subscriber
from graphql import FieldDefinitionNode, NamedTypeNode


@dataclass(frozen=True)
Expand All @@ -15,7 +16,16 @@ class GraphQLObjectFieldArg:
@dataclass(frozen=True)
class GraphQLObjectData:
fields: Dict[str, "GraphQLObjectField"]
interfaces: List[str]
interfaces: List[NamedTypeNode]


@dataclass
class GraphQLClassData:
type_aliases: Dict[str, str] = field(default_factory=dict)
fields_ast: Dict[str, FieldDefinitionNode] = field(default_factory=dict)
resolvers: Dict[str, "Resolver"] = field(default_factory=dict)
aliases: Dict[str, str] = field(default_factory=dict)
out_names: Dict[str, Dict[str, str]] = field(default_factory=dict)


@dataclass(frozen=True)
Expand Down
39 changes: 17 additions & 22 deletions ariadne_graphql_modules/base_object_type/graphql_type.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from typing import (
Any,
Expand All @@ -23,6 +22,7 @@
from ..types import GraphQLClassType

from .graphql_field import (
GraphQLClassData,
GraphQLFieldData,
GraphQLObjectData,
GraphQLObjectField,
Expand Down Expand Up @@ -54,7 +54,7 @@
class GraphQLBaseObject(GraphQLType):
__kwargs__: Dict[str, Any]
__abstract__: bool = True
__schema__: Optional[str]
__schema__: Optional[str] = None
__description__: Optional[str]
__aliases__: Optional[Dict[str, str]]
__requires__: Optional[Iterable[Union[Type[GraphQLType], Type[Enum]]]]
Expand Down Expand Up @@ -188,33 +188,29 @@ def _create_fields_and_resolvers_with_schema(

@classmethod
def _process_graphql_fields(
cls, metadata: GraphQLMetadata, type_data, type_aliases
) -> Tuple[
List[FieldDefinitionNode],
Dict[str, Resolver],
Dict[str, str],
Dict[str, Dict[str, str]],
]:
fields_ast = []
resolvers = {}
aliases = {}
out_names = {}

cls,
metadata: GraphQLMetadata,
type_data,
type_aliases,
object_model_data: GraphQLClassData,
):
for attr_name, field in type_data.fields.items():
fields_ast.append(get_field_node_from_obj_field(cls, metadata, field))
object_model_data.fields_ast[attr_name] = get_field_node_from_obj_field(
cls, metadata, field
)

if attr_name in type_aliases and field.name:
aliases[field.name] = type_aliases[attr_name]
object_model_data.aliases[field.name] = type_aliases[attr_name]
elif field.name and attr_name != field.name and not field.resolver:
aliases[field.name] = attr_name
object_model_data.aliases[field.name] = attr_name

if field.resolver and field.name:
resolvers[field.name] = field.resolver
object_model_data.resolvers[field.name] = field.resolver

if field.args and field.name:
out_names[field.name] = get_field_args_out_names(field.args)

return fields_ast, resolvers, aliases, out_names
object_model_data.out_names[field.name] = get_field_args_out_names(
field.args
)

@classmethod
def __get_graphql_types__(
Expand All @@ -232,7 +228,6 @@ def __get_graphql_types_with_schema__(
) -> Iterable[Type["GraphQLType"]]:
types: List[Type["GraphQLType"]] = [cls]
types.extend(getattr(cls, "__requires__", []))
types.extend(getattr(cls, "__implements__", []))
return types

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions ariadne_graphql_modules/compatibility_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .v1.subscription_type import SubscriptionType
from .v1.union_type import UnionType
from .v1.object_type import ObjectType
from .v1.bases import BindableType
from .v1.bases import BaseType, BindableType

from .base import GraphQLModel, GraphQLType
from . import (
Expand All @@ -38,7 +38,7 @@


def wrap_legacy_types(
*bindable_types: Type[BindableType],
*bindable_types: Type[BaseType],
) -> List[Type["LegacyGraphQLType"]]:
all_types = get_all_types(bindable_types)

Expand Down
8 changes: 8 additions & 0 deletions ariadne_graphql_modules/deferredtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,18 @@

@dataclass(frozen=True)
class DeferredTypeData:
"""Data class representing deferred type information with a module path."""

path: str


def deferred(module_path: str) -> DeferredTypeData:
"""
Create a DeferredTypeData object from a given module path.
If the module path is relative (starts with '.'),
resolve it based on the caller's package context.
"""
if not module_path.startswith("."):
return DeferredTypeData(module_path)

Expand Down
7 changes: 7 additions & 0 deletions ariadne_graphql_modules/description.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@


def get_description_node(description: Optional[str]) -> Optional[StringValueNode]:
"""Convert a description string into a GraphQL StringValueNode.
If the description is provided, it will be dedented, stripped of surrounding
whitespace, and used to create a StringValueNode. If the description contains
newline characters, the `block` attribute of the StringValueNode
will be set to `True`.
"""
if not description:
return None

Expand Down
11 changes: 8 additions & 3 deletions ariadne_graphql_modules/executable_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


def make_executable_schema(
*types: SchemaType,
*types: Union[SchemaType, List[SchemaType]],
directives: Optional[Dict[str, Type[SchemaDirectiveVisitor]]] = None,
convert_names_case: Union[bool, SchemaNameConverter] = False,
merge_roots: bool = True,
Expand Down Expand Up @@ -96,7 +96,12 @@ def make_executable_schema(
return schema


def find_type_defs(types: Sequence[SchemaType]) -> List[str]:
def find_type_defs(
types: Union[
tuple[SchemaType | List[SchemaType], ...],
List[SchemaType],
]
) -> List[str]:
type_defs: List[str] = []

for type_def in types:
Expand All @@ -109,7 +114,7 @@ def find_type_defs(types: Sequence[SchemaType]) -> List[str]:


def flatten_types(
types: Sequence[SchemaType],
types: tuple[SchemaType | List[SchemaType], ...],
metadata: GraphQLMetadata,
) -> List[SchemaType]:
flat_schema_types_list: List[SchemaType] = flatten_schema_types(
Expand Down
36 changes: 18 additions & 18 deletions ariadne_graphql_modules/interface_type/graphql_type.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Tuple, cast
from typing import Any, Dict, Optional, Tuple, cast

from ariadne.types import Resolver
from graphql import (
Expand All @@ -8,6 +8,8 @@
NamedTypeNode,
)

from ariadne_graphql_modules.base_object_type.graphql_field import GraphQLClassData

from ..base_object_type import (
GraphQLFieldData,
GraphQLBaseObject,
Expand All @@ -29,6 +31,7 @@ class GraphQLInterface(GraphQLBaseObject):
__graphql_type__ = GraphQLClassType.INTERFACE
__abstract__ = True
__description__: Optional[str] = None
__graphql_name__: Optional[str] = None

def __init_subclass__(cls) -> None:
super().__init_subclass__()
Expand Down Expand Up @@ -78,19 +81,11 @@ def __get_graphql_model_without_schema__(
type_data = cls.get_graphql_object_data(metadata)
type_aliases = getattr(cls, "__aliases__", None) or {}

fields_ast: List[FieldDefinitionNode] = []
resolvers: Dict[str, Resolver] = {}
aliases: Dict[str, str] = {}
out_names: Dict[str, Dict[str, str]] = {}

fields_ast, resolvers, aliases, out_names = cls._process_graphql_fields(
metadata, type_data, type_aliases
object_model_data = GraphQLClassData()
cls._process_graphql_fields(
metadata, type_data, type_aliases, object_model_data
)

interfaces_ast: List[NamedTypeNode] = []
for interface_name in type_data.interfaces:
interfaces_ast.append(NamedTypeNode(name=NameNode(value=interface_name)))

return GraphQLInterfaceModel(
name=name,
ast_type=InterfaceTypeDefinitionNode,
Expand All @@ -99,13 +94,13 @@ def __get_graphql_model_without_schema__(
description=get_description_node(
getattr(cls, "__description__", None),
),
fields=tuple(fields_ast),
interfaces=tuple(interfaces_ast),
fields=tuple(object_model_data.fields_ast.values()),
interfaces=tuple(type_data.interfaces),
),
resolve_type=cls.resolve_type,
resolvers=resolvers,
aliases=aliases,
out_names=out_names,
resolvers=object_model_data.resolvers,
aliases=object_model_data.aliases,
out_names=object_model_data.out_names,
)

@staticmethod
Expand Down Expand Up @@ -143,5 +138,10 @@ def create_graphql_object_data_without_schema(cls) -> GraphQLObjectData:

return GraphQLObjectData(
fields=cls._build_fields(fields_data=fields_data),
interfaces=[],
interfaces=[
NamedTypeNode(name=NameNode(value=interface.__name__))
for interface in inherited_objects
if getattr(interface, "__graphql_type__", None)
== GraphQLClassType.INTERFACE
],
)
47 changes: 13 additions & 34 deletions ariadne_graphql_modules/object_type/graphql_type.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import (
Dict,
List,
Optional,
Tuple,
cast,
Expand All @@ -14,14 +13,14 @@
ObjectTypeDefinitionNode,
)

from ariadne_graphql_modules.base_object_type.graphql_field import GraphQLClassData

from ..types import GraphQLClassType

from ..base_object_type import (
GraphQLFieldData,
GraphQLBaseObject,
GraphQLObjectData,
validate_object_type_with_schema,
validate_object_type_without_schema,
)
from .models import GraphQLObjectModel

Expand All @@ -36,20 +35,8 @@ class GraphQLObject(GraphQLBaseObject):
__graphql_type__ = GraphQLClassType.OBJECT
__abstract__ = True
__description__: Optional[str] = None

def __init_subclass__(cls) -> None:
super().__init_subclass__()

if cls.__dict__.get("__abstract__"):
return

cls.__abstract__ = False

if cls.__dict__.get("__schema__"):
valid_type = getattr(cls, "__valid_type__", ObjectTypeDefinitionNode)
cls.__kwargs__ = validate_object_type_with_schema(cls, valid_type)
else:
cls.__kwargs__ = validate_object_type_without_schema(cls)
__schema__: Optional[str] = None
__graphql_name__: Optional[str] = None

@classmethod
def __get_graphql_model_with_schema__(cls) -> "GraphQLModel":
Expand Down Expand Up @@ -84,19 +71,11 @@ def __get_graphql_model_without_schema__(
type_data = cls.get_graphql_object_data(metadata)
type_aliases = getattr(cls, "__aliases__", {})

fields_ast: List[FieldDefinitionNode] = []
resolvers: Dict[str, Resolver] = {}
aliases: Dict[str, str] = {}
out_names: Dict[str, Dict[str, str]] = {}

fields_ast, resolvers, aliases, out_names = cls._process_graphql_fields(
metadata, type_data, type_aliases
object_model_data = GraphQLClassData()
cls._process_graphql_fields(
metadata, type_data, type_aliases, object_model_data
)

interfaces_ast: List[NamedTypeNode] = []
for interface_name in type_data.interfaces:
interfaces_ast.append(NamedTypeNode(name=NameNode(value=interface_name)))

return GraphQLObjectModel(
name=name,
ast_type=ObjectTypeDefinitionNode,
Expand All @@ -105,12 +84,12 @@ def __get_graphql_model_without_schema__(
description=get_description_node(
getattr(cls, "__description__", None),
),
fields=tuple(fields_ast),
interfaces=tuple(interfaces_ast),
fields=tuple(object_model_data.fields_ast.values()),
interfaces=tuple(type_data.interfaces),
),
resolvers=resolvers,
aliases=aliases,
out_names=out_names,
resolvers=object_model_data.resolvers,
aliases=object_model_data.aliases,
out_names=object_model_data.out_names,
)

@classmethod
Expand Down Expand Up @@ -140,7 +119,7 @@ def create_graphql_object_data_without_schema(cls) -> GraphQLObjectData:
return GraphQLObjectData(
fields=cls._build_fields(fields_data=fields_data),
interfaces=[
interface.__name__
NamedTypeNode(name=NameNode(value=interface.__name__))
for interface in inherited_objects
if getattr(interface, "__graphql_type__", None)
== GraphQLClassType.INTERFACE
Expand Down
Loading

0 comments on commit c878df9

Please sign in to comment.