diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..1a93705 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,33 @@ +name: Publish on PyPI + +on: push + +jobs: + build-n-publish: + name: Build and publish to PyPI + runs-on: ubuntu-18.04 + steps: + - uses: actions/checkout@master + - name: Set up Python 3.9 + uses: actions/setup-python@v1 + with: + python-version: 3.9 + - name: Install pypa/build + run: >- + python -m + pip install + build + --user + - name: Build a binary wheel and a source tarball + run: >- + python -m + build + --sdist + --wheel + --outdir dist/ + . + - name: Publish distribution to PyPI + if: startsWith(github.ref, 'refs/tags') + uses: pypa/gh-action-pypi-publish@master + with: + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..8143d31 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,33 @@ +name: Tests + +on: [push] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.7", "3.8", "3.9", "3.10"] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[asgi-file-uploads] + pip install -r requirements-dev.txt + - name: Pytest + run: | + pytest tests ariadne_future --cov=ariadne --cov=tests + - uses: codecov/codecov-action@v1 + - name: Linters + run: | + pylint ariadne ariadne_future tests setup.py + mypy ariadne ariadne_future --ignore-missing-imports + black --check . diff --git a/.gitignore b/.gitignore index b6e4761..d5ca745 100644 --- a/.gitignore +++ b/.gitignore @@ -17,11 +17,10 @@ eggs/ lib/ lib64/ parts/ +pip-wheel-metadata/ sdist/ var/ wheels/ -pip-wheel-metadata/ -share/python-wheels/ *.egg-info/ .installed.cfg *.egg @@ -37,17 +36,19 @@ MANIFEST pip-log.txt pip-delete-this-directory.txt +# OS files +.DS_Store +Thumbs.db + # Unit test / coverage reports htmlcov/ .tox/ -.nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover -*.py,cover .hypothesis/ .pytest_cache/ @@ -59,7 +60,6 @@ coverage.xml *.log local_settings.py db.sqlite3 -db.sqlite3-journal # Flask stuff: instance/ @@ -77,26 +77,14 @@ target/ # Jupyter Notebook .ipynb_checkpoints -# IPython -profile_default/ -ipython_config.py +# Vscode +.vscode # pyenv .python-version -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow -__pypackages__/ - -# Celery stuff +# celery beat schedule file celerybeat-schedule -celerybeat.pid # SageMath parsed files *.sage.py @@ -122,8 +110,6 @@ venv.bak/ # mypy .mypy_cache/ -.dmypy.json -dmypy.json -# Pyre type checker -.pyre/ +# PyCharm +.idea diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..054607b --- /dev/null +++ b/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2018, Mirumee Labs +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..a38f64f --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,8 @@ +include README.md +include LICENSE +graft ariadne_graphql_modules + +global-exclude __pycache__ +global-exclude *.py[co] +global-exclude .DS_Store + diff --git a/README.md b/README.md index 4590709..ac0e132 100644 --- a/README.md +++ b/README.md @@ -1 +1,317 @@ -# ariadne-modules \ No newline at end of file +[![Ariadne](https://ariadnegraphql.org/img/logo-horizontal-sm.png)](https://ariadnegraphql.org) + +- - - - - + +# Ariadne GraphQL Modules + +Ariadne package for implementing Ariadne GraphQL schemas using modular approach. + +For reasoning behind this work, please see [this GitHub discussion](https://github.com/mirumee/ariadne/issues/306). + +See [API reference](./REFERENCE.md) file for documentation. + + +# Examples + +## Basic example + +```python +from datetime import date + +from ariadne.asgi import GraphQL, gql +from ariadne_graphql_modules import ObjectType, make_executable_schema + + +class Query(ObjectType): + __schema__ = gql( + """ + type Query { + message: String! + year: Int! + } + """ + ) + + def resolve_message(*_): + return "Hello world!" + + def resolve_year(*_): + return date.today().year + + +schema = make_executable_schema(Query) +app = GraphQL(schema=schema, debug=True) +``` + + +## Dependency injection + +If `__schema__` string contains other type, its definition should be provided via `__requires__` attribute: + +```python +from typing import List, Optional + +from ariadne.asgi import GraphQL, gql +from ariadne_graphql_modules import ObjectType, make_executable_schema + +from my_app.users import User, get_user, get_last_users + + +class UserType(ObjectType): + __schema__ = gql( + """ + type User { + id: ID! + name: String! + email: String + } + """ + ) + + def resolve_email(user: User, info): + if info.context["is_admin"]: + return user.email + + return None + + +class UsersQueries(ObjectType): + __schema__ = gql( + """ + type Query { + user(id: ID!): User + users: [User!]! + } + """ + ) + __requires__ = [UserType] + + def resolve_user(*_, id: string) -> Optional[User]: + return get_user(id=id) + + def resolve_users(*_, id: string) -> List[User]: + return get_last_users() + + +# UsersQueries already knows about `UserType` so it can be omitted +# in make_executable_schema arguments +schema = make_executable_schema(UsersQueries) +app = GraphQL(schema=schema, debug=True) +``` + + +### Deferred dependencies + +Optionally dependencies can be declared as deferred so they can be provided directly to `make_executable_schema`: + +```python +from typing import List, Optional + +from ariadne.asgi import GraphQL, gql +from ariadne_graphql_modules import DeferredType, ObjectType, make_executable_schema + +from my_app.users import User, get_user, get_last_users + + +class UserType(ObjectType): + __schema__ = gql( + """ + type User { + id: ID! + name: String! + email: String + } + """ + ) + + def resolve_email(user: User, info): + if info.context["is_admin"]: + return user.email + + return None + + +class UsersQueries(ObjectType): + __schema__ = gql( + """ + type Query { + user(id: ID!): User + users: [User!]! + } + """ + ) + __requires__ = [DeferredType("User")] + + def resolve_user(*_, id: string) -> Optional[User]: + return get_user(id=id) + + def resolve_users(*_, id: string) -> List[User]: + return get_last_users() + + +schema = make_executable_schema(UserType, UsersQueries) +app = GraphQL(schema=schema, debug=True) +``` + + +## Automatic case convertion between `python_world` and `clientWorld` + +### Resolving fields values + +Use `__aliases__ = convert_case` to automatically set aliases for fields that convert case + +```python +from ariadne.asgi import gql +from ariadne_graphql_modules import ObjectType, convert_case + + +class UserType(ObjectType): + __schema__ = gql( + """ + type User { + id: ID! + fullName: String! + } + """ + ) + __aliases__ = convert_case +``` + + +### Converting fields arguments + +Use `__fields_args__ = convert_case` on type to automatically convert field arguments to python case in resolver kwargs: + +```python +from ariadne.asgi import gql +from ariadne_graphql_modules import MutationType, convert_case + +from my_app import create_user + + +class UserRegisterMutation(MutationType): + __schema__ = gql( + """ + type Mutation { + registerUser(fullName: String!, email: String!): Boolean! + } + """ + ) + __fields_args__ = convert_case + + async def resolve_mutation(*_, full_name: str, email: str): + user = await create_user( + full_name=full_name, + email=email, + ) + return bool(user) +``` + + +### Converting inputs fields + +Use `__args__ = convert_case` on type to automatically convert input fields to python case in resolver kwargs: + +```python +from ariadne.asgi import gql +from ariadne_graphql_modules import InputType, MutationType, convert_case + +from my_app import create_user + + +class UserRegisterInput(InputType): + __schema__ = gql( + """ + input UserRegisterInput { + fullName: String! + email: String! + } + """ + ) + __args__ = convert_case + + +class UserRegisterMutation(MutationType): + __schema__ = gql( + """ + type Mutation { + registerUser(input: UserRegisterInput!): Boolean! + } + """ + ) + __requires__ = [UserRegisterInput] + + async def resolve_mutation(*_, input: dict): + user = await create_user( + full_name=input["full_name"], + email=input["email"], + ) + return bool(user) +``` + + +## Roots merging + +`Query`, `Mutation` and `Subscription` types are automatically merged into one by `make_executable_schema`: + +```python +from datetime import date + +from ariadne.asgi import GraphQL, gql +from ariadne_graphql_modules import ObjectType, make_executable_schema + + +class YearQuery(ObjectType): + __schema__ = gql( + """ + type Query { + year: Int! + } + """ + ) + + def resolve_year(*_): + return date.today().year + + +class MessageQuery(ObjectType): + __schema__ = gql( + """ + type Query { + message: String! + } + """ + ) + + def resolve_message(*_): + return "Hello world!" + + +schema = make_executable_schema(YearQuery, MessageQuery) +app = GraphQL(schema=schema, debug=True) +``` + +Final schema will contain single `Query` type thats result of merged tupes: + +```graphql +type Query { + message: String! + year: Int! +} +``` + +Fields on final type will be ordered alphabetically. + + +## Contributing + +We are welcoming contributions to Ariadne! If you've found a bug or issue, feel free to use [GitHub issues](https://github.com/mirumee/ariadne/issues). If you have any questions or feedback, don't hesitate to catch us on [GitHub discussions](https://github.com/mirumee/ariadne/discussions/). + +For guidance and instructions, please see [CONTRIBUTING.md](CONTRIBUTING.md). + +Website and the docs have their own GitHub repository: [mirumee/ariadne-website](https://github.com/mirumee/ariadne-website) + +Also make sure you follow [@AriadneGraphQL](https://twitter.com/AriadneGraphQL) on Twitter for latest updates, news and random musings! + +**Crafted with ❤️ by [Mirumee Software](http://mirumee.com)** +hello@mirumee.com diff --git a/REFERENCE.md b/REFERENCE.md new file mode 100644 index 0000000..f6c85bf --- /dev/null +++ b/REFERENCE.md @@ -0,0 +1,414 @@ + + +## `ObjectType` + +New `ObjectType` is base class for Python classes representing GraphQL types (either `type` or `extend type`). + + +### `__schema__` + +`ObjectType` key attribute is `__schema__` string that can define only one GraphQL type: + +```python +class QueryType(ObjectType): + __schema__ = """ + type Query { + year: Int! + } + """ +``` + +`ObjectType` implements validation logic for `__schema__`. It verifies that its valid SDL string defining exactly one GraphQL type. + + +### Resolvers + +Resolvers are class methods or static methods named after schema's fields: + +```python +class QueryType(ObjectType): + __schema__ = """ + type Query { + year: Int! + } + """ + + @staticmethod + def resolve_year(_, info: GraphQLResolveInfo) -> int: + return 2022 +``` + +> `ObjectType` could look up return type of `Int` scalar's `serialize` method and compare it with resolver's return type as extra safety net. + +If resolver function is not present for field, default resolver implemented by `graphql-core` will be used in its place. + +In situations when field's name should be resolved to different value, custom mappings can be defined via `__aliases__` attribute: + +```python +class UserType(ObjectType): + __schema__ = """ + type User { + id: ID! + dateJoined: String! + } + """ + __aliases__ = { + "dateJoined": "date_joined" + } +``` + +Above code will result in Ariadne generating resolver resolving `dateJoined` field to `date_joined` attribute on resolved object. + +If `date_joined` exists as `resolve_date_joined` callable on `ObjectType`, it will be used as resolver for `dateJoined`: + +```python +class UserType(ObjectType): + __schema__ = """ + type User { + id: ID! + dateJoined: String + } + """ + __aliases__ = { + "dateJoined": "date_joined" + } + + @staticmethod + def resolve_date_joined(user, info) -> Optional[str]: + if can_see_activity(info.context): + return user.date_joined + + return None +``` + +> `ObjectType` could raise error if resolver can't be matched to any field on type. + + +### `__requires__` + +When GraphQL type requires on other GraphQL type (or scalar/directive etc. ect.) `ObjectType` will raise an error about missing dependency. This dependency can be provided through `__requires__` attribute: + +```python +class UserType(ObjectType): + __schema__ = """ + type User { + id: ID! + dateJoined: String! + } + """ + + +class UsersGroupType(ObjectType): + __schema__ = """ + type UsersGroup { + id: ID! + users: [User!]! + } + """ + __requires__ = [UserType] +``` + +`ObjectType` verifies that types specified in `__requires__` actually define required types. If `__schema__` in `UserType` is not defining `User`, error will be raised about missing dependency. + +In case of circular dependencies, special `DeferredType` can be used: + +```python +class UserType(ObjectType): + __schema__ = """ + type User { + id: ID! + dateJoined: String! + group: UsersGroup + } + """ + __requires__ = [DeferredType("UsersGroup")] + + +class UsersGroupType(ObjectType): + __schema__ = """ + type UsersGroup { + id: ID! + users: [User!]! + } + """ + __requires__ = [UserType] +``` + +`DeferredType` makes `UserType` happy about `UsersGroup` dependency, deferring dependency check to `make_executable_schema`. If "real" `UsersGroup` is not provided at that time, error will be raised about missing types required to create schema. + + +## `SubscriptionType` + +Specialized subclass of `ObjectType` that defines GraphQL subscription: + +```python +class ChatSubscriptions(SubscriptionType): + __schema__ = """ + type Subscription { + chat: Chat + } + """ + __requires__ = [ChatType] + + async def resolve_chat(chat_id, *_): + return await get_chat_from_db(chat_id) + + async def subscribe_chat(*_): + async for event in subscribe("chats"): + yield event["chat_id"] +``` + + +## `InputType` + +Defines GraphQL input: + +```python +class UserCreateInput(InputType): + __schema__ = """ + input UserInput { + name: String! + email: String! + fullName: String! + } + """ + __args__ = { + "fullName": "full_name", + } +``` + +### `__args__` + +Optional attribue `__args__` is a `Dict[str, str]` used to override key names for `dict` representing input's data. + +Following JSON: + +```json +{ + "name": "Alice", + "email:" "alice@example.com", + "fullName": "Alice Chains" +} +``` + +Will be represented as following dict: + +```python +{ + "name": "Alice", + "email": "alice@example.com", + "full_name": "Alice Chains", +} +``` + + +## `ScalarType` + +Allows you to define custom scalar in your GraphQL schema. + +```python +class DateScalar(ScalarType): + __schema__ = "scalar Datetime" + + @staticmethod + def serialize(value) -> str: + # Called by GraphQL to serialize Python value to + # JSON-serializable format + return value.strftime("%Y-%m-%d") + + @staticmethod + def parse_value(value) -> str: + # Called by GraphQL to parse JSON-serialized value to + # Python type + parsed_datetime = datetime.strptime(formatted_date, "%Y-%m-%d") + return parsed_datetime.date() +``` + +Note that those methods are only required if Python type is not JSON serializable, or you want to customize its serialization process. + +Additionally you may define third method called `parse_literal` that customizes value's deserialization from GraphQL query's AST, but this is only useful for complex types that represent objects: + +```python +from graphql import StringValueNode + + +class DateScalar(Scalar): + __schema__ = "scalar Datetime" + + @staticmethod + def def parse_literal(ast, variable_values: Optional[Dict[str, Any]] = None): + if not isinstance(ast, StringValueNode): + raise ValueError() + + parsed_datetime = datetime.strptime(ast.value, "%Y-%m-%d") + return parsed_datetime.date() +``` + +If you won't define `parse_literal`, GraphQL will use custom logic that will unpack value from AST and then call `parse_value` on it. + + +## `InterfaceType` + +Defines intefrace in GraphQL schema: + +```python +class SearchResultInterface(InterfaceType): + __schema__ = """ + interface SearchResult { + summary: String! + score: Int! + } + """ + + @staticmethod + def resolve_type(obj, info): + # Returns string with name of GraphQL type representing Python type + # from your business logic + if isinstance(obj, UserModel): + return UserType.graphql_name + + if isinstance(obj, CommentModel): + return CommentType.graphql_name + + return None + + @staticmethod + def resolve_summary(obj, info): + # Optional default resolver for summary field, used by types implementing + # this interface when they don't implement their own +``` + + +## `UnionType` + +Defines GraphQL union: + +```python +class SearchResultUnion(UnionType): + __schema__ = "union SearchResult = User | Post | Thread" + __requires__ = [UserType, PostType, ThreadType] + + @staticmethod + def resolve_type(obj, info): + # Returns string with name of GraphQL type representing Python type + # from your business logic + if isinstance(obj, UserModel): + return UserType.graphql_name + + if isinstance(obj, PostModel): + return PostType.graphql_name + + if isinstance(obj, ThreadModel): + return ThreadType.graphql_name + + return None +``` + + +## `DirectiveType` + +Defines new GraphQL directive in your schema and specifies `SchemaDirectiveVisitor` for it: + + +```python +from ariadne import SchemaDirectiveVisitor +from graphql import default_field_resolver + + +class PrefixStringSchemaVisitor(SchemaDirectiveVisitor): + def visit_field_definition(self, field, object_type): + original_resolver = field.resolve or default_field_resolver + + def resolve_prefixed_value(obj, info, **kwargs): + result = original_resolver(obj, info, **kwargs) + if result: + return f"PREFIX: {result}" + return result + + field.resolve = resolve_prefixed_value + return field + + +class PrefixStringDirective(DirectiveType): + __schema__ = "directive @example on FIELD_DEFINITION" + __visitor__ = PrefixStringSchemaVisitor +``` + + +## `make_executable_schema` + +New `make_executable_schema` takes list of Ariadne's types and constructs executable schema from them, performing last-stage validation for types consistency: + +```python +class UserType(ObjectType): + __schema__ = """ + type User { + id: ID! + username: String! + } + """ + + +class QueryType(ObjectType): + __schema__ = """ + type Query { + user: User + } + """ + __requires__ = [UserType] + + @staticmethod + def user(*_): + return { + "id": 1, + "username": "Alice", + } + + +schema = make_executable_schema(QueryType) +``` + + +### Automatic merging of roots + +Passing multiple `Query`, `Mutation` or `Subscription` definitions to `make_executable_schema` by default results in schema defining single types containing sum of all fields defined on those types, ordered alphabetically by field name. + +```python +class UserQueriesType(ObjectType): + __schema__ = """ + type Query { + user(id: ID!): User + } + """ + ... + + +class ProductsQueriesType(ObjectType): + __schema__ = """ + type Query { + product(id: ID!): Product + } + """ + ... + +schema = make_executable_schema(UserQueriesType, ProductsQueriesType) +``` + +Above schema will have single `Query` type looking like this: + +```graphql +type Query { + product(id: ID!): Product + user(id: ID!): User +} +``` + +To opt out of this behavior use `merge_roots=False` option: + +```python +schema = make_executable_schema( + UserQueriesType, + ProductsQueriesType, + merge_roots=False, +) +``` diff --git a/ariadne_graphql_modules/__init__.py b/ariadne_graphql_modules/__init__.py new file mode 100644 index 0000000..a25b060 --- /dev/null +++ b/ariadne_graphql_modules/__init__.py @@ -0,0 +1,30 @@ +from .convert_case import convert_case +from .deferred_type import DeferredType +from .directive_type import DirectiveType +from .enum_type import EnumType +from .executable_schema import make_executable_schema +from .input_type import InputType +from .interface_type import InterfaceType +from .mutation_type import MutationType +from .object_type import ObjectType +from .scalar_type import ScalarType +from .subscription_type import SubscriptionType +from .union_type import UnionType +from .utils import create_alias_resolver, parse_definition + +__all__ = [ + "DeferredType", + "DirectiveType", + "EnumType", + "InputType", + "InterfaceType", + "MutationType", + "ObjectType", + "ScalarType", + "SubscriptionType", + "UnionType", + "convert_case", + "create_alias_resolver", + "make_executable_schema", + "parse_definition", +] diff --git a/ariadne_graphql_modules/base_type.py b/ariadne_graphql_modules/base_type.py new file mode 100644 index 0000000..8f1acc5 --- /dev/null +++ b/ariadne_graphql_modules/base_type.py @@ -0,0 +1,34 @@ +from typing import List, Type + +from graphql import DefinitionNode, GraphQLSchema + +from .dependencies import Dependencies +from .types import RequirementsDict + + +class BaseType: + __abstract__ = True + __schema__: str + __requires__: List[Type["BaseType"]] = [] + + graphql_name: str + graphql_type: Type[DefinitionNode] + + @classmethod + def __get_requirements__(cls) -> RequirementsDict: + return {req.graphql_name: req.graphql_type for req in cls.__requires__} + + @classmethod + def __validate_requirements__( + cls, requirements: RequirementsDict, dependencies: Dependencies + ): + for graphql_name in dependencies: + if graphql_name not in requirements: + raise ValueError( + f"{cls.__name__} class was defined without required GraphQL " + f"definition for '{graphql_name}' in __requires__" + ) + + @classmethod + def __bind_to_schema__(cls, schema: GraphQLSchema): + raise NotImplementedError() diff --git a/ariadne_graphql_modules/convert_case.py b/ariadne_graphql_modules/convert_case.py new file mode 100644 index 0000000..da464e8 --- /dev/null +++ b/ariadne_graphql_modules/convert_case.py @@ -0,0 +1,74 @@ +from typing import Dict, Optional, Union, cast + +from ariadne import convert_camel_case_to_snake +from graphql import DefinitionNode + +from .types import FieldsDict + +Overrides = Dict[str, str] +ArgsOverrides = Dict[str, Overrides] + + +def convert_case( + overrides_or_fields: Optional[Union[FieldsDict, dict]] = None, + map_fields_args=False, +): + no_args_call = convert_case_call_without_args(overrides_or_fields) + + overrides = {} + if not no_args_call: + overrides = cast(dict, overrides_or_fields) + + def create_case_mappings(fields: FieldsDict, map_fields_args=False): + if map_fields_args: + return convert_args_cas(fields, overrides) + + return convert_aliases_case(fields, overrides) + + if no_args_call: + fields = cast(FieldsDict, overrides_or_fields) + return create_case_mappings(fields, map_fields_args) + + return create_case_mappings + + +def convert_case_call_without_args( + overrides_or_fields: Optional[Union[FieldsDict, dict]] = None +) -> bool: + if overrides_or_fields is None: + return True + + if isinstance(list(overrides_or_fields.values())[0], DefinitionNode): + return True + + return False + + +def convert_aliases_case(fields: FieldsDict, overrides: Overrides) -> Overrides: + final_mappings = {} + for field_name in fields: + if field_name in overrides: + field_name_final = overrides[field_name] + else: + field_name_final = convert_camel_case_to_snake(field_name) + if field_name != field_name_final: + final_mappings[field_name] = field_name_final + return final_mappings + + +def convert_args_cas(fields: FieldsDict, overrides: ArgsOverrides) -> ArgsOverrides: + final_mappings = {} + for field_name, field_def in fields.items(): + arg_overrides: Overrides = overrides.get(field_name, {}) + arg_mappings = {} + for arg in field_def.arguments: + arg_name = arg.name.value + if arg_name in arg_overrides: + arg_name_final = arg_overrides[arg_name] + else: + arg_name_final = convert_camel_case_to_snake(arg_name) + if arg_name != arg_name_final: + arg_mappings[arg_name] = arg_name_final + if arg_mappings: + final_mappings[field_name] = arg_mappings + return final_mappings diff --git a/ariadne_graphql_modules/deferred_type.py b/ariadne_graphql_modules/deferred_type.py new file mode 100644 index 0000000..6c8e73a --- /dev/null +++ b/ariadne_graphql_modules/deferred_type.py @@ -0,0 +1,14 @@ +from graphql import ObjectTypeDefinitionNode + +from .base_type import BaseType + + +class DeferredType(BaseType): + graphql_type = ObjectTypeDefinitionNode + + def __init__(self, name: str): + self.graphql_name = name + + @classmethod + def __bind_to_schema__(cls, *_): + raise NotImplementedError("DeferredType cannot be bound to schema") diff --git a/ariadne_graphql_modules/dependencies.py b/ariadne_graphql_modules/dependencies.py new file mode 100644 index 0000000..a099055 --- /dev/null +++ b/ariadne_graphql_modules/dependencies.py @@ -0,0 +1,140 @@ +from typing import Tuple, Union, Set + +from graphql import ( + ConstDirectiveNode, + FieldDefinitionNode, + InputObjectTypeDefinitionNode, + InputObjectTypeExtensionNode, + InputValueDefinitionNode, + InterfaceTypeDefinitionNode, + InterfaceTypeExtensionNode, + NamedTypeNode, + ObjectTypeDefinitionNode, + ObjectTypeExtensionNode, + UnionTypeDefinitionNode, + UnionTypeExtensionNode, +) + +from .utils import unwrap_type_node + +GRAPHQL_TYPES = ("ID", "Int", "String", "Boolean") + +Dependencies = Tuple[str, ...] + + +def get_dependencies_from_object_type( + graphql_type: Union[ + InterfaceTypeDefinitionNode, + InterfaceTypeExtensionNode, + ObjectTypeDefinitionNode, + ObjectTypeExtensionNode, + ] +) -> Dependencies: + dependencies: Set[str] = set() + dependencies.update( + get_dependencies_from_directives(graphql_type.directives), + get_dependencies_from_fields(graphql_type.fields), + get_dependencies_from_interfaces(graphql_type.interfaces), + ) + + if graphql_type.name.value in dependencies: + # Remove self-dependency + dependencies.remove(graphql_type.name.value) + + return tuple(dependencies) + + +def get_dependencies_from_input_type( + graphql_type: Union[InputObjectTypeDefinitionNode, InputObjectTypeExtensionNode] +) -> Dependencies: + dependencies: Set[str] = set() + dependencies.update( + get_dependencies_from_directives(graphql_type.directives), + get_dependencies_from_input_fields(graphql_type.fields), + ) + + if graphql_type.name.value in dependencies: + # Remove self-dependency + dependencies.remove(graphql_type.name.value) + + return tuple(dependencies) + + +def get_dependencies_from_union_type( + graphql_type: Union[UnionTypeDefinitionNode, UnionTypeExtensionNode] +) -> Dependencies: + dependencies: Set[str] = set() + dependencies.update( + get_dependencies_from_directives(graphql_type.directives), + get_dependencies_from_interfaces(graphql_type.types), + ) + + if graphql_type.name.value in dependencies: + # Remove self-dependency + dependencies.remove(graphql_type.name.value) + + return tuple(dependencies) + + +def get_dependencies_from_directives( + directives: Tuple[ConstDirectiveNode, ...] +) -> Dependencies: + dependencies: Set[str] = set() + for directive in directives: + dependencies.add(directive.name.value) + return tuple(dependencies) + + +def get_dependencies_from_fields( + fields: Tuple[FieldDefinitionNode, ...] +) -> Dependencies: + dependencies: Set[str] = set() + + for field_def in fields: + dependencies.update(get_dependencies_from_directives(field_def.directives)) + + # Get dependency from return type + field_type = unwrap_type_node(field_def.type) + if isinstance(field_type, NamedTypeNode): + field_type_name = field_type.name.value + if field_type_name not in GRAPHQL_TYPES: + dependencies.add(field_type_name) + + # Get dependency from arguments + for arg_def in field_def.arguments: + dependencies.update(get_dependencies_from_directives(arg_def.directives)) + + arg_type = unwrap_type_node(arg_def.type) + if isinstance(arg_type, NamedTypeNode): + arg_type_name = arg_type.name.value + if arg_type_name not in GRAPHQL_TYPES: + dependencies.add(arg_type_name) + + return tuple(dependencies) + + +def get_dependencies_from_input_fields( + fields: Tuple[InputValueDefinitionNode, ...] +) -> Dependencies: + dependencies: Set[str] = set() + + for field_def in fields: + dependencies.update(get_dependencies_from_directives(field_def.directives)) + + # Get dependency from return type + field_type = unwrap_type_node(field_def.type) + if isinstance(field_type, NamedTypeNode): + field_type_name = field_type.name.value + if field_type_name not in GRAPHQL_TYPES: + dependencies.add(field_type_name) + + return tuple(dependencies) + + +def get_dependencies_from_interfaces( + interfaces: Tuple[NamedTypeNode, ...] +) -> Dependencies: + dependencies: Set[str] = set() + for interface in interfaces: + dependencies.add(interface.name.value) + return tuple(dependencies) diff --git a/ariadne_graphql_modules/directive_type.py b/ariadne_graphql_modules/directive_type.py new file mode 100644 index 0000000..c857fb0 --- /dev/null +++ b/ariadne_graphql_modules/directive_type.py @@ -0,0 +1,53 @@ +from typing import Type, cast + +from ariadne import SchemaDirectiveVisitor +from graphql import ( + DefinitionNode, + DirectiveDefinitionNode, +) + +from .base_type import BaseType +from .utils import parse_definition + + +class DirectiveType(BaseType): + __abstract__ = True + __visitor__: Type[SchemaDirectiveVisitor] + + def __init_subclass__(cls) -> None: + super().__init_subclass__() + + if cls.__dict__.get("__abstract__"): + return + + cls.__abstract__ = False + + graphql_def = cls.__validate_schema__( + parse_definition(cls.__name__, cls.__schema__) + ) + + cls.graphql_name = graphql_def.name.value + cls.graphql_type = type(graphql_def) + + cls.__validate_visitor__() + + @classmethod + def __validate_schema__(cls, type_def: DefinitionNode) -> DirectiveDefinitionNode: + if not isinstance(type_def, DirectiveDefinitionNode): + raise ValueError( + f"{cls.__name__} class was defined with __schema__ " + "without GraphQL directive" + ) + + return cast(DirectiveDefinitionNode, type_def) + + @classmethod + def __validate_visitor__(cls): + if not getattr(cls, "__visitor__", None): + raise AttributeError( + f"{cls.__name__} class was defined without __visitor__ attribute" + ) + + @staticmethod + def __bind_to_schema__(*_): + pass # Binding directive to schema is noop diff --git a/ariadne_graphql_modules/enum_type.py b/ariadne_graphql_modules/enum_type.py new file mode 100644 index 0000000..9feb66e --- /dev/null +++ b/ariadne_graphql_modules/enum_type.py @@ -0,0 +1,121 @@ +from enum import Enum +from typing import List, Optional, Type, Union, cast + +import ariadne +from graphql import ( + DefinitionNode, + GraphQLSchema, + EnumTypeDefinitionNode, + EnumTypeExtensionNode, +) + +from .base_type import BaseType +from .types import RequirementsDict +from .utils import parse_definition + +EnumNodeType = Union[EnumTypeDefinitionNode, EnumTypeExtensionNode] + + +class EnumType(BaseType): + __abstract__ = True + __enum__: Optional[Union[Type[Enum], dict]] = None + + graphql_type: Union[Type[EnumTypeDefinitionNode], Type[EnumTypeExtensionNode]] + + def __init_subclass__(cls) -> None: + super().__init_subclass__() + + if cls.__dict__.get("__abstract__"): + return + + cls.__abstract__ = False + + graphql_def = cls.__validate_schema__( + parse_definition(cls.__name__, cls.__schema__) + ) + + cls.graphql_name = graphql_def.name.value + cls.graphql_type = type(graphql_def) + + requirements = cls.__get_requirements__() + cls.__validate_requirements_contain_extended_type__(graphql_def, requirements) + + values = cls.__get_values__(graphql_def) + cls.__validate_values__(values) + + @classmethod + def __validate_schema__(cls, type_def: DefinitionNode) -> EnumNodeType: + if not isinstance(type_def, (EnumTypeDefinitionNode, EnumTypeExtensionNode)): + raise ValueError( + f"{cls.__name__} class was defined with __schema__ without GraphQL enum" + ) + + return cast(EnumNodeType, type_def) + + @classmethod + def __validate_requirements_contain_extended_type__( + cls, type_def: EnumNodeType, requirements: RequirementsDict + ): + if not isinstance(type_def, EnumTypeExtensionNode): + return + + graphql_name = type_def.name.value + if graphql_name not in requirements: + raise ValueError( + f"{cls.__name__} graphql type was defined without required GraphQL " + f"type definition for '{graphql_name}' in __requires__" + ) + + if requirements[graphql_name] != EnumTypeDefinitionNode: + raise ValueError( + f"{cls.__name__} requires '{graphql_name}' to be GraphQL enum " + f"but other type was provided in '__requires__'" + ) + + @classmethod + def __get_values__(cls, type_def: EnumNodeType) -> List[str]: + if not type_def.values and not ( + isinstance(type_def, EnumTypeExtensionNode) and type_def.directives + ): + raise ValueError( + f"{cls.__name__} class was defined with __schema__ containing " + f"empty, GraphQL enum definition" + ) + + return [value.name.value for value in type_def.values] + + @classmethod + def __validate_values__(cls, values: List[str]): + if not cls.__enum__: + return + + if isinstance(cls.__enum__, dict): + enum_keys = set(cls.__enum__.keys()) + else: + enum_keys = set(cls.__enum__.__members__.keys()) + + missing_keys = set(values) - enum_keys + if missing_keys: + raise ValueError( + f"{cls.__name__} class was defined with __enum__ missing following " + f"items required by GraphQL definition: {', '.join(missing_keys)}" + ) + + extra_keys = enum_keys - set(values) + if extra_keys: + raise ValueError( + f"{cls.__name__} class was defined with __enum__ containing extra " + f"items missing in GraphQL definition: {', '.join(extra_keys)}" + ) + + @classmethod + def __bind_to_schema__(cls, schema: GraphQLSchema): + if cls.__enum__: + bindable = ariadne.EnumType(cls.graphql_name, cls.__enum__) + bindable.bind_to_schema(schema) + + @classmethod + def __bind_to_default_values__(cls, schema: GraphQLSchema): + if cls.__enum__: + bindable = ariadne.EnumType(cls.graphql_name, cls.__enum__) + bindable.bind_to_default_values(schema) diff --git a/ariadne_graphql_modules/executable_schema.py b/ariadne_graphql_modules/executable_schema.py new file mode 100644 index 0000000..90adab9 --- /dev/null +++ b/ariadne_graphql_modules/executable_schema.py @@ -0,0 +1,171 @@ +from typing import Dict, Iterable, List, Tuple, Type, cast + +from ariadne import ( + SchemaDirectiveVisitor, + set_default_enum_values_on_schema, + validate_schema_enum_values, +) +from graphql import ( + ConstDirectiveNode, + DocumentNode, + FieldDefinitionNode, + GraphQLSchema, + NamedTypeNode, + ObjectTypeDefinitionNode, + assert_valid_schema, + build_ast_schema, + concat_ast, + parse, +) +from graphql.language import ast + +from .base_type import BaseType +from .deferred_type import DeferredType +from .enum_type import EnumType + +ROOT_TYPES = ["Query", "Mutation", "Subscription"] + + +def make_executable_schema( + *types, + merge_roots: bool = True, +): + all_types: List[Type[BaseType]] = [] + find_requirements(all_types, types) + + real_types = [type_ for type_ in all_types if not isinstance(type_, DeferredType)] + validate_no_missing_types(real_types, all_types) + + schema = build_schema(real_types, merge_roots) + set_default_enum_values_on_schema(schema) + assert_valid_schema(schema) + validate_schema_enum_values(schema) + repair_default_enum_values(schema, real_types) + + add_directives_to_schema(schema, real_types) + + return schema + + +def find_requirements( + types_list: List[Type[BaseType]], types: Iterable[Type[BaseType]] +): + for type_ in types: + if type_ not in types_list: + types_list.append(type_) + + find_requirements(types_list, type_.__requires__) + + +def validate_no_missing_types( + real_types: List[Type[BaseType]], all_types: List[Type[BaseType]] +): + deferred_names = [ + deferred.graphql_name + for deferred in all_types + if isinstance(deferred, DeferredType) + ] + + real_names = [type_.graphql_name for type_ in real_types] + missing_names = set(deferred_names) - set(real_names) + if missing_names: + raise ValueError( + "Following types are defined as deferred and are missing " + f"from schema: {', '.join(missing_names)}" + ) + + +def build_schema( + types_list: List[Type[BaseType]], merge_roots: bool = True +) -> GraphQLSchema: + schema_definitions: List[ast.DocumentNode] = [] + if merge_roots: + schema_definitions.append(build_root_schema(types_list)) + for type_ in types_list: + if type_.graphql_name not in ROOT_TYPES or not merge_roots: + schema_definitions.append(parse(type_.__schema__)) + + ast_document = concat_ast(schema_definitions) + schema = build_ast_schema(ast_document) + + for type_ in types_list: + type_.__bind_to_schema__(schema) + + return schema + + +def build_root_schema(types_list: List[Type[BaseType]]) -> DocumentNode: + root_types: Dict[str, List[Type[BaseType]]] = { + "Query": [], + "Mutation": [], + "Subscription": [], + } + + for type_ in types_list: + if type_.graphql_name in root_types: + root_types[type_.graphql_name].append(type_) + + schema: List[DocumentNode] = [] + for types_defs in root_types.values(): + if len(types_defs) == 1: + schema.append(parse(types_defs[0].__schema__)) + elif types_defs: + schema.append(merge_root_types(types_defs)) + + return concat_ast(schema) + + +def merge_root_types(types_list: List[Type[BaseType]]) -> DocumentNode: + interfaces: List[NamedTypeNode] = [] + directives: List[ConstDirectiveNode] = [] + fields: Dict[str, Tuple[FieldDefinitionNode, Type[BaseType]]] = {} + + for type_ in types_list: + type_definition = cast( + ObjectTypeDefinitionNode, + parse(type_.__schema__).definitions[0], + ) + interfaces.extend(type_definition.interfaces) + directives.extend(type_definition.directives) + + for field_def in type_definition.fields: + field_name = field_def.name.value + if field_name in fields: + other_type_name = fields[field_name][1].__name__ + raise ValueError( + f"Multiple {type_.graphql_name} types are defining same field " + f"'{field_name}': {other_type_name}, {type_.__name__}" + ) + + fields[field_name] = (field_def, type_) + + merged_definition = ast.ObjectTypeDefinitionNode() + merged_definition.name = ast.NameNode() + merged_definition.name.value = types_list[0].graphql_name + merged_definition.interfaces = tuple(interfaces) + merged_definition.directives = tuple(directives) + merged_definition.fields = tuple( + fields[field_name][0] for field_name in sorted(fields) + ) + + merged_document = DocumentNode() + merged_document.definitions = (merged_definition,) + + return merged_document + + +def add_directives_to_schema(schema: GraphQLSchema, types_list: List[Type[BaseType]]): + directives: Dict[str, Type[SchemaDirectiveVisitor]] = {} + for type_ in types_list: + visitor = getattr(type_, "__visitor__", None) + if visitor and issubclass(visitor, SchemaDirectiveVisitor): + directives[type_.graphql_name] = visitor + + if directives: + SchemaDirectiveVisitor.visit_schema_directives(schema, directives) + + +def repair_default_enum_values(schema, types_list: List[Type[BaseType]]) -> None: + for type_ in types_list: + if issubclass(type_, EnumType): + type_.__bind_to_default_values__(schema) diff --git a/ariadne_graphql_modules/input_type.py b/ariadne_graphql_modules/input_type.py new file mode 100644 index 0000000..8487d2b --- /dev/null +++ b/ariadne_graphql_modules/input_type.py @@ -0,0 +1,118 @@ +from typing import Callable, Dict, Optional, Union, cast + +from graphql import ( + DefinitionNode, + InputObjectTypeDefinitionNode, + InputObjectTypeExtensionNode, +) + +from .base_type import BaseType +from .dependencies import Dependencies, get_dependencies_from_input_type +from .types import InputFieldsDict, RequirementsDict +from .utils import parse_definition + +Args = Dict[str, str] +InputNodeType = Union[InputObjectTypeDefinitionNode, InputObjectTypeExtensionNode] + + +class InputType(BaseType): + __abstract__ = True + __args__: Optional[Union[Args, Callable[..., Args]]] = None + + graphql_fields: InputFieldsDict + + def __init_subclass__(cls) -> None: + super().__init_subclass__() + + if cls.__dict__.get("__abstract__"): + return + + cls.__abstract__ = False + + graphql_def = cls.__validate_schema__( + parse_definition(cls.__name__, cls.__schema__) + ) + + cls.graphql_name = graphql_def.name.value + cls.graphql_type = type(graphql_def) + cls.graphql_fields = cls.__get_fields__(graphql_def) + + if callable(cls.__args__): + # pylint: disable=not-callable + cls.__args__ = cls.__args__(cls.graphql_fields) + + cls.__validate_args__() + + requirements = cls.__get_requirements__() + cls.__validate_requirements_contain_extended_type__(graphql_def, requirements) + + dependencies = cls.__get_dependencies__(graphql_def) + cls.__validate_requirements__(requirements, dependencies) + + @classmethod + def __validate_schema__(cls, type_def: DefinitionNode) -> InputNodeType: + if not isinstance( + type_def, (InputObjectTypeDefinitionNode, InputObjectTypeExtensionNode) + ): + raise ValueError( + f"{cls.__name__} class was defined with __schema__ without GraphQL input" + ) + + return cast(InputNodeType, type_def) + + @classmethod + def __validate_requirements_contain_extended_type__( + cls, type_def: InputNodeType, requirements: RequirementsDict + ): + if not isinstance(type_def, InputObjectTypeExtensionNode): + return + + graphql_name = type_def.name.value + if graphql_name not in requirements: + raise ValueError( + f"{cls.__name__} graphql type was defined without required GraphQL " + f"type definition for '{graphql_name}' in __requires__" + ) + + if requirements[graphql_name] != InputObjectTypeDefinitionNode: + raise ValueError( + f"{cls.__name__} requires '{graphql_name}' to be GraphQL input " + f"but other type was provided in '__requires__'" + ) + + @classmethod + def __get_fields__(cls, type_def: InputNodeType) -> InputFieldsDict: + if not type_def.fields and not ( + isinstance(type_def, InputObjectTypeExtensionNode) and type_def.directives + ): + raise ValueError( + f"{cls.__name__} class was defined with __schema__ containing empty " + f"GraphQL input definition" + ) + + return {field.name.value: field for field in type_def.fields} + + @classmethod + def __validate_args__(cls): + if not cls.__args__: + return + + invalid_args = set(cls.__args__) - set(cls.graphql_fields) + if invalid_args: + raise ValueError( + f"{cls.__name__} class was defined with args for fields not in " + f"GraphQL input: {', '.join(invalid_args)}" + ) + + @classmethod + def __get_dependencies__(cls, type_def: InputNodeType) -> Dependencies: + return get_dependencies_from_input_type(type_def) + + @classmethod + def __bind_to_schema__(cls, schema): + if not cls.__args__: + return + + graphql_type = schema.type_map.get(cls.graphql_name) + for field_name, field_target in cls.__args__.items(): + graphql_type.fields[field_name].out_name = field_target diff --git a/ariadne_graphql_modules/interface_type.py b/ariadne_graphql_modules/interface_type.py new file mode 100644 index 0000000..9fcf328 --- /dev/null +++ b/ariadne_graphql_modules/interface_type.py @@ -0,0 +1,136 @@ +from typing import Dict, Callable, Type, Union, cast + +from graphql import ( + DefinitionNode, + GraphQLFieldResolver, + GraphQLInterfaceType, + GraphQLObjectType, + GraphQLSchema, + GraphQLTypeResolver, + InterfaceTypeDefinitionNode, + InterfaceTypeExtensionNode, +) + +from ariadne import type_implements_interface + +from .base_type import BaseType +from .dependencies import Dependencies, get_dependencies_from_object_type +from .resolvers_mixin import ResolversMixin +from .types import FieldsDict, RequirementsDict +from .utils import parse_definition + +InterfaceNodeType = Union[InterfaceTypeDefinitionNode, InterfaceTypeExtensionNode] + + +class InterfaceType(BaseType, ResolversMixin): + __abstract__ = True + + graphql_name: str + graphql_type: Union[ + Type[InterfaceTypeDefinitionNode], Type[InterfaceTypeExtensionNode] + ] + + resolve_type: GraphQLTypeResolver + resolvers: Dict[str, GraphQLFieldResolver] + + def __init_subclass__(cls) -> None: + super().__init_subclass__() + + if cls.__dict__.get("__abstract__"): + return + + cls.__abstract__ = False + + graphql_def = cls.__validate_schema__( + parse_definition(cls.__name__, cls.__schema__) + ) + + cls.graphql_name = graphql_def.name.value + cls.graphql_type = type(graphql_def) + cls.graphql_fields = cls.__get_fields__(graphql_def) + + requirements = cls.__get_requirements__() + cls.__validate_requirements_contain_extended_type__(graphql_def, requirements) + + dependencies = cls.__get_dependencies__(graphql_def) + cls.__validate_requirements__(requirements, dependencies) + + if callable(cls.__fields_args__): + cls.__fields_args__ = cls.__fields_args__(cls.graphql_fields, True) + + cls.__validate_fields_args__() + + if callable(cls.__aliases__): + cls.__aliases__ = cls.__aliases__(cls.graphql_fields) + + cls.__validate_aliases__() + cls.resolvers = cls.__get_resolvers__() + + @classmethod + def __validate_schema__(cls, type_def: DefinitionNode) -> InterfaceNodeType: + if not isinstance( + type_def, (InterfaceTypeDefinitionNode, InterfaceTypeExtensionNode) + ): + raise ValueError( + f"{cls.__name__} class was defined with __schema__ without " + "GraphQL interface" + ) + + return cast(InterfaceNodeType, type_def) + + @classmethod + def __validate_requirements_contain_extended_type__( + cls, type_def: InterfaceNodeType, requirements: RequirementsDict + ): + if not isinstance(type_def, InterfaceTypeExtensionNode): + return + + graphql_name = type_def.name.value + if graphql_name not in requirements: + raise ValueError( + f"{cls.__name__} graphql type was defined without required GraphQL " + f"type definition for '{graphql_name}' in __requires__" + ) + + if requirements[graphql_name] != InterfaceTypeDefinitionNode: + raise ValueError( + f"{cls.__name__} requires '{graphql_name}' to be GraphQL interface " + f"but other type was provided in '__requires__'" + ) + + @classmethod + def __get_fields__(cls, type_def: InterfaceNodeType) -> FieldsDict: + if not type_def.fields and not ( + isinstance(type_def, InterfaceTypeExtensionNode) + and (type_def.directives or type_def.interfaces) + ): + raise ValueError( + f"{cls.__name__} class was defined with __schema__ containing empty " + f"GraphQL interface definition" + ) + + return {field.name.value: field for field in type_def.fields} + + @classmethod + def __get_defined_resolvers__(cls) -> Dict[str, Callable]: + resolvers = super().__get_defined_resolvers__() + resolvers.pop("type", None) + return resolvers + + @classmethod + def __get_dependencies__(cls, type_def: InterfaceNodeType) -> Dependencies: + return get_dependencies_from_object_type(type_def) + + @classmethod + def __bind_to_schema__(cls, schema: GraphQLSchema): + graphql_type = cast(GraphQLInterfaceType, schema.type_map.get(cls.graphql_name)) + graphql_type.resolve_type = cls.resolve_type + + for type_ in schema.type_map.values(): + if not type_implements_interface(cls.graphql_name, type_): + continue + + type_ = cast(GraphQLObjectType, type_) + for field_name, field_resolver in cls.resolvers.items(): + if not type_.fields[field_name].resolve: + type_.fields[field_name].resolve = field_resolver diff --git a/ariadne_graphql_modules/mutation_type.py b/ariadne_graphql_modules/mutation_type.py new file mode 100644 index 0000000..111acec --- /dev/null +++ b/ariadne_graphql_modules/mutation_type.py @@ -0,0 +1,150 @@ +from typing import Dict, Optional, Type, Union, cast + +from graphql import ( + DefinitionNode, + FieldDefinitionNode, + GraphQLFieldResolver, + ObjectTypeDefinitionNode, + ObjectTypeExtensionNode, +) + +from .base_type import BaseType +from .dependencies import Dependencies, get_dependencies_from_object_type +from .types import RequirementsDict +from .utils import parse_definition + +ObjectNodeType = Union[ObjectTypeDefinitionNode, ObjectTypeExtensionNode] + + +class MutationType(BaseType): + __abstract__ = True + __args__: Optional[Dict[str, str]] = None + + graphql_name = "Mutation" + graphql_type: Union[Type[ObjectTypeDefinitionNode], Type[ObjectTypeExtensionNode]] + + mutation_name: str + resolve_mutation: GraphQLFieldResolver + + def __init_subclass__(cls) -> None: + super().__init_subclass__() + + if cls.__dict__.get("__abstract__"): + return + + cls.__abstract__ = False + + graphql_def = cls.__validate_schema__( + parse_definition(cls.__name__, cls.__schema__) + ) + + cls.graphql_name = graphql_def.name.value + cls.graphql_type = type(graphql_def) + + field = cls.__get_field__(graphql_def) + cls.mutation_name = field.name.value + + requirements = cls.__get_requirements__() + cls.__validate_requirements_contain_extended_type__(graphql_def, requirements) + + dependencies = cls.__get_dependencies__(graphql_def) + cls.__validate_requirements__(requirements, dependencies) + + cls.__validate_args__(field) + cls.__validate_resolve_mutation__() + + @classmethod + def __validate_schema__(cls, type_def: DefinitionNode) -> ObjectNodeType: + if not isinstance( + type_def, (ObjectTypeDefinitionNode, ObjectTypeExtensionNode) + ): + raise ValueError( + f"{cls.__name__} class was defined with __schema__ without GraphQL type" + ) + + if type_def.name.value != "Mutation": + raise ValueError( + f"{cls.__name__} class was defined with __schema__ containing " + f"GraphQL definition for 'type {type_def.name.value}' while " + "'type Mutation' was expected" + ) + + return cast(ObjectNodeType, type_def) + + @classmethod + def __validate_requirements_contain_extended_type__( + cls, type_def: ObjectNodeType, requirements: RequirementsDict + ): + if not isinstance(type_def, ObjectTypeExtensionNode): + return + + graphql_name = type_def.name.value + if graphql_name not in requirements: + raise ValueError( + f"{cls.__name__} graphql type was defined without required GraphQL " + f"type definition for '{graphql_name}' in __requires__" + ) + + if requirements[graphql_name] != ObjectTypeDefinitionNode: + raise ValueError( + f"{cls.__name__} requires '{graphql_name}' to be GraphQL type " + f"but other type was provided in '__requires__'" + ) + + @classmethod + def __get_field__(cls, type_def: ObjectNodeType) -> FieldDefinitionNode: + if not type_def.fields: + raise ValueError( + f"{cls.__name__} class was defined with __schema__ containing " + f"empty GraphQL type definition" + ) + + if len(type_def.fields) != 1: + raise ValueError( + f"{cls.__name__} class subclasses 'MutationType' class which " + "requires __schema__ to define exactly one field" + ) + + return type_def.fields[0] + + @classmethod + def __get_dependencies__(cls, type_def: ObjectNodeType) -> Dependencies: + return get_dependencies_from_object_type(type_def) + + @classmethod + def __validate_args__(cls, field: FieldDefinitionNode): + if not cls.__args__: + return + + field_args = [arg.name.value for arg in field.arguments] + invalid_args = set(cls.__args__) - set(field_args) + if invalid_args: + raise ValueError( + f"{cls.__name__} class was defined with args not on " + f"'{field.name.value}' GraphQL field: {', '.join(invalid_args)}" + ) + + @classmethod + def __validate_resolve_mutation__(cls): + resolver = getattr(cls, "resolve_mutation", None) + if not resolver: + raise AttributeError( + f"{cls.__name__} class was defined without required " + "'resolve_mutation' attribute" + ) + + if not callable(resolver): + raise TypeError( + f"{cls.__name__} class was defined with attribute " + "'resolve_mutation' but it's not callable" + ) + + @classmethod + def __bind_to_schema__(cls, schema): + graphql_type = schema.type_map.get(cls.graphql_name) + graphql_type.fields[cls.mutation_name].resolve = cls.resolve_mutation + + if cls.__args__: + field_args = graphql_type.fields[cls.mutation_name].args + for arg_name, out_name in cls.__args__.items(): + field_args[arg_name].out_name = out_name diff --git a/ariadne_graphql_modules/object_type.py b/ariadne_graphql_modules/object_type.py new file mode 100644 index 0000000..ed1d81f --- /dev/null +++ b/ariadne_graphql_modules/object_type.py @@ -0,0 +1,120 @@ +from typing import Dict, Type, Union, cast + +from graphql import ( + DefinitionNode, + GraphQLFieldResolver, + ObjectTypeDefinitionNode, + ObjectTypeExtensionNode, +) + +from .base_type import BaseType +from .dependencies import Dependencies, get_dependencies_from_object_type +from .resolvers_mixin import ResolversMixin +from .types import FieldsDict, RequirementsDict +from .utils import parse_definition + +ObjectNodeType = Union[ObjectTypeDefinitionNode, ObjectTypeExtensionNode] + + +class ObjectType(BaseType, ResolversMixin): + __abstract__ = True + + graphql_name: str + graphql_type: Union[Type[ObjectTypeDefinitionNode], Type[ObjectTypeExtensionNode]] + + resolvers: Dict[str, GraphQLFieldResolver] + + def __init_subclass__(cls) -> None: + super().__init_subclass__() + + if cls.__dict__.get("__abstract__"): + return + + cls.__abstract__ = False + + graphql_def = cls.__validate_schema__( + parse_definition(cls.__name__, cls.__schema__) + ) + + cls.graphql_name = graphql_def.name.value + cls.graphql_type = type(graphql_def) + cls.graphql_fields = cls.__get_fields__(graphql_def) + + requirements = cls.__get_requirements__() + cls.__validate_requirements_contain_extended_type__(graphql_def, requirements) + + dependencies = cls.__get_dependencies__(graphql_def) + cls.__validate_requirements__(requirements, dependencies) + + if callable(cls.__fields_args__): + cls.__fields_args__ = cls.__fields_args__(cls.graphql_fields, True) + + cls.__validate_fields_args__() + + if callable(cls.__aliases__): + cls.__aliases__ = cls.__aliases__(cls.graphql_fields) + + cls.__validate_aliases__() + cls.resolvers = cls.__get_resolvers__() + + @classmethod + def __validate_schema__(cls, type_def: DefinitionNode) -> ObjectNodeType: + if not isinstance( + type_def, (ObjectTypeDefinitionNode, ObjectTypeExtensionNode) + ): + raise ValueError( + f"{cls.__name__} class was defined with __schema__ without GraphQL type" + ) + + if type_def.name.value == "Subscription": + raise ValueError( + f"{cls.__name__} class was defined with __schema__ containing " + f"GraphQL definition for 'type Subscription' which is only supported " + "by subsclassess of 'SubscriptionType'" + ) + + return cast(ObjectNodeType, type_def) + + @classmethod + def __validate_requirements_contain_extended_type__( + cls, type_def: ObjectNodeType, requirements: RequirementsDict + ): + if not isinstance(type_def, ObjectTypeExtensionNode): + return + + graphql_name = type_def.name.value + if graphql_name not in requirements: + raise ValueError( + f"{cls.__name__} graphql type was defined without required GraphQL " + f"type definition for '{graphql_name}' in __requires__" + ) + + if requirements[graphql_name] != ObjectTypeDefinitionNode: + raise ValueError( + f"{cls.__name__} requires '{graphql_name}' to be GraphQL type " + f"but other type was provided in '__requires__'" + ) + + @classmethod + def __get_fields__(cls, type_def: ObjectNodeType) -> FieldsDict: + if not type_def.fields and not ( + isinstance(type_def, ObjectTypeExtensionNode) + and (type_def.directives or type_def.interfaces) + ): + raise ValueError( + f"{cls.__name__} class was defined with __schema__ containing " + f"empty GraphQL type definition" + ) + + return {field.name.value: field for field in type_def.fields} + + @classmethod + def __get_dependencies__(cls, type_def: ObjectNodeType) -> Dependencies: + return get_dependencies_from_object_type(type_def) + + @classmethod + def __bind_to_schema__(cls, schema): + graphql_type = schema.type_map.get(cls.graphql_name) + + for field_name, field_resolver in cls.resolvers.items(): + graphql_type.fields[field_name].resolve = field_resolver diff --git a/ariadne_graphql_modules/py.typed b/ariadne_graphql_modules/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/ariadne_graphql_modules/resolvers_mixin.py b/ariadne_graphql_modules/resolvers_mixin.py new file mode 100644 index 0000000..c91fee4 --- /dev/null +++ b/ariadne_graphql_modules/resolvers_mixin.py @@ -0,0 +1,102 @@ +from typing import Callable, Dict, Optional, Union + +from graphql import GraphQLFieldResolver + +from .types import FieldsDict +from .utils import create_alias_resolver + +Aliases = Dict[str, str] +FieldsArgs = Dict[str, Dict[str, str]] + + +class ResolversMixin: + """Adds aliases and resolvers logic to GraphQL type""" + + __aliases__: Optional[Union[Aliases, Callable[..., Aliases]]] = None + __fields_args__: Optional[Union[FieldsArgs, Callable[..., FieldsArgs]]] = None + + graphql_name: str + graphql_fields: FieldsDict + + resolvers: Dict[str, GraphQLFieldResolver] + + @classmethod + def __validate_aliases__(cls): + if not cls.__aliases__: + return + + invalid_aliases = set(cls.__aliases__) - set(cls.graphql_fields) + if invalid_aliases: + raise ValueError( + f"{cls.__name__} class was defined with aliases for fields not in " + f"GraphQL type: {', '.join(invalid_aliases)}" + ) + + @classmethod + def __validate_fields_args__(cls): + if not cls.__fields_args__: + return + + invalid_fields = set(cls.__fields_args__) - set(cls.graphql_fields) + if invalid_fields: + raise ValueError( + f"{cls.__name__} class was defined with fields args mappings " + f"for fields not in GraphQL type: {', '.join(invalid_fields)}" + ) + + for field_name, field_args in cls.__fields_args__.items(): + defined_args = [ + arg.name.value for arg in cls.graphql_fields[field_name].arguments + ] + invalid_args = set(field_args) - set(defined_args) + if invalid_args: + raise ValueError( + f"{cls.__name__} class was defined with args mappings not in " + f"not in '{field_name}' field: {', '.join(invalid_args)}" + ) + + @classmethod + def __get_resolvers__(cls): + aliases = cls.__aliases__ or {} + defined_resolvers = cls.__get_defined_resolvers__() + + used_resolvers = [] + resolvers = {} + + for field_name in cls.graphql_fields: + if aliases and field_name in aliases: + resolver_name = aliases[field_name] + if resolver_name in defined_resolvers: + used_resolvers.append(resolver_name) + resolvers[field_name] = defined_resolvers[resolver_name] + else: + resolvers[field_name] = create_alias_resolver(resolver_name) + + elif field_name in defined_resolvers: + used_resolvers.append(field_name) + resolvers[field_name] = defined_resolvers[field_name] + + unused_resolvers = [ + f"resolve_{field_name}" + for field_name in set(defined_resolvers) - set(used_resolvers) + ] + if unused_resolvers: + raise ValueError( + f"{cls.__name__} class was defined with resolvers for fields not in " + f"GraphQL type: {', '.join(unused_resolvers)}" + ) + + return resolvers + + @classmethod + def __get_defined_resolvers__(cls) -> Dict[str, Callable]: + resolvers = {} + for name in dir(cls): + if not name.startswith("resolve_"): + continue + + value = getattr(cls, name) + if callable(value): + resolvers[name[8:]] = value + + return resolvers diff --git a/ariadne_graphql_modules/scalar_type.py b/ariadne_graphql_modules/scalar_type.py new file mode 100644 index 0000000..8898acf --- /dev/null +++ b/ariadne_graphql_modules/scalar_type.py @@ -0,0 +1,90 @@ +from typing import Optional, Type, Union, cast + +from graphql import ( + DefinitionNode, + GraphQLScalarSerializer, + GraphQLScalarType, + GraphQLScalarLiteralParser, + GraphQLScalarValueParser, + GraphQLSchema, + ScalarTypeDefinitionNode, + ScalarTypeExtensionNode, +) + +from .base_type import BaseType +from .types import RequirementsDict +from .utils import parse_definition + +ScalarNodeType = Union[ScalarTypeDefinitionNode, ScalarTypeExtensionNode] + + +class ScalarType(BaseType): + __abstract__ = True + + graphql_type: Union[Type[ScalarTypeDefinitionNode], Type[ScalarTypeExtensionNode]] + + serialize: Optional[GraphQLScalarSerializer] = None + parse_value: Optional[GraphQLScalarValueParser] = None + parse_literal: Optional[GraphQLScalarLiteralParser] = None + + def __init_subclass__(cls) -> None: + super().__init_subclass__() + + if cls.__dict__.get("__abstract__"): + return + + cls.__abstract__ = False + + graphql_def = cls.__validate_schema__( + parse_definition(cls.__name__, cls.__schema__) + ) + + cls.graphql_name = graphql_def.name.value + cls.graphql_type = type(graphql_def) + + requirements = cls.__get_requirements__() + cls.__validate_requirements_contain_extended_type__(graphql_def, requirements) + + @classmethod + def __validate_schema__(cls, type_def: DefinitionNode) -> ScalarNodeType: + if not isinstance( + type_def, (ScalarTypeDefinitionNode, ScalarTypeExtensionNode) + ): + raise ValueError( + f"{cls.__name__} class was defined with __schema__ " + "without GraphQL scalar" + ) + + return cast(ScalarNodeType, type_def) + + @classmethod + def __validate_requirements_contain_extended_type__( + cls, type_def: ScalarNodeType, requirements: RequirementsDict + ): + if not isinstance(type_def, ScalarTypeExtensionNode): + return + + graphql_name = type_def.name.value + if graphql_name not in requirements: + raise ValueError( + f"{cls.__name__} graphql type was defined without required GraphQL " + f"scalar, definition for '{graphql_name}' in __requires__" + ) + + if requirements[graphql_name] != ScalarTypeDefinitionNode: + raise ValueError( + f"{cls.__name__} requires '{graphql_name}' to be GraphQL scalar " + f"but other type was provided in '__requires__'" + ) + + @classmethod + def __bind_to_schema__(cls, schema: GraphQLSchema): + graphql_type = cast(GraphQLScalarType, schema.type_map.get(cls.graphql_name)) + + # See mypy bug https://github.com/python/mypy/issues/2427 + if cls.serialize: + graphql_type.serialize = cls.serialize # type: ignore + if cls.parse_value: + graphql_type.parse_value = cls.parse_value # type: ignore + if cls.parse_literal: + graphql_type.parse_literal = cls.parse_literal # type: ignore diff --git a/ariadne_graphql_modules/subscription_type.py b/ariadne_graphql_modules/subscription_type.py new file mode 100644 index 0000000..4432859 --- /dev/null +++ b/ariadne_graphql_modules/subscription_type.py @@ -0,0 +1,100 @@ +from typing import Callable, Dict, Union, cast + +from graphql import ( + DefinitionNode, + GraphQLFieldResolver, + GraphQLObjectType, + GraphQLSchema, + ObjectTypeDefinitionNode, + ObjectTypeExtensionNode, +) + +from .object_type import ObjectType + +ObjectNodeType = Union[ObjectTypeDefinitionNode, ObjectTypeExtensionNode] + + +class SubscriptionType(ObjectType): + __abstract__ = True + + subscribers: Dict[str, GraphQLFieldResolver] + + def __init_subclass__(cls) -> None: + super().__init_subclass__() + + if cls.__dict__.get("__abstract__"): + return + + cls.subscribers = cls.__get_subscribers__() + + @classmethod + def __validate_schema__(cls, type_def: DefinitionNode) -> ObjectNodeType: + if not isinstance( + type_def, (ObjectTypeDefinitionNode, ObjectTypeExtensionNode) + ): + raise ValueError( + f"{cls.__name__} class was defined with __schema__ without GraphQL type" + ) + + if type_def.name.value != "Subscription": + raise ValueError( + f"{cls.__name__} class was defined with __schema__ containing " + f"GraphQL definition for 'type {type_def.name.value}' " + "(expected 'type Subscription')" + ) + + return cast(ObjectNodeType, type_def) + + @classmethod + def __get_subscribers__(cls): + aliases = cls.__aliases__ or {} + defined_subscribers = cls.__get_defined_subscribers__() + + used_subscribers = [] + subscribers = {} + + for field_name in cls.graphql_fields: + if aliases and field_name in aliases: + subscription_name = aliases[field_name] + if subscription_name in defined_subscribers: + used_subscribers.append(subscription_name) + subscribers[field_name] = defined_subscribers[subscription_name] + + elif field_name in defined_subscribers: + used_subscribers.append(field_name) + subscribers[field_name] = defined_subscribers[field_name] + + unused_subscribers = [ + f"resolve_{field_name}" + for field_name in set(defined_subscribers) - set(used_subscribers) + ] + if unused_subscribers: + raise ValueError( + f"{cls.__name__} class was defined with subscribers for fields " + f"not in GraphQL type: {', '.join(unused_subscribers)}" + ) + + return subscribers + + @classmethod + def __get_defined_subscribers__(cls) -> Dict[str, Callable]: + resolvers = {} + for name in dir(cls): + if not name.startswith("subscribe_"): + continue + + value = getattr(cls, name) + if callable(value): + resolvers[name[10:]] = value + + return resolvers + + @classmethod + def __bind_to_schema__(cls, schema: GraphQLSchema): + graphql_type = cast(GraphQLObjectType, schema.type_map[cls.graphql_name]) + + for field_name, field_resolver in cls.resolvers.items(): + graphql_type.fields[field_name].resolve = field_resolver + + for field_name, field_subscriber in cls.subscribers.items(): + graphql_type.fields[field_name].subscribe = field_subscriber diff --git a/ariadne_graphql_modules/types.py b/ariadne_graphql_modules/types.py new file mode 100644 index 0000000..9c2173b --- /dev/null +++ b/ariadne_graphql_modules/types.py @@ -0,0 +1,11 @@ +from typing import Dict, Type + +from graphql import ( + DefinitionNode, + FieldDefinitionNode, + InputValueDefinitionNode, +) + +FieldsDict = Dict[str, FieldDefinitionNode] +InputFieldsDict = Dict[str, InputValueDefinitionNode] +RequirementsDict = Dict[str, Type[DefinitionNode]] diff --git a/ariadne_graphql_modules/union_type.py b/ariadne_graphql_modules/union_type.py new file mode 100644 index 0000000..e0a5e64 --- /dev/null +++ b/ariadne_graphql_modules/union_type.py @@ -0,0 +1,84 @@ +from typing import Type, Union, cast + +from graphql import ( + DefinitionNode, + GraphQLTypeResolver, + GraphQLSchema, + GraphQLUnionType, + UnionTypeDefinitionNode, + UnionTypeExtensionNode, +) + +from .base_type import BaseType +from .dependencies import Dependencies, get_dependencies_from_union_type +from .types import RequirementsDict +from .utils import parse_definition + +UnionNodeType = Union[UnionTypeDefinitionNode, UnionTypeExtensionNode] + + +class UnionType(BaseType): + __abstract__ = True + + graphql_type: Union[Type[UnionTypeDefinitionNode], Type[UnionTypeExtensionNode]] + resolve_type: GraphQLTypeResolver + + def __init_subclass__(cls) -> None: + super().__init_subclass__() + + if cls.__dict__.get("__abstract__"): + return + + cls.__abstract__ = False + + graphql_def = cls.__validate_schema__( + parse_definition(cls.__name__, cls.__schema__) + ) + + cls.graphql_name = graphql_def.name.value + cls.graphql_type = type(graphql_def) + + requirements = cls.__get_requirements__() + cls.__validate_requirements_contain_extended_type__(graphql_def, requirements) + + dependencies = cls.__get_dependencies__(graphql_def) + cls.__validate_requirements__(requirements, dependencies) + + @classmethod + def __validate_schema__(cls, type_def: DefinitionNode) -> UnionNodeType: + if not isinstance(type_def, (UnionTypeDefinitionNode, UnionTypeExtensionNode)): + raise ValueError( + f"{cls.__name__} class was defined with __schema__ " + "without GraphQL union" + ) + + return cast(UnionNodeType, type_def) + + @classmethod + def __validate_requirements_contain_extended_type__( + cls, type_def: UnionNodeType, requirements: RequirementsDict + ): + if not isinstance(type_def, UnionTypeExtensionNode): + return + + graphql_name = type_def.name.value + if graphql_name not in requirements: + raise ValueError( + f"{cls.__name__} class was defined without required GraphQL union " + f"definition for '{graphql_name}' in __requires__" + ) + + if requirements[graphql_name] != UnionTypeDefinitionNode: + raise ValueError( + f"{cls.__name__} requires '{graphql_name}' to be GraphQL union " + f"but other type was provided in '__requires__'" + ) + + @classmethod + def __get_dependencies__(cls, type_def: UnionNodeType) -> Dependencies: + return get_dependencies_from_union_type(type_def) + + @classmethod + def __bind_to_schema__(cls, schema: GraphQLSchema): + graphql_type = cast(GraphQLUnionType, schema.type_map.get(cls.graphql_name)) + graphql_type.resolve_type = cls.resolve_type diff --git a/ariadne_graphql_modules/utils.py b/ariadne_graphql_modules/utils.py new file mode 100644 index 0000000..be29116 --- /dev/null +++ b/ariadne_graphql_modules/utils.py @@ -0,0 +1,52 @@ +from typing import Any, Mapping + +from graphql import ( + DefinitionNode, + GraphQLResolveInfo, + ListTypeNode, + NonNullTypeNode, + TypeNode, + parse, +) + + +def parse_definition(type_name: str, schema: Any) -> DefinitionNode: + if not isinstance(schema, str): + raise TypeError( + f"{type_name} class was defined with __schema__ of invalid type: " + f"{type(schema).__name__}" + ) + + definitions = parse(schema).definitions + + if len(definitions) > 1: + definitions_types = [type(definition).__name__ for definition in definitions] + raise ValueError( + f"{type_name} class was defined with __schema__ containing more " + f"than one GraphQL definition (found: {', '.join(definitions_types)})" + ) + + return definitions[0] + + +def unwrap_type_node(field_type: TypeNode): + if isinstance(field_type, (NonNullTypeNode, ListTypeNode)): + return unwrap_type_node(field_type.type) + return field_type + + +def create_alias_resolver(field_name: str): + def default_aliased_field_resolver( + source: Any, info: GraphQLResolveInfo, **args: Any + ) -> Any: + value = ( + source.get(field_name) + if isinstance(source, Mapping) + else getattr(source, field_name, None) + ) + + if callable(value): + return value(info, **args) + return value + + return default_aliased_field_resolver diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2289fa0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,22 @@ +[tool.black] +line-length = 88 +target-version = ['py36', 'py37', 'py38'] +include = '\.pyi?$' +exclude = ''' +/( + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist + | snapshots +)/ +''' + +[tool.pytest.ini_options] +testpaths = ["tests"] diff --git a/setup.py b/setup.py new file mode 100755 index 0000000..5e7f61c --- /dev/null +++ b/setup.py @@ -0,0 +1,40 @@ +#! /usr/bin/env python +import os +from setuptools import setup + +CLASSIFIERS = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Software Development :: Libraries :: Python Modules", +] + +README_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "README.md") +with open(README_PATH, "r", encoding="utf8") as f: + README = f.read() + +setup( + name="ariadne-graphql-modules", + author="Mirumee Software", + author_email="hello@mirumee.com", + description="GraphQL Modules for Ariadne", + long_description=README, + long_description_content_type="text/markdown", + license="BSD", + version="0.1", + url="https://github.com/mirumee/ariadne-graphql-modules", + packages=["ariadne_graphql_modules"], + include_package_data=True, + install_requires=[ + "ariadne", + ], + classifiers=CLASSIFIERS, + platforms=["any"], + zip_safe=False, +) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/snapshots/__init__.py b/tests/snapshots/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/snapshots/snap_test_definition_parser.py b/tests/snapshots/snap_test_definition_parser.py new file mode 100644 index 0000000..f78c1f4 --- /dev/null +++ b/tests/snapshots/snap_test_definition_parser.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# snapshottest: v1 - https://goo.gl/zC4yUc +from __future__ import unicode_literals + +from snapshottest import GenericRepr, Snapshot + + +snapshots = Snapshot() + +snapshots['test_definition_parser_raises_error_schema_str_contains_multiple_types 1'] = GenericRepr("") + +snapshots['test_definition_parser_raises_error_when_schema_str_has_invalid_syntax 1'] = GenericRepr('') + +snapshots['test_definition_parser_raises_error_when_schema_type_is_invalid 1'] = GenericRepr("") diff --git a/tests/snapshots/snap_test_directive_type.py b/tests/snapshots/snap_test_directive_type.py new file mode 100644 index 0000000..4691c16 --- /dev/null +++ b/tests/snapshots/snap_test_directive_type.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# snapshottest: v1 - https://goo.gl/zC4yUc +from __future__ import unicode_literals + +from snapshottest import GenericRepr, Snapshot + + +snapshots = Snapshot() + +snapshots['test_directive_type_raises_attribute_error_when_defined_without_schema 1'] = GenericRepr('') + +snapshots['test_directive_type_raises_attribute_error_when_defined_without_visitor 1'] = GenericRepr("") + +snapshots['test_directive_type_raises_error_when_defined_with_invalid_graphql_type_schema 1'] = GenericRepr("") + +snapshots['test_directive_type_raises_error_when_defined_with_invalid_schema_str 1'] = GenericRepr('') + +snapshots['test_directive_type_raises_error_when_defined_with_invalid_schema_type 1'] = GenericRepr("") + +snapshots['test_directive_type_raises_error_when_defined_with_multiple_types_schema 1'] = GenericRepr("") diff --git a/tests/snapshots/snap_test_enum_type.py b/tests/snapshots/snap_test_enum_type.py new file mode 100644 index 0000000..0ee40db --- /dev/null +++ b/tests/snapshots/snap_test_enum_type.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# snapshottest: v1 - https://goo.gl/zC4yUc +from __future__ import unicode_literals + +from snapshottest import GenericRepr, Snapshot + + +snapshots = Snapshot() + +snapshots['test_enum_type_raises_attribute_error_when_defined_without_schema 1'] = GenericRepr('') + +snapshots['test_enum_type_raises_error_when_defined_with_invalid_graphql_type_schema 1'] = GenericRepr("") + +snapshots['test_enum_type_raises_error_when_defined_with_invalid_schema_str 1'] = GenericRepr('') + +snapshots['test_enum_type_raises_error_when_defined_with_invalid_schema_type 1'] = GenericRepr("") + +snapshots['test_enum_type_raises_error_when_defined_with_multiple_types_schema 1'] = GenericRepr("") + +snapshots['test_enum_type_raises_error_when_dict_mapping_has_extra_items_not_in_definition 1'] = GenericRepr("") + +snapshots['test_enum_type_raises_error_when_dict_mapping_misses_items_from_definition 1'] = GenericRepr("") + +snapshots['test_enum_type_raises_error_when_enum_mapping_has_extra_items_not_in_definition 1'] = GenericRepr("") + +snapshots['test_enum_type_raises_error_when_enum_mapping_misses_items_from_definition 1'] = GenericRepr("") diff --git a/tests/snapshots/snap_test_executable_schema.py b/tests/snapshots/snap_test_executable_schema.py new file mode 100644 index 0000000..afca833 --- /dev/null +++ b/tests/snapshots/snap_test_executable_schema.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +# snapshottest: v1 - https://goo.gl/zC4yUc +from __future__ import unicode_literals + +from snapshottest import GenericRepr, Snapshot + + +snapshots = Snapshot() + +snapshots['test_executable_schema_raises_value_error_if_merged_types_define_same_field 1'] = GenericRepr('') diff --git a/tests/snapshots/snap_test_input_type.py b/tests/snapshots/snap_test_input_type.py new file mode 100644 index 0000000..8df4264 --- /dev/null +++ b/tests/snapshots/snap_test_input_type.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# snapshottest: v1 - https://goo.gl/zC4yUc +from __future__ import unicode_literals + +from snapshottest import GenericRepr, Snapshot + + +snapshots = Snapshot() + +snapshots['test_input_type_raises_attribute_error_when_defined_without_schema 1'] = GenericRepr('') + +snapshots['test_input_type_raises_error_when_defined_with_args_map_for_nonexisting_field 1'] = GenericRepr("") + +snapshots['test_input_type_raises_error_when_defined_with_invalid_graphql_type_schema 1'] = GenericRepr("") + +snapshots['test_input_type_raises_error_when_defined_with_invalid_schema_str 1'] = GenericRepr('') + +snapshots['test_input_type_raises_error_when_defined_with_invalid_schema_type 1'] = GenericRepr("") + +snapshots['test_input_type_raises_error_when_defined_with_multiple_types_schema 1'] = GenericRepr("") + +snapshots['test_input_type_raises_error_when_defined_without_extended_dependency 1'] = GenericRepr('') + +snapshots['test_input_type_raises_error_when_defined_without_field_type_dependency 1'] = GenericRepr('') + +snapshots['test_input_type_raises_error_when_defined_without_fields 1'] = GenericRepr("") + +snapshots['test_input_type_raises_error_when_extended_dependency_is_wrong_type 1'] = GenericRepr('') diff --git a/tests/snapshots/snap_test_interface_type.py b/tests/snapshots/snap_test_interface_type.py new file mode 100644 index 0000000..b21bd35 --- /dev/null +++ b/tests/snapshots/snap_test_interface_type.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# snapshottest: v1 - https://goo.gl/zC4yUc +from __future__ import unicode_literals + +from snapshottest import GenericRepr, Snapshot + + +snapshots = Snapshot() + +snapshots['test_interface_type_raises_attribute_error_when_defined_without_schema 1'] = GenericRepr('') + +snapshots['test_interface_type_raises_error_when_defined_with_alias_for_nonexisting_field 1'] = GenericRepr("") + +snapshots['test_interface_type_raises_error_when_defined_with_invalid_graphql_type_schema 1'] = GenericRepr("") + +snapshots['test_interface_type_raises_error_when_defined_with_invalid_schema_str 1'] = GenericRepr('') + +snapshots['test_interface_type_raises_error_when_defined_with_invalid_schema_type 1'] = GenericRepr("") + +snapshots['test_interface_type_raises_error_when_defined_with_multiple_types_schema 1'] = GenericRepr("") + +snapshots['test_interface_type_raises_error_when_defined_with_resolver_for_nonexisting_field 1'] = GenericRepr("") + +snapshots['test_interface_type_raises_error_when_defined_without_argument_type_dependency 1'] = GenericRepr('') + +snapshots['test_interface_type_raises_error_when_defined_without_extended_dependency 1'] = GenericRepr("") + +snapshots['test_interface_type_raises_error_when_defined_without_fields 1'] = GenericRepr("") + +snapshots['test_interface_type_raises_error_when_defined_without_return_type_dependency 1'] = GenericRepr('') + +snapshots['test_interface_type_raises_error_when_extended_dependency_is_wrong_type 1'] = GenericRepr('') diff --git a/tests/snapshots/snap_test_mutation_type.py b/tests/snapshots/snap_test_mutation_type.py new file mode 100644 index 0000000..f5d3262 --- /dev/null +++ b/tests/snapshots/snap_test_mutation_type.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# snapshottest: v1 - https://goo.gl/zC4yUc +from __future__ import unicode_literals + +from snapshottest import GenericRepr, Snapshot + + +snapshots = Snapshot() + +snapshots['test_mutation_type_raises_attribute_error_when_defined_without_schema 1'] = GenericRepr('') + +snapshots['test_mutation_type_raises_error_when_defined_for_different_type_name 1'] = GenericRepr('') + +snapshots['test_mutation_type_raises_error_when_defined_with_invalid_graphql_type_schema 1'] = GenericRepr("") + +snapshots['test_mutation_type_raises_error_when_defined_with_invalid_schema_type 1'] = GenericRepr("") + +snapshots['test_mutation_type_raises_error_when_defined_with_multiple_fields 1'] = GenericRepr('') + +snapshots['test_mutation_type_raises_error_when_defined_with_multiple_types_schema 1'] = GenericRepr("") + +snapshots['test_mutation_type_raises_error_when_defined_with_nonexistant_args 1'] = GenericRepr('') + +snapshots['test_mutation_type_raises_error_when_defined_without_callable_resolve_mutation_attr 1'] = GenericRepr('') + +snapshots['test_mutation_type_raises_error_when_defined_without_fields 1'] = GenericRepr("") + +snapshots['test_mutation_type_raises_error_when_defined_without_resolve_mutation_attr 1'] = GenericRepr('') + +snapshots['test_mutation_type_raises_error_when_defined_without_return_type_dependency 1'] = GenericRepr('') + +snapshots['test_object_type_raises_error_when_defined_with_invalid_schema_str 1'] = GenericRepr('') diff --git a/tests/snapshots/snap_test_object_type.py b/tests/snapshots/snap_test_object_type.py new file mode 100644 index 0000000..d5ed96c --- /dev/null +++ b/tests/snapshots/snap_test_object_type.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# snapshottest: v1 - https://goo.gl/zC4yUc +from __future__ import unicode_literals + +from snapshottest import GenericRepr, Snapshot + + +snapshots = Snapshot() + +snapshots['test_object_type_raises_attribute_error_when_defined_without_schema 1'] = GenericRepr('') + +snapshots['test_object_type_raises_error_when_defined_with_alias_for_nonexisting_field 1'] = GenericRepr("") + +snapshots['test_object_type_raises_error_when_defined_with_field_args_for_nonexisting_arg 1'] = GenericRepr('') + +snapshots['test_object_type_raises_error_when_defined_with_field_args_for_nonexisting_field 1'] = GenericRepr("") + +snapshots['test_object_type_raises_error_when_defined_with_invalid_graphql_type_schema 1'] = GenericRepr("") + +snapshots['test_object_type_raises_error_when_defined_with_invalid_schema_str 1'] = GenericRepr('') + +snapshots['test_object_type_raises_error_when_defined_with_invalid_schema_type 1'] = GenericRepr("") + +snapshots['test_object_type_raises_error_when_defined_with_multiple_types_schema 1'] = GenericRepr("") + +snapshots['test_object_type_raises_error_when_defined_with_resolver_for_nonexisting_field 1'] = GenericRepr("") + +snapshots['test_object_type_raises_error_when_defined_without_argument_type_dependency 1'] = GenericRepr('') + +snapshots['test_object_type_raises_error_when_defined_without_extended_dependency 1'] = GenericRepr('') + +snapshots['test_object_type_raises_error_when_defined_without_fields 1'] = GenericRepr("") + +snapshots['test_object_type_raises_error_when_defined_without_return_type_dependency 1'] = GenericRepr('') + +snapshots['test_object_type_raises_error_when_extended_dependency_is_wrong_type 1'] = GenericRepr('') diff --git a/tests/snapshots/snap_test_scalar_type.py b/tests/snapshots/snap_test_scalar_type.py new file mode 100644 index 0000000..8de8a0e --- /dev/null +++ b/tests/snapshots/snap_test_scalar_type.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +# snapshottest: v1 - https://goo.gl/zC4yUc +from __future__ import unicode_literals + +from snapshottest import GenericRepr, Snapshot + + +snapshots = Snapshot() + +snapshots['test_scalar_type_raises_attribute_error_when_defined_without_schema 1'] = GenericRepr('') + +snapshots['test_scalar_type_raises_error_when_defined_with_invalid_graphql_type_schema 1'] = GenericRepr("") + +snapshots['test_scalar_type_raises_error_when_defined_with_invalid_schema_str 1'] = GenericRepr('') + +snapshots['test_scalar_type_raises_error_when_defined_with_invalid_schema_type 1'] = GenericRepr("") + +snapshots['test_scalar_type_raises_error_when_defined_with_multiple_types_schema 1'] = GenericRepr("") diff --git a/tests/snapshots/snap_test_subscription_type.py b/tests/snapshots/snap_test_subscription_type.py new file mode 100644 index 0000000..2b46bb1 --- /dev/null +++ b/tests/snapshots/snap_test_subscription_type.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# snapshottest: v1 - https://goo.gl/zC4yUc +from __future__ import unicode_literals + +from snapshottest import GenericRepr, Snapshot + + +snapshots = Snapshot() + +snapshots['test_subscription_type_raises_attribute_error_when_defined_without_schema 1'] = GenericRepr('') + +snapshots['test_subscription_type_raises_error_when_defined_with_alias_for_nonexisting_field 1'] = GenericRepr("") + +snapshots['test_subscription_type_raises_error_when_defined_with_invalid_graphql_type_name 1'] = GenericRepr('') + +snapshots['test_subscription_type_raises_error_when_defined_with_invalid_graphql_type_schema 1'] = GenericRepr("") + +snapshots['test_subscription_type_raises_error_when_defined_with_invalid_schema_str 1'] = GenericRepr('') + +snapshots['test_subscription_type_raises_error_when_defined_with_invalid_schema_type 1'] = GenericRepr("") + +snapshots['test_subscription_type_raises_error_when_defined_with_resolver_for_nonexisting_field 1'] = GenericRepr("") + +snapshots['test_subscription_type_raises_error_when_defined_with_sub_for_nonexisting_field 1'] = GenericRepr("") + +snapshots['test_subscription_type_raises_error_when_defined_without_argument_type_dependency 1'] = GenericRepr('') + +snapshots['test_subscription_type_raises_error_when_defined_without_extended_dependency 1'] = GenericRepr('') + +snapshots['test_subscription_type_raises_error_when_defined_without_fields 1'] = GenericRepr("") + +snapshots['test_subscription_type_raises_error_when_defined_without_return_type_dependency 1'] = GenericRepr('') + +snapshots['test_subscription_type_raises_error_when_extended_dependency_is_wrong_type 1'] = GenericRepr('') diff --git a/tests/snapshots/snap_test_union_type.py b/tests/snapshots/snap_test_union_type.py new file mode 100644 index 0000000..26fd0c5 --- /dev/null +++ b/tests/snapshots/snap_test_union_type.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +# snapshottest: v1 - https://goo.gl/zC4yUc +from __future__ import unicode_literals + +from snapshottest import GenericRepr, Snapshot + + +snapshots = Snapshot() + +snapshots['test_interface_type_raises_error_when_extended_dependency_is_wrong_type 1'] = GenericRepr('') + +snapshots['test_union_type_raises_attribute_error_when_defined_without_schema 1'] = GenericRepr('') + +snapshots['test_union_type_raises_error_when_defined_with_invalid_graphql_type_schema 1'] = GenericRepr("") + +snapshots['test_union_type_raises_error_when_defined_with_invalid_schema_str 1'] = GenericRepr('') + +snapshots['test_union_type_raises_error_when_defined_with_invalid_schema_type 1'] = GenericRepr("") + +snapshots['test_union_type_raises_error_when_defined_with_multiple_types_schema 1'] = GenericRepr("") + +snapshots['test_union_type_raises_error_when_defined_without_extended_dependency 1'] = GenericRepr('') + +snapshots['test_union_type_raises_error_when_defined_without_member_type_dependency 1'] = GenericRepr('') diff --git a/tests/test_convert_case.py b/tests/test_convert_case.py new file mode 100644 index 0000000..6c57ecd --- /dev/null +++ b/tests/test_convert_case.py @@ -0,0 +1,106 @@ +from ariadne_graphql_modules import InputType, ObjectType, convert_case + + +def test_cases_are_mapped_for_aliases(): + class ExampleObject(ObjectType): + __schema__ = """ + type Example { + field: Int + otherField: Int + } + """ + __aliases__ = convert_case + + assert ExampleObject.__aliases__ == {"otherField": "other_field"} + + +def test_cases_are_mapped_for_aliases_with_overrides(): + class ExampleObject(ObjectType): + __schema__ = """ + type Example { + field: Int + otherField: Int + } + """ + __aliases__ = convert_case({"field": "override"}) + + assert ExampleObject.__aliases__ == { + "field": "override", + "otherField": "other_field", + } + + +def test_cases_are_mapped_for_input_fields(): + class ExampleInput(InputType): + __schema__ = """ + input Example { + field: Int + otherField: Int + } + """ + __args__ = convert_case + + assert ExampleInput.__args__ == {"otherField": "other_field"} + + +def test_cases_are_mapped_for_input_fields_with_overrides(): + class ExampleInput(InputType): + __schema__ = """ + input Example { + field: Int + otherField: Int + } + """ + __args__ = convert_case({"field": "override"}) + + assert ExampleInput.__args__ == { + "field": "override", + "otherField": "other_field", + } + + +def test_cases_are_mapped_for_fields_args(): + class ExampleObject(ObjectType): + __schema__ = """ + type Example { + field(arg: Int, secondArg: Int): Int + otherField(arg: Int, thirdArg: Int): Int + } + """ + __fields_args__ = convert_case + + assert ExampleObject.__fields_args__ == { + "field": { + "secondArg": "second_arg", + }, + "otherField": { + "thirdArg": "third_arg", + }, + } + + +def test_cases_are_mapped_for_fields_args_with_overrides(): + class ExampleObject(ObjectType): + __schema__ = """ + type Example { + field(arg: Int, secondArg: Int): Int + otherField(arg: Int, thirdArg: Int): Int + } + """ + __fields_args__ = convert_case( + { + "field": { + "arg": "override", + } + } + ) + + assert ExampleObject.__fields_args__ == { + "field": { + "arg": "override", + "secondArg": "second_arg", + }, + "otherField": { + "thirdArg": "third_arg", + }, + } diff --git a/tests/test_definition_parser.py b/tests/test_definition_parser.py new file mode 100644 index 0000000..c5e49d3 --- /dev/null +++ b/tests/test_definition_parser.py @@ -0,0 +1,62 @@ +import pytest +from graphql import GraphQLError +from graphql.language.ast import ObjectTypeDefinitionNode + +from ariadne_graphql_modules import parse_definition + + +def test_definition_parser_returns_definition_type_from_valid_schema_string(): + type_def = parse_definition( + "MyType", + """ + type My { + id: ID! + } + """, + ) + + assert isinstance(type_def, ObjectTypeDefinitionNode) + assert type_def.name.value == "My" + assert type_def.fields[0].name.value == "id" + + +def test_definition_parser_parses_definition_with_description(): + type_def = parse_definition( + "MyType", + """ + "Test user type" + type User + """, + ) + + assert isinstance(type_def, ObjectTypeDefinitionNode) + assert type_def.name.value == "User" + assert type_def.description.value == "Test user type" + + +def test_definition_parser_raises_error_when_schema_type_is_invalid(snapshot): + with pytest.raises(TypeError) as err: + parse_definition("MyType", True) + + snapshot.assert_match(err) + + +def test_definition_parser_raises_error_when_schema_str_has_invalid_syntax(snapshot): + with pytest.raises(GraphQLError) as err: + parse_definition("MyType", "typo User") + + snapshot.assert_match(err) + + +def test_definition_parser_raises_error_schema_str_contains_multiple_types(snapshot): + with pytest.raises(ValueError) as err: + parse_definition( + "MyType", + """ + type User + + type Group + """, + ) + + snapshot.assert_match(err) diff --git a/tests/test_directive_type.py b/tests/test_directive_type.py new file mode 100644 index 0000000..fcd3846 --- /dev/null +++ b/tests/test_directive_type.py @@ -0,0 +1,105 @@ +import pytest +from ariadne import SchemaDirectiveVisitor +from graphql import GraphQLError, default_field_resolver, graphql_sync + +from ariadne_graphql_modules import DirectiveType, ObjectType, make_executable_schema + + +def test_directive_type_raises_attribute_error_when_defined_without_schema(snapshot): + with pytest.raises(AttributeError) as err: + # pylint: disable=unused-variable + class ExampleDirective(DirectiveType): + pass + + snapshot.assert_match(err) + + +def test_directive_type_raises_error_when_defined_with_invalid_schema_type(snapshot): + with pytest.raises(TypeError) as err: + # pylint: disable=unused-variable + class ExampleDirective(DirectiveType): + __schema__ = True + + snapshot.assert_match(err) + + +def test_directive_type_raises_error_when_defined_with_invalid_schema_str(snapshot): + with pytest.raises(GraphQLError) as err: + # pylint: disable=unused-variable + class ExampleDirective(DirectiveType): + __schema__ = "directivo @example on FIELD_DEFINITION" + + snapshot.assert_match(err) + + +def test_directive_type_raises_error_when_defined_with_invalid_graphql_type_schema( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ExampleDirective(DirectiveType): + __schema__ = "scalar example" + + snapshot.assert_match(err) + + +def test_directive_type_raises_error_when_defined_with_multiple_types_schema(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ExampleDirective(DirectiveType): + __schema__ = """ + directive @example on FIELD_DEFINITION + + directive @other on OBJECT + """ + + snapshot.assert_match(err) + + +def test_directive_type_raises_attribute_error_when_defined_without_visitor(snapshot): + with pytest.raises(AttributeError) as err: + # pylint: disable=unused-variable + class ExampleDirective(DirectiveType): + __schema__ = "directive @example on FIELD_DEFINITION" + + snapshot.assert_match(err) + + +class ExampleSchemaVisitor(SchemaDirectiveVisitor): + def visit_field_definition(self, field, object_type): + original_resolver = field.resolve or default_field_resolver + + def resolve_prefixed_value(obj, info, **kwargs): + result = original_resolver(obj, info, **kwargs) + if result: + return f"PREFIX: {result}" + return result + + field.resolve = resolve_prefixed_value + return field + + +def test_directive_type_extracts_graphql_name(): + class ExampleDirective(DirectiveType): + __schema__ = "directive @example on FIELD_DEFINITION" + __visitor__ = ExampleSchemaVisitor + + assert ExampleDirective.graphql_name == "example" + + +def test_directive_is_set_on_field(): + class ExampleDirective(DirectiveType): + __schema__ = "directive @example on FIELD_DEFINITION" + __visitor__ = ExampleSchemaVisitor + + class QueryType(ObjectType): + __schema__ = """ + type Query { + field: String! @example + } + """ + __requires__ = [ExampleDirective] + + schema = make_executable_schema(QueryType) + result = graphql_sync(schema, "{ field }", root_value={"field": "test"}) + assert result.data == {"field": "PREFIX: test"} diff --git a/tests/test_enum_type.py b/tests/test_enum_type.py new file mode 100644 index 0000000..730eb8c --- /dev/null +++ b/tests/test_enum_type.py @@ -0,0 +1,302 @@ +from enum import Enum + +import pytest +from ariadne import SchemaDirectiveVisitor +from graphql import GraphQLError, graphql_sync + +from ariadne_graphql_modules import ( + DirectiveType, + EnumType, + ObjectType, + make_executable_schema, +) + + +def test_enum_type_raises_attribute_error_when_defined_without_schema(snapshot): + with pytest.raises(AttributeError) as err: + # pylint: disable=unused-variable + class UserRoleEnum(EnumType): + pass + + snapshot.assert_match(err) + + +def test_enum_type_raises_error_when_defined_with_invalid_schema_type(snapshot): + with pytest.raises(TypeError) as err: + # pylint: disable=unused-variable + class UserRoleEnum(EnumType): + __schema__ = True + + snapshot.assert_match(err) + + +def test_enum_type_raises_error_when_defined_with_invalid_schema_str(snapshot): + with pytest.raises(GraphQLError) as err: + # pylint: disable=unused-variable + class UserRoleEnum(EnumType): + __schema__ = "enom UserRole" + + snapshot.assert_match(err) + + +def test_enum_type_raises_error_when_defined_with_invalid_graphql_type_schema( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserRoleEnum(EnumType): + __schema__ = "scalar UserRole" + + snapshot.assert_match(err) + + +def test_enum_type_raises_error_when_defined_with_multiple_types_schema(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserRoleEnum(EnumType): + __schema__ = """ + enum UserRole { + USER + MOD + ADMIN + } + + enum Category { + CATEGORY + LINK + } + """ + + snapshot.assert_match(err) + + +def test_enum_type_extracts_graphql_name(): + class UserRoleEnum(EnumType): + __schema__ = """ + enum UserRole { + USER + MOD + ADMIN + } + """ + + assert UserRoleEnum.graphql_name == "UserRole" + + +def test_enum_type_can_be_extended_with_new_values(): + # pylint: disable=unused-variable + class UserRoleEnum(EnumType): + __schema__ = """ + enum UserRole { + USER + MOD + ADMIN + } + """ + + class ExtendUserRoleEnum(EnumType): + __schema__ = """ + extend enum UserRole { + MVP + } + """ + __requires__ = [UserRoleEnum] + + +def test_enum_type_can_be_extended_with_directive(): + # pylint: disable=unused-variable + class ExampleDirective(DirectiveType): + __schema__ = "directive @example on ENUM" + __visitor__ = SchemaDirectiveVisitor + + class UserRoleEnum(EnumType): + __schema__ = """ + enum UserRole { + USER + MOD + ADMIN + } + """ + + class ExtendUserRoleEnum(EnumType): + __schema__ = "extend enum UserRole @example" + __requires__ = [UserRoleEnum, ExampleDirective] + + +class BaseQueryType(ObjectType): + __abstract__ = True + __schema__ = """ + type Query { + enumToRepr(enum: UserRole = USER): String! + reprToEnum: UserRole! + } + """ + __aliases__ = { + "enumToRepr": "enum_repr", + } + + @staticmethod + def resolve_enum_repr(*_, enum) -> str: + return repr(enum) + + +def make_test_schema(enum_type): + class QueryType(BaseQueryType): + __requires__ = [enum_type] + + return make_executable_schema(QueryType) + + +def test_enum_type_can_be_defined_with_dict_mapping(): + class UserRoleEnum(EnumType): + __schema__ = """ + enum UserRole { + USER + MOD + ADMIN + } + """ + __enum__ = { + "USER": 0, + "MOD": 1, + "ADMIN": 2, + } + + schema = make_test_schema(UserRoleEnum) + + # Specfied enum value is reversed + result = graphql_sync(schema, "{ enumToRepr(enum: MOD) }") + assert result.data["enumToRepr"] == "1" + + # Default enum value is reversed + result = graphql_sync(schema, "{ enumToRepr }") + assert result.data["enumToRepr"] == "0" + + # Python value is converted to enum + result = graphql_sync(schema, "{ reprToEnum }", root_value={"reprToEnum": 2}) + assert result.data["reprToEnum"] == "ADMIN" + + +def test_enum_type_raises_error_when_dict_mapping_misses_items_from_definition( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserRoleEnum(EnumType): + __schema__ = """ + enum UserRole { + USER + MOD + ADMIN + } + """ + __enum__ = { + "USER": 0, + "MODERATOR": 1, + "ADMIN": 2, + } + + snapshot.assert_match(err) + + +def test_enum_type_raises_error_when_dict_mapping_has_extra_items_not_in_definition( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserRoleEnum(EnumType): + __schema__ = """ + enum UserRole { + USER + MOD + ADMIN + } + """ + __enum__ = { + "USER": 0, + "REVIEW": 1, + "MOD": 2, + "ADMIN": 3, + } + + snapshot.assert_match(err) + + +def test_enum_type_can_be_defined_with_str_enum_mapping(): + class RoleEnum(str, Enum): + USER = "user" + MOD = "moderator" + ADMIN = "administrator" + + class UserRoleEnum(EnumType): + __schema__ = """ + enum UserRole { + USER + MOD + ADMIN + } + """ + __enum__ = RoleEnum + + schema = make_test_schema(UserRoleEnum) + + # Specfied enum value is reversed + result = graphql_sync(schema, "{ enumToRepr(enum: MOD) }") + assert result.data["enumToRepr"] == repr(RoleEnum.MOD) + + # Default enum value is reversed + result = graphql_sync(schema, "{ enumToRepr }") + assert result.data["enumToRepr"] == repr(RoleEnum.USER) + + # Python value is converted to enum + result = graphql_sync( + schema, "{ reprToEnum }", root_value={"reprToEnum": "administrator"} + ) + assert result.data["reprToEnum"] == "ADMIN" + + +def test_enum_type_raises_error_when_enum_mapping_misses_items_from_definition( + snapshot, +): + class RoleEnum(str, Enum): + USER = "user" + MODERATOR = "moderator" + ADMIN = "administrator" + + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserRoleEnum(EnumType): + __schema__ = """ + enum UserRole { + USER + MOD + ADMIN + } + """ + __enum__ = RoleEnum + + snapshot.assert_match(err) + + +def test_enum_type_raises_error_when_enum_mapping_has_extra_items_not_in_definition( + snapshot, +): + class RoleEnum(str, Enum): + USER = "user" + REVIEW = "review" + MOD = "moderator" + ADMIN = "administrator" + + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserRoleEnum(EnumType): + __schema__ = """ + enum UserRole { + USER + MOD + ADMIN + } + """ + __enum__ = RoleEnum + + snapshot.assert_match(err) diff --git a/tests/test_executable_schema.py b/tests/test_executable_schema.py new file mode 100644 index 0000000..f5dfe13 --- /dev/null +++ b/tests/test_executable_schema.py @@ -0,0 +1,90 @@ +import pytest +from graphql import graphql_sync + +from ariadne_graphql_modules import ObjectType, make_executable_schema + + +def test_executable_schema_is_created_from_object_types(): + class UserType(ObjectType): + __schema__ = """ + type User { + id: ID! + username: String! + } + """ + __aliases__ = { + "username": "user_name", + } + + class QueryType(ObjectType): + __schema__ = """ + type Query { + user: User + } + """ + __requires__ = [UserType] + + @staticmethod + def resolve_user(*_): + return { + "id": 1, + "user_name": "Alice", + } + + schema = make_executable_schema(QueryType) + result = graphql_sync(schema, "{ user { id username } }") + assert result.errors is None + assert result.data == {"user": {"id": "1", "username": "Alice"}} + + +def test_executable_schema_merges_root_types(): + class CityQueryType(ObjectType): + __schema__ = """ + type Query { + city: String! + } + """ + + @staticmethod + def resolve_city(*_): + return "Wroclaw" + + class YearQueryType(ObjectType): + __schema__ = """ + type Query { + year: Int! + } + """ + + @staticmethod + def resolve_year(*_): + return 2022 + + schema = make_executable_schema(CityQueryType, YearQueryType) + result = graphql_sync(schema, "{ city, year }") + assert result.errors is None + assert result.data == {"city": "Wroclaw", "year": 2022} + + +def test_executable_schema_raises_value_error_if_merged_types_define_same_field( + snapshot, +): + class CityQueryType(ObjectType): + __schema__ = """ + type Query { + city: String + } + """ + + class YearQueryType(ObjectType): + __schema__ = """ + type Query { + city: String + year: Int + } + """ + + with pytest.raises(ValueError) as err: + make_executable_schema(CityQueryType, YearQueryType) + + snapshot.assert_match(err) diff --git a/tests/test_input_type.py b/tests/test_input_type.py new file mode 100644 index 0000000..ded8d94 --- /dev/null +++ b/tests/test_input_type.py @@ -0,0 +1,285 @@ +import pytest +from ariadne import SchemaDirectiveVisitor +from graphql import GraphQLError, graphql_sync + +from ariadne_graphql_modules import ( + DeferredType, + DirectiveType, + EnumType, + InputType, + InterfaceType, + ObjectType, + ScalarType, + make_executable_schema, +) + + +def test_input_type_raises_attribute_error_when_defined_without_schema(snapshot): + with pytest.raises(AttributeError) as err: + # pylint: disable=unused-variable + class UserInput(InputType): + pass + + snapshot.assert_match(err) + + +def test_input_type_raises_error_when_defined_with_invalid_schema_type(snapshot): + with pytest.raises(TypeError) as err: + # pylint: disable=unused-variable + class UserInput(InputType): + __schema__ = True + + snapshot.assert_match(err) + + +def test_input_type_raises_error_when_defined_with_invalid_schema_str(snapshot): + with pytest.raises(GraphQLError) as err: + # pylint: disable=unused-variable + class UserInput(InputType): + __schema__ = "inpet UserInput" + + snapshot.assert_match(err) + + +def test_input_type_raises_error_when_defined_with_invalid_graphql_type_schema( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserInput(InputType): + __schema__ = """ + type User { + id: ID! + } + """ + + snapshot.assert_match(err) + + +def test_input_type_raises_error_when_defined_with_multiple_types_schema(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserInput(InputType): + __schema__ = """ + input User + + input Group + """ + + snapshot.assert_match(err) + + +def test_input_type_raises_error_when_defined_without_fields(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserInput(InputType): + __schema__ = "input User" + + snapshot.assert_match(err) + + +def test_input_type_extracts_graphql_name(): + class UserInput(InputType): + __schema__ = """ + input User { + id: ID! + } + """ + + assert UserInput.graphql_name == "User" + + +def test_input_type_raises_error_when_defined_without_field_type_dependency(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserInput(InputType): + __schema__ = """ + input User { + id: ID! + role: Role! + } + """ + + snapshot.assert_match(err) + + +def test_input_type_verifies_field_dependency(): + # pylint: disable=unused-variable + class RoleEnum(EnumType): + __schema__ = """ + enum Role { + USER + ADMIN + } + """ + + class UserInput(InputType): + __schema__ = """ + input User { + id: ID! + role: Role! + } + """ + __requires__ = [RoleEnum] + + +def test_input_type_verifies_circular_dependency(): + # pylint: disable=unused-variable + class UserInput(InputType): + __schema__ = """ + input User { + id: ID! + patron: User + } + """ + + +def test_input_type_verifies_circular_dependency_using_deferred_type(): + # pylint: disable=unused-variable + class GroupInput(InputType): + __schema__ = """ + input Group { + id: ID! + patron: User + } + """ + __requires__ = [DeferredType("User")] + + class UserInput(InputType): + __schema__ = """ + input User { + id: ID! + group: Group + } + """ + __requires__ = [GroupInput] + + +def test_input_type_can_be_extended_with_new_fields(): + # pylint: disable=unused-variable + class UserInput(InputType): + __schema__ = """ + input User { + id: ID! + } + """ + + class ExtendUserInput(InputType): + __schema__ = """ + extend input User { + name: String! + } + """ + __requires__ = [UserInput] + + +def test_input_type_can_be_extended_with_directive(): + # pylint: disable=unused-variable + class ExampleDirective(DirectiveType): + __schema__ = "directive @example on INPUT_OBJECT" + __visitor__ = SchemaDirectiveVisitor + + class UserInput(InputType): + __schema__ = """ + input User { + id: ID! + } + """ + + class ExtendUserInput(InputType): + __schema__ = """ + extend input User @example + """ + __requires__ = [UserInput, ExampleDirective] + + +def test_input_type_raises_error_when_defined_without_extended_dependency(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ExtendUserInput(InputType): + __schema__ = """ + extend input User { + name: String! + } + """ + + snapshot.assert_match(err) + + +def test_input_type_raises_error_when_extended_dependency_is_wrong_type(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ExampleInterface(InterfaceType): + __schema__ = """ + interface User { + id: ID! + } + """ + + class ExtendUserInput(InputType): + __schema__ = """ + extend input User { + name: String! + } + """ + __requires__ = [ExampleInterface] + + snapshot.assert_match(err) + + +def test_input_type_raises_error_when_defined_with_args_map_for_nonexisting_field( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserInput(InputType): + __schema__ = """ + input User { + id: ID! + } + """ + __args__ = { + "fullName": "full_name", + } + + snapshot.assert_match(err) + + +class UserInput(InputType): + __schema__ = """ + input UserInput { + id: ID! + fullName: String! + } + """ + __args__ = { + "fullName": "full_name", + } + + +class GenericScalar(ScalarType): + __schema__ = "scalar Generic" + + +class QueryType(ObjectType): + __schema__ = """ + type Query { + reprInput(input: UserInput): Generic! + } + """ + __aliases__ = {"reprInput": "repr_input"} + __requires__ = [GenericScalar, UserInput] + + @staticmethod + def resolve_repr_input(*_, input): # pylint: disable=redefined-builtin + return input + + +schema = make_executable_schema(QueryType) + + +def test_input_type_maps_args_to_python_dict_keys(): + result = graphql_sync(schema, '{ reprInput(input: {id: "1", fullName: "Alice"}) }') + assert result.data == { + "reprInput": {"id": "1", "full_name": "Alice"}, + } diff --git a/tests/test_interface_type.py b/tests/test_interface_type.py new file mode 100644 index 0000000..4a064f4 --- /dev/null +++ b/tests/test_interface_type.py @@ -0,0 +1,450 @@ +from dataclasses import dataclass + +import pytest +from ariadne import SchemaDirectiveVisitor +from graphql import GraphQLError, graphql_sync + +from ariadne_graphql_modules import ( + DeferredType, + DirectiveType, + InterfaceType, + ObjectType, + make_executable_schema, +) + + +def test_interface_type_raises_attribute_error_when_defined_without_schema(snapshot): + with pytest.raises(AttributeError) as err: + # pylint: disable=unused-variable + class ExampleInterface(InterfaceType): + pass + + snapshot.assert_match(err) + + +def test_interface_type_raises_error_when_defined_with_invalid_schema_type(snapshot): + with pytest.raises(TypeError) as err: + # pylint: disable=unused-variable + class ExampleInterface(InterfaceType): + __schema__ = True + + snapshot.assert_match(err) + + +def test_interface_type_raises_error_when_defined_with_invalid_schema_str(snapshot): + with pytest.raises(GraphQLError) as err: + # pylint: disable=unused-variable + class ExampleInterface(InterfaceType): + __schema__ = "interfaco Example" + + snapshot.assert_match(err) + + +def test_interface_type_raises_error_when_defined_with_invalid_graphql_type_schema( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ExampleInterface(InterfaceType): + __schema__ = "type Example" + + snapshot.assert_match(err) + + +def test_interface_type_raises_error_when_defined_with_multiple_types_schema(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ExampleInterface(InterfaceType): + __schema__ = """ + interface Example + + interface Other + """ + + snapshot.assert_match(err) + + +def test_interface_type_raises_error_when_defined_without_fields(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ExampleInterface(InterfaceType): + __schema__ = "interface Example" + + snapshot.assert_match(err) + + +def test_interface_type_extracts_graphql_name(): + class ExampleInterface(InterfaceType): + __schema__ = """ + interface Example { + id: ID! + } + """ + + assert ExampleInterface.graphql_name == "Example" + + +def test_interface_type_raises_error_when_defined_without_return_type_dependency( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ExampleInterface(InterfaceType): + __schema__ = """ + interface Example { + group: Group + groups: [Group!] + } + """ + + snapshot.assert_match(err) + + +def test_interface_type_verifies_field_dependency(): + # pylint: disable=unused-variable + class GroupType(ObjectType): + __schema__ = """ + type Group { + id: ID! + } + """ + + class ExampleInterface(InterfaceType): + __schema__ = """ + interface Example { + group: Group + groups: [Group!] + } + """ + __requires__ = [GroupType] + + +def test_interface_type_verifies_circural_dependency(): + # pylint: disable=unused-variable + class ExampleInterface(InterfaceType): + __schema__ = """ + interface Example { + parent: Example + } + """ + + +def test_interface_type_raises_error_when_defined_without_argument_type_dependency( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ExampleInterface(InterfaceType): + __schema__ = """ + interface Example { + actions(input: UserInput): [String!]! + } + """ + + snapshot.assert_match(err) + + +def test_interface_type_verifies_circular_dependency_using_deferred_type(): + # pylint: disable=unused-variable + class ExampleInterface(InterfaceType): + __schema__ = """ + interface Example { + id: ID! + users: [User] + } + """ + __requires__ = [DeferredType("User")] + + class UserType(ObjectType): + __schema__ = """ + type User { + roles: [Example] + } + """ + __requires__ = [ExampleInterface] + + +def test_interface_type_can_be_extended_with_new_fields(): + # pylint: disable=unused-variable + class ExampleInterface(InterfaceType): + __schema__ = """ + interface Example { + id: ID! + } + """ + + class ExtendExampleInterface(InterfaceType): + __schema__ = """ + extend interface Example { + name: String + } + """ + __requires__ = [ExampleInterface] + + +def test_interface_type_can_be_extended_with_directive(): + # pylint: disable=unused-variable + class ExampleDirective(DirectiveType): + __schema__ = "directive @example on INTERFACE" + __visitor__ = SchemaDirectiveVisitor + + class ExampleInterface(InterfaceType): + __schema__ = """ + interface Example { + id: ID! + } + """ + + class ExtendExampleInterface(InterfaceType): + __schema__ = """ + extend interface Example @example + """ + __requires__ = [ExampleInterface, ExampleDirective] + + +def test_interface_type_can_be_extended_with_other_interface(): + # pylint: disable=unused-variable + class ExampleInterface(InterfaceType): + __schema__ = """ + interface Example { + id: ID! + } + """ + + class OtherInterface(InterfaceType): + __schema__ = """ + interface Other { + depth: Int! + } + """ + + class ExtendExampleInterface(InterfaceType): + __schema__ = """ + extend interface Example implements Other + """ + __requires__ = [ExampleInterface, OtherInterface] + + +def test_interface_type_raises_error_when_defined_without_extended_dependency(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ExtendExampleInterface(ObjectType): + __schema__ = """ + extend interface Example { + name: String + } + """ + + snapshot.assert_match(err) + + +def test_interface_type_raises_error_when_extended_dependency_is_wrong_type(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ExampleType(ObjectType): + __schema__ = """ + type Example { + id: ID! + } + """ + + class ExampleInterface(InterfaceType): + __schema__ = """ + extend interface Example { + name: String + } + """ + __requires__ = [ExampleType] + + snapshot.assert_match(err) + + +def test_interface_type_raises_error_when_defined_with_alias_for_nonexisting_field( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ExampleInterface(InterfaceType): + __schema__ = """ + interface User { + name: String + } + """ + __aliases__ = { + "joinedDate": "joined_date", + } + + snapshot.assert_match(err) + + +def test_interface_type_raises_error_when_defined_with_resolver_for_nonexisting_field( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ExampleInterface(InterfaceType): + __schema__ = """ + interface User { + name: String + } + """ + + @staticmethod + def resolve_group(*_): + return None + + snapshot.assert_match(err) + + +@dataclass +class User: + id: int + name: str + summary: str + + +@dataclass +class Comment: + id: int + message: str + summary: str + + +class ResultInterface(InterfaceType): + __schema__ = """ + interface Result { + summary: String! + score: Int! + } + """ + + @staticmethod + def resolve_type(instance, *_): + if isinstance(instance, Comment): + return "Comment" + + if isinstance(instance, User): + return "User" + + return None + + @staticmethod + def resolve_score(*_): + return 42 + + +class UserType(ObjectType): + __schema__ = """ + type User implements Result { + id: ID! + name: String! + summary: String! + score: Int! + } + """ + __requires__ = [ResultInterface] + + +class CommentType(ObjectType): + __schema__ = """ + type Comment implements Result { + id: ID! + message: String! + summary: String! + score: Int! + } + """ + __requires__ = [ResultInterface] + + @staticmethod + def resolve_score(*_): + return 16 + + +class QueryType(ObjectType): + __schema__ = """ + type Query { + results: [Result!]! + } + """ + __requires__ = [ResultInterface] + + @staticmethod + def resolve_results(*_): + return [ + User(id=1, name="Alice", summary="Summary for Alice"), + Comment(id=1, message="Hello world!", summary="Summary for comment"), + ] + + +schema = make_executable_schema(QueryType, UserType, CommentType) + + +def test_interface_type_binds_type_resolver(): + query = """ + query { + results { + ... on User { + __typename + id + name + summary + } + ... on Comment { + __typename + id + message + summary + } + } + } + """ + + result = graphql_sync(schema, query) + assert result.data == { + "results": [ + { + "__typename": "User", + "id": "1", + "name": "Alice", + "summary": "Summary for Alice", + }, + { + "__typename": "Comment", + "id": "1", + "message": "Hello world!", + "summary": "Summary for comment", + }, + ], + } + + +def test_interface_type_binds_field_resolvers_to_implementing_types_fields(): + query = """ + query { + results { + ... on User { + __typename + score + } + ... on Comment { + __typename + score + } + } + } + """ + + result = graphql_sync(schema, query) + assert result.data == { + "results": [ + { + "__typename": "User", + "score": 42, + }, + { + "__typename": "Comment", + "score": 16, + }, + ], + } diff --git a/tests/test_mutation_type.py b/tests/test_mutation_type.py new file mode 100644 index 0000000..2b73477 --- /dev/null +++ b/tests/test_mutation_type.py @@ -0,0 +1,318 @@ +import pytest +from graphql import GraphQLError, graphql_sync + +from ariadne_graphql_modules import ( + MutationType, + ObjectType, + make_executable_schema, +) + + +def test_mutation_type_raises_attribute_error_when_defined_without_schema(snapshot): + with pytest.raises(AttributeError) as err: + # pylint: disable=unused-variable + class UserCreateMutation(MutationType): + pass + + snapshot.assert_match(err) + + +def test_mutation_type_raises_error_when_defined_with_invalid_schema_type(snapshot): + with pytest.raises(TypeError) as err: + # pylint: disable=unused-variable + class UserCreateMutation(MutationType): + __schema__ = True + + snapshot.assert_match(err) + + +def test_object_type_raises_error_when_defined_with_invalid_schema_str(snapshot): + with pytest.raises(GraphQLError) as err: + # pylint: disable=unused-variable + class UserCreateMutation(MutationType): + __schema__ = "typo User" + + snapshot.assert_match(err) + + +def test_mutation_type_raises_error_when_defined_with_invalid_graphql_type_schema( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserCreateMutation(MutationType): + __schema__ = "scalar DateTime" + + snapshot.assert_match(err) + + +def test_mutation_type_raises_error_when_defined_with_multiple_types_schema(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserCreateMutation(MutationType): + __schema__ = """ + type User + + type Group + """ + + snapshot.assert_match(err) + + +def test_mutation_type_raises_error_when_defined_for_different_type_name(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserCreateMutation(MutationType): + __schema__ = """ + type User { + id: ID! + } + """ + + snapshot.assert_match(err) + + +def test_mutation_type_raises_error_when_defined_without_fields(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserCreateMutation(MutationType): + __schema__ = """ + type Mutation + """ + + snapshot.assert_match(err) + + +def test_mutation_type_raises_error_when_defined_with_multiple_fields(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserCreateMutation(MutationType): + __schema__ = """ + type Mutation { + userCreate(name: String!): Boolean! + userUpdate(id: ID!, name: String!): Boolean! + } + """ + + snapshot.assert_match(err) + + +def test_mutation_type_raises_error_when_defined_without_resolve_mutation_attr( + snapshot, +): + with pytest.raises(AttributeError) as err: + # pylint: disable=unused-variable + class UserCreateMutation(MutationType): + __schema__ = """ + type Mutation { + userCreate(name: String!): Boolean! + } + """ + + snapshot.assert_match(err) + + +def test_mutation_type_raises_error_when_defined_without_callable_resolve_mutation_attr( + snapshot, +): + with pytest.raises(TypeError) as err: + # pylint: disable=unused-variable + class UserCreateMutation(MutationType): + __schema__ = """ + type Mutation { + userCreate(name: String!): Boolean! + } + """ + + resolve_mutation = True + + snapshot.assert_match(err) + + +def test_mutation_type_raises_error_when_defined_without_return_type_dependency( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserCreateMutation(MutationType): + __schema__ = """ + type Mutation { + userCreate(name: String!): UserCreateResult! + } + """ + + @staticmethod + def resolve_mutation(*_args): + pass + + snapshot.assert_match(err) + + +def test_mutation_type_verifies_field_dependency(): + # pylint: disable=unused-variable + class UserCreateResult(ObjectType): + __schema__ = """ + type UserCreateResult { + errors: [String!] + } + """ + + class UserCreateMutation(MutationType): + __schema__ = """ + type Mutation { + userCreate(name: String!): UserCreateResult! + } + """ + __requires__ = [UserCreateResult] + + @staticmethod + def resolve_mutation(*_args): + pass + + +def test_mutation_type_raises_error_when_defined_with_nonexistant_args( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserCreateMutation(MutationType): + __schema__ = """ + type Mutation { + userCreate(name: String!): Boolean! + } + """ + __args__ = {"realName": "real_name"} + + @staticmethod + def resolve_mutation(*_args): + pass + + snapshot.assert_match(err) + + +class QueryType(ObjectType): + __schema__ = """ + type Query { + field: String! + } + """ + + +class ResultType(ObjectType): + __schema__ = """ + type Result { + error: String + total: Int + } + """ + + +class SumMutation(MutationType): + __schema__ = """ + type Mutation { + sum(a: Int!, b: Int!): Result! + } + """ + __requires__ = [ResultType] + + @staticmethod + def resolve_mutation(*_, a: int, b: int): + return {"total": a + b} + + +class DivideMutation(MutationType): + __schema__ = """ + type Mutation { + divide(a: Int!, b: Int!): Result! + } + """ + __requires__ = [ResultType] + + @staticmethod + def resolve_mutation(*_, a: int, b: int): + if a == 0 or b == 0: + return {"error": "Division by zero"} + + return {"total": a / b} + + +class SplitMutation(MutationType): + __schema__ = """ + type Mutation { + split(strToSplit: String!): [String!]! + } + """ + __args__ = {"strToSplit": "split_str"} + + @staticmethod + def resolve_mutation(*_, split_str: str): + return split_str.split() + + +schema = make_executable_schema(QueryType, SumMutation, DivideMutation, SplitMutation) + + +def test_sum_mutation_resolves_to_result(): + query = """ + mutation { + sum(a: 5, b: 3) { + total + error + } + } + """ + result = graphql_sync(schema, query) + assert result.data == { + "sum": { + "total": 8, + "error": None, + }, + } + + +def test_divide_mutation_resolves_to_result(): + query = """ + mutation { + divide(a: 6, b: 3) { + total + error + } + } + """ + result = graphql_sync(schema, query) + assert result.data == { + "divide": { + "total": 2, + "error": None, + }, + } + + +def test_divide_mutation_resolves_to_error_result(): + query = """ + mutation { + divide(a: 6, b: 0) { + total + error + } + } + """ + result = graphql_sync(schema, query) + assert result.data == { + "divide": { + "total": None, + "error": "Division by zero", + }, + } + + +def test_split_mutation_uses_arg_mapping(): + query = """ + mutation { + split(strToSplit: "a b c") + } + """ + result = graphql_sync(schema, query) + assert result.data == { + "split": ["a", "b", "c"], + } diff --git a/tests/test_object_type.py b/tests/test_object_type.py new file mode 100644 index 0000000..93f9327 --- /dev/null +++ b/tests/test_object_type.py @@ -0,0 +1,372 @@ +import pytest +from ariadne import SchemaDirectiveVisitor +from graphql import GraphQLError, graphql_sync + +from ariadne_graphql_modules import ( + DeferredType, + DirectiveType, + InterfaceType, + ObjectType, + make_executable_schema, +) + + +def test_object_type_raises_attribute_error_when_defined_without_schema(snapshot): + with pytest.raises(AttributeError) as err: + # pylint: disable=unused-variable + class UserType(ObjectType): + pass + + snapshot.assert_match(err) + + +def test_object_type_raises_error_when_defined_with_invalid_schema_type(snapshot): + with pytest.raises(TypeError) as err: + # pylint: disable=unused-variable + class UserType(ObjectType): + __schema__ = True + + snapshot.assert_match(err) + + +def test_object_type_raises_error_when_defined_with_invalid_schema_str(snapshot): + with pytest.raises(GraphQLError) as err: + # pylint: disable=unused-variable + class UserType(ObjectType): + __schema__ = "typo User" + + snapshot.assert_match(err) + + +def test_object_type_raises_error_when_defined_with_invalid_graphql_type_schema( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserType(ObjectType): + __schema__ = "scalar DateTime" + + snapshot.assert_match(err) + + +def test_object_type_raises_error_when_defined_with_multiple_types_schema(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserType(ObjectType): + __schema__ = """ + type User + + type Group + """ + + snapshot.assert_match(err) + + +def test_object_type_raises_error_when_defined_without_fields(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserType(ObjectType): + __schema__ = "type User" + + snapshot.assert_match(err) + + +def test_object_type_extracts_graphql_name(): + class GroupType(ObjectType): + __schema__ = """ + type Group { + id: ID! + } + """ + + assert GroupType.graphql_name == "Group" + + +def test_object_type_raises_error_when_defined_without_return_type_dependency(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserType(ObjectType): + __schema__ = """ + type User { + group: Group + groups: [Group!] + } + """ + + snapshot.assert_match(err) + + +def test_object_type_verifies_field_dependency(): + # pylint: disable=unused-variable + class GroupType(ObjectType): + __schema__ = """ + type Group { + id: ID! + } + """ + + class UserType(ObjectType): + __schema__ = """ + type User { + group: Group + groups: [Group!] + } + """ + __requires__ = [GroupType] + + +def test_object_type_verifies_circular_dependency(): + # pylint: disable=unused-variable + class UserType(ObjectType): + __schema__ = """ + type User { + follows: User + } + """ + + +def test_object_type_raises_error_when_defined_without_argument_type_dependency( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserType(ObjectType): + __schema__ = """ + type User { + actions(input: UserInput): [String!]! + } + """ + + snapshot.assert_match(err) + + +def test_object_type_verifies_circular_dependency_using_deferred_type(): + # pylint: disable=unused-variable + class GroupType(ObjectType): + __schema__ = """ + type Group { + id: ID! + users: [User] + } + """ + __requires__ = [DeferredType("User")] + + class UserType(ObjectType): + __schema__ = """ + type User { + group: Group + } + """ + __requires__ = [GroupType] + + +def test_object_type_can_be_extended_with_new_fields(): + # pylint: disable=unused-variable + class UserType(ObjectType): + __schema__ = """ + type User { + id: ID! + } + """ + + class ExtendUserType(ObjectType): + __schema__ = """ + extend type User { + name: String + } + """ + __requires__ = [UserType] + + +def test_object_type_can_be_extended_with_directive(): + # pylint: disable=unused-variable + class ExampleDirective(DirectiveType): + __schema__ = "directive @example on OBJECT" + __visitor__ = SchemaDirectiveVisitor + + class UserType(ObjectType): + __schema__ = """ + type User { + id: ID! + } + """ + + class ExtendUserType(ObjectType): + __schema__ = """ + extend type User @example + """ + __requires__ = [UserType, ExampleDirective] + + +def test_object_type_can_be_extended_with_interface(): + # pylint: disable=unused-variable + class ExampleInterface(InterfaceType): + __schema__ = """ + interface Interface { + id: ID! + } + """ + + class UserType(ObjectType): + __schema__ = """ + type User { + id: ID! + } + """ + + class ExtendUserType(ObjectType): + __schema__ = """ + extend type User implements Interface + """ + __requires__ = [UserType, ExampleInterface] + + +def test_object_type_raises_error_when_defined_without_extended_dependency(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ExtendUserType(ObjectType): + __schema__ = """ + extend type User { + name: String + } + """ + + snapshot.assert_match(err) + + +def test_object_type_raises_error_when_extended_dependency_is_wrong_type(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ExampleInterface(InterfaceType): + __schema__ = """ + interface Example { + id: ID! + } + """ + + class ExampleType(ObjectType): + __schema__ = """ + extend type Example { + name: String + } + """ + __requires__ = [ExampleInterface] + + snapshot.assert_match(err) + + +def test_object_type_raises_error_when_defined_with_alias_for_nonexisting_field( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserType(ObjectType): + __schema__ = """ + type User { + name: String + } + """ + __aliases__ = { + "joinedDate": "joined_date", + } + + snapshot.assert_match(err) + + +def test_object_type_raises_error_when_defined_with_resolver_for_nonexisting_field( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserType(ObjectType): + __schema__ = """ + type User { + name: String + } + """ + + @staticmethod + def resolve_group(*_): + return None + + snapshot.assert_match(err) + + +def test_object_type_raises_error_when_defined_with_field_args_for_nonexisting_field( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserType(ObjectType): + __schema__ = """ + type User { + name: String + } + """ + __fields_args__ = {"group": {}} + + snapshot.assert_match(err) + + +def test_object_type_raises_error_when_defined_with_field_args_for_nonexisting_arg( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UserType(ObjectType): + __schema__ = """ + type User { + name: String + } + """ + __fields_args__ = {"name": {"arg": "arg2"}} + + snapshot.assert_match(err) + + +class QueryType(ObjectType): + __schema__ = """ + type Query { + field: String! + other: String! + firstField: String! + secondField: String! + } + """ + __aliases__ = { + "firstField": "first_field", + "secondField": "second_field", + } + + @staticmethod + def resolve_other(*_): + return "Word Up!" + + @staticmethod + def resolve_second_field(obj, *_): + return "Obj: %s" % obj["secondField"] + + +schema = make_executable_schema(QueryType) + + +def test_object_resolves_field_with_default_resolver(): + result = graphql_sync(schema, "{ field }", root_value={"field": "Hello!"}) + assert result.data["field"] == "Hello!" + + +def test_object_resolves_field_with_custom_resolver(): + result = graphql_sync(schema, "{ other }") + assert result.data["other"] == "Word Up!" + + +def test_object_resolves_field_with_aliased_default_resolver(): + result = graphql_sync( + schema, "{ firstField }", root_value={"first_field": "Howdy?"} + ) + assert result.data["firstField"] == "Howdy?" + + +def test_object_resolves_field_with_aliased_custom_resolver(): + result = graphql_sync(schema, "{ secondField }", root_value={"secondField": "Hey!"}) + assert result.data["secondField"] == "Obj: Hey!" diff --git a/tests/test_scalar_type.py b/tests/test_scalar_type.py new file mode 100644 index 0000000..ebdc18d --- /dev/null +++ b/tests/test_scalar_type.py @@ -0,0 +1,281 @@ +from datetime import date, datetime + +import pytest +from ariadne import SchemaDirectiveVisitor +from graphql import GraphQLError, StringValueNode, graphql_sync + +from ariadne_graphql_modules import ( + DirectiveType, + ObjectType, + ScalarType, + make_executable_schema, +) + + +def test_scalar_type_raises_attribute_error_when_defined_without_schema(snapshot): + with pytest.raises(AttributeError) as err: + # pylint: disable=unused-variable + class DateScalar(ScalarType): + pass + + snapshot.assert_match(err) + + +def test_scalar_type_raises_error_when_defined_with_invalid_schema_type(snapshot): + with pytest.raises(TypeError) as err: + # pylint: disable=unused-variable + class DateScalar(ScalarType): + __schema__ = True + + snapshot.assert_match(err) + + +def test_scalar_type_raises_error_when_defined_with_invalid_schema_str(snapshot): + with pytest.raises(GraphQLError) as err: + # pylint: disable=unused-variable + class DateScalar(ScalarType): + __schema__ = "scalor Date" + + snapshot.assert_match(err) + + +def test_scalar_type_raises_error_when_defined_with_invalid_graphql_type_schema( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class DateScalar(ScalarType): + __schema__ = "type DateTime" + + snapshot.assert_match(err) + + +def test_scalar_type_raises_error_when_defined_with_multiple_types_schema(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class DateScalar(ScalarType): + __schema__ = """ + scalar Date + + scalar DateTime + """ + + snapshot.assert_match(err) + + +def test_scalar_type_extracts_graphql_name(): + class DateScalar(ScalarType): + __schema__ = "scalar Date" + + assert DateScalar.graphql_name == "Date" + + +def test_scalar_type_can_be_extended_with_directive(): + # pylint: disable=unused-variable + class ExampleDirective(DirectiveType): + __schema__ = "directive @example on SCALAR" + __visitor__ = SchemaDirectiveVisitor + + class DateScalar(ScalarType): + __schema__ = "scalar Date" + + class ExtendDateScalar(ScalarType): + __schema__ = "extend scalar Date @example" + __requires__ = [DateScalar, ExampleDirective] + + +class DateReadOnlyScalar(ScalarType): + __schema__ = "scalar DateReadOnly" + + @staticmethod + def serialize(date): + return date.strftime("%Y-%m-%d") + + +class DateInputScalar(ScalarType): + __schema__ = "scalar DateInput" + + @staticmethod + def parse_value(formatted_date): + parsed_datetime = datetime.strptime(formatted_date, "%Y-%m-%d") + return parsed_datetime.date() + + @staticmethod + def parse_literal(ast, variable_values=None): # pylint: disable=unused-argument + if not isinstance(ast, StringValueNode): + raise ValueError() + + formatted_date = ast.value + parsed_datetime = datetime.strptime(formatted_date, "%Y-%m-%d") + return parsed_datetime.date() + + +class DefaultParserScalar(ScalarType): + __schema__ = "scalar DefaultParser" + + @staticmethod + def parse_value(value): + return type(value).__name__ + + +TEST_DATE = date(2006, 9, 13) +TEST_DATE_SERIALIZED = TEST_DATE.strftime("%Y-%m-%d") + + +class QueryType(ObjectType): + __schema__ = """ + type Query { + testSerialize: DateReadOnly! + testInput(value: DateInput!): Boolean! + testInputValueType(value: DefaultParser!): String! + } + """ + __requires__ = [ + DateReadOnlyScalar, + DateInputScalar, + DefaultParserScalar, + ] + __aliases__ = { + "testSerialize": "test_serialize", + "testInput": "test_input", + "testInputValueType": "test_input_value_type", + } + + @staticmethod + def resolve_test_serialize(*_): + return TEST_DATE + + @staticmethod + def resolve_test_input(*_, value): + assert value == TEST_DATE + return True + + @staticmethod + def resolve_test_input_value_type(*_, value): + return value + + +schema = make_executable_schema(QueryType) + + +def test_attempt_deserialize_str_literal_without_valid_date_raises_error(): + test_input = "invalid string" + result = graphql_sync(schema, '{ testInput(value: "%s") }' % test_input) + assert result.errors is not None + assert str(result.errors[0]).splitlines()[:1] == [ + "Expected value of type 'DateInput!', found \"invalid string\"; " + "time data 'invalid string' does not match format '%Y-%m-%d'" + ] + + +def test_attempt_deserialize_wrong_type_literal_raises_error(): + test_input = 123 + result = graphql_sync(schema, "{ testInput(value: %s) }" % test_input) + assert result.errors is not None + assert str(result.errors[0]).splitlines()[:1] == [ + "Expected value of type 'DateInput!', found 123; " + ] + + +def test_default_literal_parser_is_used_to_extract_value_str_from_ast_node(): + class ValueParserOnlyScalar(ScalarType): + __schema__ = "scalar DateInput" + + @staticmethod + def parse_value(formatted_date): + parsed_datetime = datetime.strptime(formatted_date, "%Y-%m-%d") + return parsed_datetime.date() + + class ValueParserOnlyQueryType(ObjectType): + __schema__ = """ + type Query { + parse(value: DateInput!): String! + } + """ + __requires__ = [ValueParserOnlyScalar] + + @staticmethod + def resolve_parse(*_, value): + return value + + schema = make_executable_schema(ValueParserOnlyQueryType) + result = graphql_sync(schema, """{ parse(value: "%s") }""" % TEST_DATE_SERIALIZED) + assert result.errors is None + assert result.data == {"parse": "2006-09-13"} + + +parametrized_query = """ + query parseValueTest($value: DateInput!) { + testInput(value: $value) + } +""" + + +def test_variable_with_valid_date_string_is_deserialized_to_python_date(): + variables = {"value": TEST_DATE_SERIALIZED} + result = graphql_sync(schema, parametrized_query, variable_values=variables) + assert result.errors is None + assert result.data == {"testInput": True} + + +def test_attempt_deserialize_str_variable_without_valid_date_raises_error(): + variables = {"value": "invalid string"} + result = graphql_sync(schema, parametrized_query, variable_values=variables) + assert result.errors is not None + assert str(result.errors[0]).splitlines()[:1] == [ + "Variable '$value' got invalid value 'invalid string'; " + "Expected type 'DateInput'. " + "time data 'invalid string' does not match format '%Y-%m-%d'" + ] + + +def test_attempt_deserialize_wrong_type_variable_raises_error(): + variables = {"value": 123} + result = graphql_sync(schema, parametrized_query, variable_values=variables) + assert result.errors is not None + assert str(result.errors[0]).splitlines()[:1] == [ + "Variable '$value' got invalid value 123; Expected type 'DateInput'. " + "strptime() argument 1 must be str, not int" + ] + + +def test_literal_string_is_deserialized_by_default_parser(): + result = graphql_sync(schema, '{ testInputValueType(value: "test") }') + assert result.errors is None + assert result.data == {"testInputValueType": "str"} + + +def test_literal_int_is_deserialized_by_default_parser(): + result = graphql_sync(schema, "{ testInputValueType(value: 123) }") + assert result.errors is None + assert result.data == {"testInputValueType": "int"} + + +def test_literal_float_is_deserialized_by_default_parser(): + result = graphql_sync(schema, "{ testInputValueType(value: 1.5) }") + assert result.errors is None + assert result.data == {"testInputValueType": "float"} + + +def test_literal_bool_true_is_deserialized_by_default_parser(): + result = graphql_sync(schema, "{ testInputValueType(value: true) }") + assert result.errors is None + assert result.data == {"testInputValueType": "bool"} + + +def test_literal_bool_false_is_deserialized_by_default_parser(): + result = graphql_sync(schema, "{ testInputValueType(value: false) }") + assert result.errors is None + assert result.data == {"testInputValueType": "bool"} + + +def test_literal_object_is_deserialized_by_default_parser(): + result = graphql_sync(schema, "{ testInputValueType(value: {}) }") + assert result.errors is None + assert result.data == {"testInputValueType": "dict"} + + +def test_literal_list_is_deserialized_by_default_parser(): + result = graphql_sync(schema, "{ testInputValueType(value: []) }") + assert result.errors is None + assert result.data == {"testInputValueType": "list"} diff --git a/tests/test_subscription_type.py b/tests/test_subscription_type.py new file mode 100644 index 0000000..d870060 --- /dev/null +++ b/tests/test_subscription_type.py @@ -0,0 +1,319 @@ +import pytest +from ariadne import SchemaDirectiveVisitor +from graphql import GraphQLError, build_schema + +from ariadne_graphql_modules import ( + DirectiveType, + InterfaceType, + ObjectType, + SubscriptionType, +) + + +def test_subscription_type_raises_attribute_error_when_defined_without_schema(snapshot): + with pytest.raises(AttributeError) as err: + # pylint: disable=unused-variable + class UsersSubscription(SubscriptionType): + pass + + snapshot.assert_match(err) + + +def test_subscription_type_raises_error_when_defined_with_invalid_schema_type(snapshot): + with pytest.raises(TypeError) as err: + # pylint: disable=unused-variable + class UsersSubscription(SubscriptionType): + __schema__ = True + + snapshot.assert_match(err) + + +def test_subscription_type_raises_error_when_defined_with_invalid_schema_str(snapshot): + with pytest.raises(GraphQLError) as err: + # pylint: disable=unused-variable + class UsersSubscription(SubscriptionType): + __schema__ = "typo Subscription" + + snapshot.assert_match(err) + + +def test_subscription_type_raises_error_when_defined_with_invalid_graphql_type_schema( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UsersSubscription(SubscriptionType): + __schema__ = "scalar Subscription" + + snapshot.assert_match(err) + + +def test_subscription_type_raises_error_when_defined_with_invalid_graphql_type_name( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UsersSubscription(SubscriptionType): + __schema__ = "type Other" + + snapshot.assert_match(err) + + +def test_subscription_type_raises_error_when_defined_without_fields(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class UsersSubscription(SubscriptionType): + __schema__ = "type Subscription" + + snapshot.assert_match(err) + + +def test_subscription_type_extracts_graphql_name(): + class UsersSubscription(SubscriptionType): + __schema__ = """ + type Subscription { + thread: ID! + } + """ + + assert UsersSubscription.graphql_name == "Subscription" + + +def test_subscription_type_raises_error_when_defined_without_return_type_dependency( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ChatSubscription(SubscriptionType): + __schema__ = """ + type Subscription { + chat: Chat + Chats: [Chat!] + } + """ + + snapshot.assert_match(err) + + +def test_subscription_type_verifies_field_dependency(): + # pylint: disable=unused-variable + class ChatType(ObjectType): + __schema__ = """ + type Chat { + id: ID! + } + """ + + class ChatSubscription(SubscriptionType): + __schema__ = """ + type Subscription { + chat: Chat + Chats: [Chat!] + } + """ + __requires__ = [ChatType] + + +def test_subscription_type_raises_error_when_defined_without_argument_type_dependency( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ChatSubscription(SubscriptionType): + __schema__ = """ + type Subscription { + chat(input: ChannelInput): [String!]! + } + """ + + snapshot.assert_match(err) + + +def test_subscription_type_can_be_extended_with_new_fields(): + # pylint: disable=unused-variable + class ChatSubscription(SubscriptionType): + __schema__ = """ + type Subscription { + chat: ID! + } + """ + + class ExtendChatSubscription(SubscriptionType): + __schema__ = """ + extend type Subscription { + thread: ID! + } + """ + __requires__ = [ChatSubscription] + + +def test_subscription_type_can_be_extended_with_directive(): + # pylint: disable=unused-variable + class ExampleDirective(DirectiveType): + __schema__ = "directive @example on OBJECT" + __visitor__ = SchemaDirectiveVisitor + + class ChatSubscription(SubscriptionType): + __schema__ = """ + type Subscription { + chat: ID! + } + """ + + class ExtendChatSubscription(SubscriptionType): + __schema__ = "extend type Subscription @example" + __requires__ = [ChatSubscription, ExampleDirective] + + +def test_subscription_type_can_be_extended_with_interface(): + # pylint: disable=unused-variable + class ExampleInterface(InterfaceType): + __schema__ = """ + interface Interface { + threads: ID! + } + """ + + class ChatSubscription(SubscriptionType): + __schema__ = """ + type Subscription { + chat: ID! + } + """ + + class ExtendChatSubscription(SubscriptionType): + __schema__ = """ + extend type Subscription implements Interface { + threads: ID! + } + """ + __requires__ = [ChatSubscription, ExampleInterface] + + +def test_subscription_type_raises_error_when_defined_without_extended_dependency( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ExtendChatSubscription(SubscriptionType): + __schema__ = """ + extend type Subscription { + thread: ID! + } + """ + + snapshot.assert_match(err) + + +def test_subscription_type_raises_error_when_extended_dependency_is_wrong_type( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ExampleInterface(InterfaceType): + __schema__ = """ + interface Subscription { + id: ID! + } + """ + + class ExtendChatSubscription(SubscriptionType): + __schema__ = """ + extend type Subscription { + thread: ID! + } + """ + __requires__ = [ExampleInterface] + + snapshot.assert_match(err) + + +def test_subscription_type_raises_error_when_defined_with_alias_for_nonexisting_field( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ChatSubscription(SubscriptionType): + __schema__ = """ + type Subscription { + chat: ID! + } + """ + __aliases__ = { + "userAlerts": "user_alerts", + } + + snapshot.assert_match(err) + + +def test_subscription_type_raises_error_when_defined_with_resolver_for_nonexisting_field( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ChatSubscription(SubscriptionType): + __schema__ = """ + type Subscription { + chat: ID! + } + """ + + @staticmethod + def resolve_group(*_): + return None + + snapshot.assert_match(err) + + +def test_subscription_type_raises_error_when_defined_with_sub_for_nonexisting_field( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ChatSubscription(SubscriptionType): + __schema__ = """ + type Subscription { + chat: ID! + } + """ + + @staticmethod + def subscribe_group(*_): + return None + + snapshot.assert_match(err) + + +def test_subscription_type_binds_resolver_and_subscriber_to_schema(): + schema = build_schema( + """ + type Query { + hello: String + } + + type Subscription { + chat: ID! + } + """ + ) + + class ChatSubscription(SubscriptionType): + __schema__ = """ + type Subscription { + chat: ID! + } + """ + + @staticmethod + def resolve_chat(*_): + return None + + @staticmethod + def subscribe_chat(*_): + return None + + ChatSubscription.__bind_to_schema__(schema) + + field = schema.type_map.get("Subscription").fields["chat"] + assert field.resolve is ChatSubscription.resolve_chat + assert field.subscribe is ChatSubscription.subscribe_chat diff --git a/tests/test_union_type.py b/tests/test_union_type.py new file mode 100644 index 0000000..5b1f30f --- /dev/null +++ b/tests/test_union_type.py @@ -0,0 +1,243 @@ +from dataclasses import dataclass + +import pytest +from ariadne import SchemaDirectiveVisitor +from graphql import GraphQLError, graphql_sync + +from ariadne_graphql_modules import ( + DirectiveType, + ObjectType, + UnionType, + make_executable_schema, +) + + +def test_union_type_raises_attribute_error_when_defined_without_schema(snapshot): + with pytest.raises(AttributeError) as err: + # pylint: disable=unused-variable + class ExampleUnion(UnionType): + pass + + snapshot.assert_match(err) + + +def test_union_type_raises_error_when_defined_with_invalid_schema_type(snapshot): + with pytest.raises(TypeError) as err: + # pylint: disable=unused-variable + class ExampleUnion(UnionType): + __schema__ = True + + snapshot.assert_match(err) + + +def test_union_type_raises_error_when_defined_with_invalid_schema_str(snapshot): + with pytest.raises(GraphQLError) as err: + # pylint: disable=unused-variable + class ExampleUnion(UnionType): + __schema__ = "unien Example = A | B" + + snapshot.assert_match(err) + + +def test_union_type_raises_error_when_defined_with_invalid_graphql_type_schema( + snapshot, +): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ExampleUnion(UnionType): + __schema__ = "scalar DateTime" + + snapshot.assert_match(err) + + +def test_union_type_raises_error_when_defined_with_multiple_types_schema(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ExampleUnion(UnionType): + __schema__ = """ + union A = C | D + + union B = C | D + """ + + snapshot.assert_match(err) + + +@dataclass +class User: + id: int + name: str + + +@dataclass +class Comment: + id: int + message: str + + +class UserType(ObjectType): + __schema__ = """ + type User { + id: ID! + name: String! + } + """ + + +class CommentType(ObjectType): + __schema__ = """ + type Comment { + id: ID! + message: String! + } + """ + + +class ResultUnion(UnionType): + __schema__ = "union Result = Comment | User" + __requires__ = [CommentType, UserType] + + @staticmethod + def resolve_type(instance, *_): + if isinstance(instance, Comment): + return "Comment" + + if isinstance(instance, User): + return "User" + + return None + + +class QueryType(ObjectType): + __schema__ = """ + type Query { + results: [Result!]! + } + """ + __requires__ = [ResultUnion] + + @staticmethod + def resolve_results(*_): + return [ + User(id=1, name="Alice"), + Comment(id=1, message="Hello world!"), + ] + + +schema = make_executable_schema(QueryType, UserType, CommentType) + + +def test_union_type_extracts_graphql_name(): + class ExampleUnion(UnionType): + __schema__ = "union Example = User | Comment" + __requires__ = [UserType, CommentType] + + assert ExampleUnion.graphql_name == "Example" + + +def test_union_type_raises_error_when_defined_without_member_type_dependency(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ExampleUnion(UnionType): + __schema__ = "union Example = User | Comment" + __requires__ = [UserType] + + snapshot.assert_match(err) + + +def test_interface_type_binds_type_resolver(): + query = """ + query { + results { + ... on User { + __typename + id + name + } + ... on Comment { + __typename + id + message + } + } + } + """ + + result = graphql_sync(schema, query) + assert result.data == { + "results": [ + { + "__typename": "User", + "id": "1", + "name": "Alice", + }, + { + "__typename": "Comment", + "id": "1", + "message": "Hello world!", + }, + ], + } + + +def test_union_type_can_be_extended_with_new_types(): + # pylint: disable=unused-variable + class ExampleUnion(UnionType): + __schema__ = "union Result = User | Comment" + __requires__ = [UserType, CommentType] + + class ThreadType(ObjectType): + __schema__ = """ + type Thread { + id: ID! + title: String! + } + """ + + class ExtendExampleUnion(UnionType): + __schema__ = "union Result = Thread" + __requires__ = [ExampleUnion, ThreadType] + + +def test_union_type_can_be_extended_with_directive(): + # pylint: disable=unused-variable + class ExampleDirective(DirectiveType): + __schema__ = "directive @example on UNION" + __visitor__ = SchemaDirectiveVisitor + + class ExampleUnion(UnionType): + __schema__ = "union Result = User | Comment" + __requires__ = [UserType, CommentType] + + class ExtendExampleUnion(UnionType): + __schema__ = """ + extend union Result @example + """ + __requires__ = [ExampleUnion, ExampleDirective] + + +def test_union_type_raises_error_when_defined_without_extended_dependency(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ExtendExampleUnion(UnionType): + __schema__ = "extend union Result = User" + __requires__ = [UserType] + + snapshot.assert_match(err) + + +def test_interface_type_raises_error_when_extended_dependency_is_wrong_type(snapshot): + with pytest.raises(ValueError) as err: + # pylint: disable=unused-variable + class ExampleType(ObjectType): + __schema__ = """ + type Example { + id: ID! + } + """ + + class ExtendExampleUnion(UnionType): + __schema__ = "extend union Example = User" + __requires__ = [ExampleType, UserType] + + snapshot.assert_match(err)