Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Base implementation of interface #37

Merged
merged 6 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ariadne_graphql_modules/next/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
from .sort import sort_schema_document
from .uniontype import GraphQLUnion, GraphQLUnionModel
from .value import get_value_from_node, get_value_node
from .interfacetype import GraphQLInterface, GraphQLInterfaceModel

__all__ = [
"GraphQLEnum",
"GraphQLEnumModel",
"GraphQLID",
"GraphQLInput",
"GraphQLInputModel",
"GraphQLInterface",
"GraphQLInterfaceModel",
"GraphQLMetadata",
"GraphQLModel",
"GraphQLObject",
Expand Down
213 changes: 213 additions & 0 deletions ariadne_graphql_modules/next/interfacetype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
from dataclasses import dataclass
from typing import (
Any,
Callable,
Dict,
List,
cast,
)

from ariadne import InterfaceType
from ariadne.types import Resolver
from graphql import (
FieldDefinitionNode,
GraphQLField,
GraphQLObjectType,
GraphQLSchema,
InputValueDefinitionNode,
NameNode,
NamedTypeNode,
InterfaceTypeDefinitionNode,
)

from .value import get_value_node

from .objecttype import (
GraphQLObject,
GraphQLObjectResolver,
get_field_args_from_resolver,
get_field_args_out_names,
get_field_node_from_obj_field,
get_graphql_object_data,
update_field_args_options,
)

from ..utils import parse_definition
from .base import GraphQLMetadata, GraphQLModel
from .description import get_description_node


class GraphQLInterface(GraphQLObject):
__abstract__: bool = True
__valid_type__ = InterfaceTypeDefinitionNode
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets repeat __abstract__: bool = True for posterity sake.


@classmethod
def __get_graphql_model_with_schema__(
cls, metadata: GraphQLMetadata, name: str
) -> "GraphQLInterfaceModel":
definition = cast(
InterfaceTypeDefinitionNode,
parse_definition(InterfaceTypeDefinitionNode, cls.__schema__),
)

descriptions: Dict[str, str] = {}
args_descriptions: Dict[str, Dict[str, str]] = {}
args_defaults: Dict[str, Dict[str, Any]] = {}
resolvers: Dict[str, Resolver] = {}
out_names: Dict[str, Dict[str, str]] = {}

for attr_name in dir(cls):
cls_attr = getattr(cls, attr_name)
if isinstance(cls_attr, GraphQLObjectResolver):
resolvers[cls_attr.field] = cls_attr.resolver
if cls_attr.description:
descriptions[cls_attr.field] = get_description_node(
cls_attr.description
)

field_args = get_field_args_from_resolver(cls_attr.resolver)
if field_args:
args_descriptions[cls_attr.field] = {}
args_defaults[cls_attr.field] = {}

final_args = update_field_args_options(field_args, cls_attr.args)

for arg_name, arg_options in final_args.items():
arg_description = get_description_node(arg_options.description)
if arg_description:
args_descriptions[cls_attr.field][arg_name] = (
arg_description
)

arg_default = arg_options.default_value
if arg_default is not None:
args_defaults[cls_attr.field][arg_name] = get_value_node(
arg_default
)

fields: List[FieldDefinitionNode] = []
for field in definition.fields:
field_args_descriptions = args_descriptions.get(field.name.value, {})
field_args_defaults = args_defaults.get(field.name.value, {})

args: List[InputValueDefinitionNode] = []
for arg in field.arguments:
arg_name = arg.name.value
args.append(
InputValueDefinitionNode(
description=(
arg.description or field_args_descriptions.get(arg_name)
),
name=arg.name,
directives=arg.directives,
type=arg.type,
default_value=(
arg.default_value or field_args_defaults.get(arg_name)
),
)
)

fields.append(
FieldDefinitionNode(
name=field.name,
description=(
field.description or descriptions.get(field.name.value)
),
directives=field.directives,
arguments=tuple(args),
type=field.type,
)
)

return GraphQLInterfaceModel(
name=definition.name.value,
ast_type=InterfaceTypeDefinitionNode,
ast=InterfaceTypeDefinitionNode(
name=NameNode(value=definition.name.value),
fields=tuple(fields),
interfaces=definition.interfaces,
),
resolve_type=cls.resolve_type,
resolvers=resolvers,
aliases=getattr(cls, "__aliases__", {}),
out_names=out_names,
)

@classmethod
def __get_graphql_model_without_schema__(
cls, metadata: GraphQLMetadata, name: str
) -> "GraphQLInterfaceModel":
type_data = get_graphql_object_data(metadata, cls)
type_aliases = getattr(cls, "__aliases__", None) or {}

fields_ast: List[FieldDefinitionNode] = []
resolvers: Dict[str, Resolver] = {}
aliases: Dict[str, str] = {}
out_names: Dict[str, Dict[str, str]] = {}

for attr_name, field in type_data.fields.items():
fields_ast.append(get_field_node_from_obj_field(cls, metadata, field))

if attr_name in type_aliases:
aliases[field.name] = type_aliases[attr_name]
elif attr_name != field.name:
aliases[field.name] = attr_name

if field.resolver:
resolvers[field.name] = field.resolver

if field.args:
out_names[field.name] = get_field_args_out_names(field.args)

interfaces_ast: List[NamedTypeNode] = []
for interface_name in type_data.interfaces:
interfaces_ast.append(NamedTypeNode(name=NameNode(value=interface_name)))

return GraphQLInterfaceModel(
name=name,
ast_type=InterfaceTypeDefinitionNode,
ast=InterfaceTypeDefinitionNode(
name=NameNode(value=name),
description=get_description_node(
getattr(cls, "__description__", None),
),
fields=tuple(fields_ast),
interfaces=tuple(interfaces_ast),
),
resolve_type=cls.resolve_type,
resolvers=resolvers,
aliases=aliases,
out_names=out_names,
)

@staticmethod
def resolve_type(obj: Any, *args) -> str:
if isinstance(obj, GraphQLInterface):
return obj.__get_graphql_name__()

raise ValueError(
f"Cannot resolve GraphQL type {obj} for object of type '{type(obj).__name__}'."
)


@dataclass(frozen=True)
class GraphQLInterfaceModel(GraphQLModel):
resolvers: Dict[str, Resolver]
resolve_type: Callable[[Any], Any]
out_names: Dict[str, Dict[str, str]]
aliases: Dict[str, str]

def bind_to_schema(self, schema: GraphQLSchema):
bindable = InterfaceType(self.name, self.resolve_type)
for field, resolver in self.resolvers.items():
bindable.set_field(field, resolver)
for alias, target in self.aliases.items():
bindable.set_alias(alias, target)

bindable.bind_to_schema(schema)

graphql_type = cast(GraphQLObjectType, schema.get_type(self.name))
for field_name, field_out_names in self.out_names.items():
graphql_field = cast(GraphQLField, graphql_type.fields[field_name])
for arg_name, out_name in field_out_names.items():
graphql_field.args[arg_name].out_name = out_name
39 changes: 29 additions & 10 deletions ariadne_graphql_modules/next/objecttype.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
GraphQLSchema,
InputValueDefinitionNode,
NameNode,
NamedTypeNode,
ObjectTypeDefinitionNode,
TypeDefinitionNode,
)

