Skip to content

Commit

Permalink
adjust code to newer changes
Browse files Browse the repository at this point in the history
  • Loading branch information
DamianCzajkowski committed Aug 19, 2024
1 parent 10f38d2 commit ad692b3
Show file tree
Hide file tree
Showing 221 changed files with 757 additions and 560 deletions.
4 changes: 2 additions & 2 deletions ariadne_graphql_modules/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
DefinitionNode,
GraphQLSchema,
ObjectTypeDefinitionNode,
TypeDefinitionNode,
TypeSystemDefinitionNode,
)

from .dependencies import Dependencies
Expand Down Expand Up @@ -36,7 +36,7 @@ class DefinitionType(BaseType):

graphql_name: str
graphql_type: Type[DefinitionNode]
graphql_def: TypeDefinitionNode
graphql_def: TypeSystemDefinitionNode

@classmethod
def __get_requirements__(cls) -> RequirementsDict:
Expand Down
14 changes: 8 additions & 6 deletions ariadne_graphql_modules/mutation_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,22 @@ def __init_subclass__(cls) -> None:

cls.__abstract__ = False

graphql_def = cls.__validate_schema__(
cls.graphql_def = cls.__validate_schema__(
parse_definition(cls.__name__, cls.__schema__)
)

cls.graphql_name = graphql_def.name.value
cls.graphql_type = type(graphql_def)
cls.graphql_name = cls.graphql_def.name.value
cls.graphql_type = type(cls.graphql_def)

field = cls.__get_field__(graphql_def)
field = cls.__get_field__(cls.graphql_def)
cls.mutation_name = field.name.value

requirements = cls.__get_requirements__()
cls.__validate_requirements_contain_extended_type__(graphql_def, requirements)
cls.__validate_requirements_contain_extended_type__(
cls.graphql_def, requirements
)

dependencies = cls.__get_dependencies__(graphql_def)
dependencies = cls.__get_dependencies__(cls.graphql_def)
cls.__validate_requirements__(requirements, dependencies)

if callable(cls.__args__):
Expand Down
86 changes: 44 additions & 42 deletions ariadne_graphql_modules/next/compatibility_layer.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
from typing import List, Type
from enum import Enum
from inspect import isclass
from typing import Any, Dict, List, Type, Union, cast

from graphql import (
EnumTypeDefinitionNode,
InputObjectTypeDefinitionNode,
InterfaceTypeDefinitionNode,
NameNode,
ObjectTypeDefinitionNode,
ScalarTypeDefinitionNode,
TypeExtensionNode,
UnionTypeDefinitionNode,
)

from ariadne_graphql_modules.executable_schema import get_all_types
from ariadne_graphql_modules.next.inputtype import GraphQLInputModel
from ariadne_graphql_modules.next.interfacetype import (
GraphQLInterfaceModel,
)
from ariadne_graphql_modules.next.scalartype import GraphQLScalarModel
from ariadne_graphql_modules.next.subscriptiontype import GraphQLSubscriptionModel
from ariadne_graphql_modules.next.uniontype import GraphQLUnionModel
from ..executable_schema import get_all_types

from ..directive_type import DirectiveType
from ..enum_type import EnumType
Expand All @@ -30,10 +25,17 @@

from ..object_type import ObjectType

from .description import get_description_node
from ..bases import BindableType
from .base import GraphQLModel, GraphQLType
from . import GraphQLObjectModel, GraphQLEnumModel
from . import (
GraphQLObjectModel,
GraphQLEnumModel,
GraphQLInputModel,
GraphQLScalarModel,
GraphQLInterfaceModel,
GraphQLSubscriptionModel,
GraphQLUnionModel,
)


def wrap_legacy_types(
Expand All @@ -53,6 +55,8 @@ class LegacyGraphQLType(GraphQLType):

@classmethod
def __get_graphql_model__(cls, *_) -> GraphQLModel:
if issubclass(cls.__base_type__.graphql_type, TypeExtensionNode):
pass
if issubclass(cls.__base_type__, ObjectType):
return cls.construct_object_model(cls.__base_type__)
if issubclass(cls.__base_type__, EnumType):
Expand All @@ -69,51 +73,50 @@ def __get_graphql_model__(cls, *_) -> GraphQLModel:
return cls.construct_subscription_model(cls.__base_type__)
if issubclass(cls.__base_type__, UnionType):
return cls.construct_union_model(cls.__base_type__)
else:
raise ValueError(f"Unsupported base_type {cls.__base_type__}")
raise ValueError(f"Unsupported base_type {cls.__base_type__}")

@classmethod
def construct_object_model(
cls, base_type: Type[ObjectType]
) -> "GraphQLObjectModel":
name = base_type.graphql_name
description = base_type.__doc__

cls, base_type: Type[Union[ObjectType, MutationType]]
) -> GraphQLObjectModel:
return GraphQLObjectModel(
name=name,
name=base_type.graphql_name,
ast_type=ObjectTypeDefinitionNode,
ast=ObjectTypeDefinitionNode(
name=NameNode(value=name),
description=get_description_node(description),
fields=tuple(base_type.graphql_fields.values()),
interfaces=base_type.interfaces,
),
resolvers=base_type.resolvers,
aliases=base_type.__aliases__ or {},
out_names=base_type.__fields_args__ or {},
ast=cast(ObjectTypeDefinitionNode, base_type.graphql_def),
resolvers=base_type.resolvers, # type: ignore
aliases=base_type.__aliases__ or {}, # type: ignore
out_names={},
)

@classmethod
def construct_enum_model(cls, base_type: Type[EnumType]) -> GraphQLEnumModel:
members = base_type.__enum__ or {}
members_values: Dict[str, Any] = {}

if isinstance(members, dict):
members_values = dict(members.items())
elif isclass(members) and issubclass(members, Enum):
members_values = {member.name: member for member in members}

return GraphQLEnumModel(
name=base_type.graphql_name,
members=base_type.__enum__ or {},
members=members_values,
ast_type=EnumTypeDefinitionNode,
ast=base_type.graphql_def,
ast=cast(EnumTypeDefinitionNode, base_type.graphql_def),
)

@classmethod
def construct_directive_model(cls, base_type: Type[DirectiveType]) -> GraphQLModel:
def construct_directive_model(cls, base_type: Type[DirectiveType]):
"""TODO: https://github.com/mirumee/ariadne-graphql-modules/issues/29"""

@classmethod
def construct_input_model(cls, base_type: Type[InputType]) -> GraphQLInputModel:
return GraphQLInputModel(
name=base_type.graphql_name,
ast_type=InputObjectTypeDefinitionNode,
ast=base_type.graphql_def, # type: ignore
ast=cast(InputObjectTypeDefinitionNode, base_type.graphql_def),
out_type=base_type.graphql_type,
out_names=base_type.graphql_fields or {}, # type: ignore
out_names={},
)

@classmethod
Expand All @@ -123,19 +126,19 @@ def construct_interface_model(
return GraphQLInterfaceModel(
name=base_type.graphql_name,
ast_type=InterfaceTypeDefinitionNode,
ast=base_type.graphql_def,
ast=cast(InterfaceTypeDefinitionNode, base_type.graphql_def),
resolve_type=base_type.resolve_type,
resolvers=base_type.resolvers,
out_names={},
aliases=base_type.__aliases__ or {},
aliases=base_type.__aliases__ or {}, # type: ignore
)

@classmethod
def construct_scalar_model(cls, base_type: Type[ScalarType]) -> GraphQLScalarModel:
return GraphQLScalarModel(
name=base_type.graphql_name,
ast_type=ScalarTypeDefinitionNode,
ast=base_type.graphql_def,
ast=cast(ScalarTypeDefinitionNode, base_type.graphql_def),
serialize=base_type.serialize,
parse_value=base_type.parse_value,
parse_literal=base_type.parse_literal,
Expand All @@ -148,11 +151,10 @@ def construct_subscription_model(
return GraphQLSubscriptionModel(
name=base_type.graphql_name,
ast_type=ObjectTypeDefinitionNode,
ast=base_type.graphql_def,
resolve_type=None,
ast=cast(ObjectTypeDefinitionNode, base_type.graphql_def),
resolvers=base_type.resolvers,
aliases=base_type.__aliases__ or {},
out_names=base_type.__fields_args__ or {},
aliases=base_type.__aliases__ or {}, # type: ignore
out_names={},
subscribers=base_type.subscribers,
)

Expand All @@ -161,6 +163,6 @@ def construct_union_model(cls, base_type: Type[UnionType]) -> GraphQLUnionModel:
return GraphQLUnionModel(
name=base_type.graphql_name,
ast_type=UnionTypeDefinitionNode,
ast=base_type.graphql_def,
ast=cast(UnionTypeDefinitionNode, base_type.graphql_def),
resolve_type=base_type.resolve_type,
)
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, cast
from typing import Dict, cast

from ariadne import InterfaceType
from ariadne.types import Resolver
from graphql import GraphQLField, GraphQLObjectType, GraphQLSchema
from graphql import GraphQLField, GraphQLObjectType, GraphQLSchema, GraphQLTypeResolver

from ..base import GraphQLModel


@dataclass(frozen=True)
class GraphQLInterfaceModel(GraphQLModel):
resolvers: Dict[str, Resolver]
resolve_type: Callable[[Any], Any]
resolve_type: GraphQLTypeResolver
out_names: Dict[str, Dict[str, str]]
aliases: Dict[str, str]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __get_graphql_model_without_schema__(

@staticmethod
def resolve_type(obj: Any, *_) -> str:
if isinstance(obj, GraphQLInterface):
if isinstance(obj, GraphQLObject):
return obj.__get_graphql_name__()

raise ValueError(
Expand Down
15 changes: 10 additions & 5 deletions ariadne_graphql_modules/next/graphql_scalar/scalar_model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional
from typing import Optional

from graphql import GraphQLSchema, ValueNode
from graphql import (
GraphQLScalarLiteralParser,
GraphQLScalarSerializer,
GraphQLScalarValueParser,
GraphQLSchema,
)
from ariadne import ScalarType as ScalarTypeBindable
from ..base import GraphQLModel


@dataclass(frozen=True)
class GraphQLScalarModel(GraphQLModel):
serialize: Callable[[Any], Any]
parse_value: Callable[[Any], Any]
parse_literal: Callable[[ValueNode, Optional[Dict[str, Any]]], Any]
serialize: Optional[GraphQLScalarSerializer]
parse_value: Optional[GraphQLScalarValueParser]
parse_literal: Optional[GraphQLScalarLiteralParser]

def bind_to_schema(self, schema: GraphQLSchema):
bindable = ScalarTypeBindable(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, cast
from typing import Dict, cast

from ariadne import SubscriptionType
from ariadne.types import Resolver, Subscriber
Expand All @@ -11,7 +11,6 @@
@dataclass(frozen=True)
class GraphQLSubscriptionModel(GraphQLModel):
resolvers: Dict[str, Resolver]
resolve_type: Callable[[Any], Any]
out_names: Dict[str, Dict[str, str]]
aliases: Dict[str, str]
subscribers: Dict[str, Subscriber]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def __get_graphql_model_with_schema__(cls) -> "GraphQLModel":
fields=tuple(fields),
interfaces=definition.interfaces,
),
resolve_type=cls.resolve_type,
resolvers=resolvers,
subscribers=subscribers,
aliases=getattr(cls, "__aliases__", {}),
Expand Down Expand Up @@ -196,22 +195,12 @@ def __get_graphql_model_without_schema__(
fields=tuple(fields_ast),
interfaces=tuple(interfaces_ast),
),
resolve_type=cls.resolve_type,
resolvers=resolvers,
aliases=aliases,
out_names=out_names,
subscribers=subscribers,
)

@staticmethod
def resolve_type(obj: Any, *_) -> str:
if isinstance(obj, GraphQLSubscription):
return obj.__get_graphql_name__()

raise ValueError(
f"Cannot resolve GraphQL type {obj} for object of type '{type(obj).__name__}'."
)

@staticmethod
def source(
field: str,
Expand Down
5 changes: 2 additions & 3 deletions ariadne_graphql_modules/next/graphql_union/union_model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from dataclasses import dataclass
from typing import Any, Callable

from ariadne import UnionType
from graphql import GraphQLSchema
from graphql import GraphQLSchema, GraphQLTypeResolver

from ..base import GraphQLModel


@dataclass(frozen=True)
class GraphQLUnionModel(GraphQLModel):
resolve_type: Callable[[Any], Any]
resolve_type: GraphQLTypeResolver

def bind_to_schema(self, schema: GraphQLSchema):
bindable = UnionType(self.name, self.resolve_type)
Expand Down
15 changes: 8 additions & 7 deletions ariadne_graphql_modules/object_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,20 @@ def __init_subclass__(cls) -> None:

cls.__abstract__ = False

graphql_def = cls.__validate_schema__(
cls.graphql_def = cls.__validate_schema__(
parse_definition(cls.__name__, cls.__schema__)
)

cls.graphql_name = graphql_def.name.value
cls.graphql_type = type(graphql_def)
cls.graphql_fields = cls.__get_fields__(graphql_def)
cls.interfaces = graphql_def.interfaces
cls.graphql_name = cls.graphql_def.name.value
cls.graphql_type = type(cls.graphql_def)
cls.graphql_fields = cls.__get_fields__(cls.graphql_def)

requirements = cls.__get_requirements__()
cls.__validate_requirements_contain_extended_type__(graphql_def, requirements)
cls.__validate_requirements_contain_extended_type__(
cls.graphql_def, requirements
)

dependencies = cls.__get_dependencies__(graphql_def)
dependencies = cls.__get_dependencies__(cls.graphql_def)
cls.__validate_requirements__(requirements, dependencies)

if callable(cls.__fields_args__):
Expand Down
2 changes: 1 addition & 1 deletion ariadne_graphql_modules/union_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init_subclass__(cls) -> None:
)

cls.graphql_name = cls.graphql_def.name.value
cls.graphql_type = type(cls.graphql_def)
cls.graphql_type = type(cls.graphql_def) # type: ignore

requirements = cls.__get_requirements__()
cls.__validate_requirements_contain_extended_type__(
Expand Down
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from pathlib import Path
import pytest


@pytest.fixture(scope="session")
def datadir() -> Path:
return Path(__file__).parent / "snapshots"


@pytest.fixture(scope="session")
def original_datadir() -> Path:
return Path(__file__).parent / "snapshots"
Empty file removed tests/snapshots/__init__.py
Empty file.
Loading

0 comments on commit ad692b3

Please sign in to comment.