from ..utils import parse_definition
Expand All @@ -42,6 +44,7 @@ class GraphQLObject(GraphQLType):
__description__: Optional[str]
__aliases__: Optional[Dict[str, str]]
__requires__: Optional[Iterable[Union[Type[GraphQLType], Type[Enum]]]]
__implements__: Optional[Iterable[Type[GraphQLType]]]

def __init__(self, **kwargs: Any):
for kwarg in kwargs:
Expand All @@ -65,7 +68,8 @@ def __init_subclass__(cls) -> None:
cls.__abstract__ = False

if cls.__dict__.get("__schema__"):
cls.__kwargs__ = validate_object_type_with_schema(cls)
valid_type = getattr(cls, "__valid_type__", ObjectTypeDefinitionNode)
cls.__kwargs__ = validate_object_type_with_schema(cls, valid_type)
else:
cls.__kwargs__ = validate_object_type_without_schema(cls)

Expand Down Expand Up @@ -113,9 +117,9 @@ def __get_graphql_model_with_schema__(
for arg_name, arg_options in final_args.items():
arg_description = get_description_node(arg_options.description)
if arg_description:
args_descriptions[cls_attr.field][
arg_name
] = arg_description
args_descriptions[cls_attr.field][arg_name] = (
arg_description
)

arg_default = arg_options.default_value
if arg_default is not None:
Expand Down Expand Up @@ -163,6 +167,7 @@ def __get_graphql_model_with_schema__(
ast=ObjectTypeDefinitionNode(
name=NameNode(value=definition.name.value),
fields=tuple(fields),
interfaces=definition.interfaces,
),
resolvers=resolvers,
aliases=getattr(cls, "__aliases__", {}),
Expand Down Expand Up @@ -195,6 +200,10 @@ def __get_graphql_model_without_schema__(
if field.args:
out_names[field.name] = get_field_args_out_names(field.args)

interfaces_ast: List[NamedTypeNode] = []
for interface_name in type_data.interfaces:
interfaces_ast.append(NamedTypeNode(name=NameNode(value=interface_name)))

return GraphQLObjectModel(
name=name,
ast_type=ObjectTypeDefinitionNode,
Expand All @@ -204,6 +213,7 @@ def __get_graphql_model_without_schema__(
getattr(cls, "__description__", None),
),
fields=tuple(fields_ast),
interfaces=tuple(interfaces_ast),
),
resolvers=resolvers,
aliases=aliases,
Expand All @@ -226,6 +236,7 @@ def __get_graphql_types_with_schema__(
) -> Iterable["GraphQLType"]:
types: List[GraphQLType] = [cls]
types.extend(getattr(cls, "__requires__", []))
types.extend(getattr(cls, "__implements__", []))
return types

@classmethod
Expand Down Expand Up @@ -305,6 +316,7 @@ def argument(
@dataclass(frozen=True)
class GraphQLObjectData:
fields: Dict[str, "GraphQLObjectField"]
interfaces: List[str]


def get_graphql_object_data(
Expand Down Expand Up @@ -341,6 +353,10 @@ def create_graphql_object_data_without_schema(
aliases: Dict[str, str] = getattr(cls, "__aliases__", None) or {}
aliases_targets: List[str] = list(aliases.values())

interfaces: List[str] = [
interface.__name__ for interface in getattr(cls, "__implements__", [])
]

for attr_name, attr_type in type_hints.items():
if attr_name.startswith("__"):
continue
Expand Down Expand Up @@ -411,7 +427,7 @@ def create_graphql_object_data_without_schema(
default_value=fields_defaults.get(field_name),
)

return GraphQLObjectData(fields=fields)
return GraphQLObjectData(fields=fields, interfaces=interfaces)


class GraphQLObjectField:
Expand Down Expand Up @@ -592,7 +608,7 @@ def get_field_args_from_resolver(


def get_field_args_out_names(
field_args: Dict[str, GraphQLObjectFieldArg]
field_args: Dict[str, GraphQLObjectFieldArg],
) -> Dict[str, str]:
out_names: Dict[str, str] = {}
for field_arg in field_args.values():
Expand Down Expand Up @@ -661,15 +677,18 @@ def update_field_args_options(
return updated_args


def validate_object_type_with_schema(cls: Type[GraphQLObject]) -> Dict[str, Any]:
definition = parse_definition(ObjectTypeDefinitionNode, cls.__schema__)
def validate_object_type_with_schema(
cls: Type[GraphQLObject],
valid_type: Type[TypeDefinitionNode] = ObjectTypeDefinitionNode,
) -> Dict[str, Any]:
definition = parse_definition(valid_type, cls.__schema__)

if not isinstance(definition, ObjectTypeDefinitionNode):
if not isinstance(definition, valid_type):
raise ValueError(
f"Class '{cls.__name__}' defines '__schema__' attribute "
"with declaration for an invalid GraphQL type. "
f"('{definition.__class__.__name__}' != "
f"'{ObjectTypeDefinitionNode.__name__}')"
f"'{valid_type.__name__}')"
)

validate_name(cls, definition)
Expand Down
18 changes: 18 additions & 0 deletions tests_next/snapshots/snap_test_interface_type_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
# snapshottest: v1 - https://goo.gl/zC4yUc
from __future__ import unicode_literals

from snapshottest import Snapshot


snapshots = Snapshot()

snapshots['test_interface_no_interface_in_schema 1'] = "Unknown type 'BaseInterface'."

snapshots['test_interface_with_different_types 1'] = '''Query root type must be provided.

Interface field UserInterface.score expects type String! but User.score is type Int!.'''

snapshots['test_missing_interface_implementation 1'] = '''Query root type must be provided.

Interface field RequiredInterface.requiredField expected but Implementing does not provide it.'''
Loading
Loading