diff --git a/ariadne_graphql_modules/next/__init__.py b/ariadne_graphql_modules/next/__init__.py index 69261ee..0e4c512 100644 --- a/ariadne_graphql_modules/next/__init__.py +++ b/ariadne_graphql_modules/next/__init__.py @@ -20,6 +20,7 @@ from .sort import sort_schema_document from .uniontype import GraphQLUnion, GraphQLUnionModel from .value import get_value_from_node, get_value_node +from .interfacetype import GraphQLInterface, GraphQLInterfaceModel __all__ = [ "GraphQLEnum", @@ -27,6 +28,8 @@ "GraphQLID", "GraphQLInput", "GraphQLInputModel", + "GraphQLInterface", + "GraphQLInterfaceModel", "GraphQLMetadata", "GraphQLModel", "GraphQLObject", diff --git a/ariadne_graphql_modules/next/interfacetype.py b/ariadne_graphql_modules/next/interfacetype.py new file mode 100644 index 0000000..09fc6fb --- /dev/null +++ b/ariadne_graphql_modules/next/interfacetype.py @@ -0,0 +1,213 @@ +from dataclasses import dataclass +from typing import ( + Any, + Callable, + Dict, + List, + cast, +) + +from ariadne import InterfaceType +from ariadne.types import Resolver +from graphql import ( + FieldDefinitionNode, + GraphQLField, + GraphQLObjectType, + GraphQLSchema, + InputValueDefinitionNode, + NameNode, + NamedTypeNode, + InterfaceTypeDefinitionNode, +) + +from .value import get_value_node + +from .objecttype import ( + GraphQLObject, + GraphQLObjectResolver, + get_field_args_from_resolver, + get_field_args_out_names, + get_field_node_from_obj_field, + get_graphql_object_data, + update_field_args_options, +) + +from ..utils import parse_definition +from .base import GraphQLMetadata, GraphQLModel +from .description import get_description_node + + +class GraphQLInterface(GraphQLObject): + __abstract__: bool = True + __valid_type__ = InterfaceTypeDefinitionNode + + @classmethod + def __get_graphql_model_with_schema__( + cls, metadata: GraphQLMetadata, name: str + ) -> "GraphQLInterfaceModel": + definition = cast( + InterfaceTypeDefinitionNode, + parse_definition(InterfaceTypeDefinitionNode, cls.__schema__), + ) + + descriptions: Dict[str, str] = {} + args_descriptions: Dict[str, Dict[str, str]] = {} + args_defaults: Dict[str, Dict[str, Any]] = {} + resolvers: Dict[str, Resolver] = {} + out_names: Dict[str, Dict[str, str]] = {} + + for attr_name in dir(cls): + cls_attr = getattr(cls, attr_name) + if isinstance(cls_attr, GraphQLObjectResolver): + resolvers[cls_attr.field] = cls_attr.resolver + if cls_attr.description: + descriptions[cls_attr.field] = get_description_node( + cls_attr.description + ) + + field_args = get_field_args_from_resolver(cls_attr.resolver) + if field_args: + args_descriptions[cls_attr.field] = {} + args_defaults[cls_attr.field] = {} + + final_args = update_field_args_options(field_args, cls_attr.args) + + for arg_name, arg_options in final_args.items(): + arg_description = get_description_node(arg_options.description) + if arg_description: + args_descriptions[cls_attr.field][arg_name] = ( + arg_description + ) + + arg_default = arg_options.default_value + if arg_default is not None: + args_defaults[cls_attr.field][arg_name] = get_value_node( + arg_default + ) + + fields: List[FieldDefinitionNode] = [] + for field in definition.fields: + field_args_descriptions = args_descriptions.get(field.name.value, {}) + field_args_defaults = args_defaults.get(field.name.value, {}) + + args: List[InputValueDefinitionNode] = [] + for arg in field.arguments: + arg_name = arg.name.value + args.append( + InputValueDefinitionNode( + description=( + arg.description or field_args_descriptions.get(arg_name) + ), + name=arg.name, + directives=arg.directives, + type=arg.type, + default_value=( + arg.default_value or field_args_defaults.get(arg_name) + ), + ) + ) + + fields.append( + FieldDefinitionNode( + name=field.name, + description=( + field.description or descriptions.get(field.name.value) + ), + directives=field.directives, + arguments=tuple(args), + type=field.type, + ) + ) + + return GraphQLInterfaceModel( + name=definition.name.value, + ast_type=InterfaceTypeDefinitionNode, + ast=InterfaceTypeDefinitionNode( + name=NameNode(value=definition.name.value), + fields=tuple(fields), + interfaces=definition.interfaces, + ), + resolve_type=cls.resolve_type, + resolvers=resolvers, + aliases=getattr(cls, "__aliases__", {}), + out_names=out_names, + ) + + @classmethod + def __get_graphql_model_without_schema__( + cls, metadata: GraphQLMetadata, name: str + ) -> "GraphQLInterfaceModel": + type_data = get_graphql_object_data(metadata, cls) + type_aliases = getattr(cls, "__aliases__", None) or {} + + fields_ast: List[FieldDefinitionNode] = [] + resolvers: Dict[str, Resolver] = {} + aliases: Dict[str, str] = {} + out_names: Dict[str, Dict[str, str]] = {} + + for attr_name, field in type_data.fields.items(): + fields_ast.append(get_field_node_from_obj_field(cls, metadata, field)) + + if attr_name in type_aliases: + aliases[field.name] = type_aliases[attr_name] + elif attr_name != field.name: + aliases[field.name] = attr_name + + if field.resolver: + resolvers[field.name] = field.resolver + + if field.args: + out_names[field.name] = get_field_args_out_names(field.args) + + interfaces_ast: List[NamedTypeNode] = [] + for interface_name in type_data.interfaces: + interfaces_ast.append(NamedTypeNode(name=NameNode(value=interface_name))) + + return GraphQLInterfaceModel( + name=name, + ast_type=InterfaceTypeDefinitionNode, + ast=InterfaceTypeDefinitionNode( + name=NameNode(value=name), + description=get_description_node( + getattr(cls, "__description__", None), + ), + fields=tuple(fields_ast), + interfaces=tuple(interfaces_ast), + ), + resolve_type=cls.resolve_type, + resolvers=resolvers, + aliases=aliases, + out_names=out_names, + ) + + @staticmethod + def resolve_type(obj: Any, *args) -> str: + if isinstance(obj, GraphQLInterface): + return obj.__get_graphql_name__() + + raise ValueError( + f"Cannot resolve GraphQL type {obj} for object of type '{type(obj).__name__}'." + ) + + +@dataclass(frozen=True) +class GraphQLInterfaceModel(GraphQLModel): + resolvers: Dict[str, Resolver] + resolve_type: Callable[[Any], Any] + out_names: Dict[str, Dict[str, str]] + aliases: Dict[str, str] + + def bind_to_schema(self, schema: GraphQLSchema): + bindable = InterfaceType(self.name, self.resolve_type) + for field, resolver in self.resolvers.items(): + bindable.set_field(field, resolver) + for alias, target in self.aliases.items(): + bindable.set_alias(alias, target) + + bindable.bind_to_schema(schema) + + graphql_type = cast(GraphQLObjectType, schema.get_type(self.name)) + for field_name, field_out_names in self.out_names.items(): + graphql_field = cast(GraphQLField, graphql_type.fields[field_name]) + for arg_name, out_name in field_out_names.items(): + graphql_field.args[arg_name].out_name = out_name diff --git a/ariadne_graphql_modules/next/objecttype.py b/ariadne_graphql_modules/next/objecttype.py index 62c2722..9dcb028 100644 --- a/ariadne_graphql_modules/next/objecttype.py +++ b/ariadne_graphql_modules/next/objecttype.py @@ -23,7 +23,9 @@ GraphQLSchema, InputValueDefinitionNode, NameNode, + NamedTypeNode, ObjectTypeDefinitionNode, + TypeDefinitionNode, ) from ..utils import parse_definition @@ -42,6 +44,7 @@ class GraphQLObject(GraphQLType): __description__: Optional[str] __aliases__: Optional[Dict[str, str]] __requires__: Optional[Iterable[Union[Type[GraphQLType], Type[Enum]]]] + __implements__: Optional[Iterable[Type[GraphQLType]]] def __init__(self, **kwargs: Any): for kwarg in kwargs: @@ -65,7 +68,8 @@ def __init_subclass__(cls) -> None: cls.__abstract__ = False if cls.__dict__.get("__schema__"): - cls.__kwargs__ = validate_object_type_with_schema(cls) + valid_type = getattr(cls, "__valid_type__", ObjectTypeDefinitionNode) + cls.__kwargs__ = validate_object_type_with_schema(cls, valid_type) else: cls.__kwargs__ = validate_object_type_without_schema(cls) @@ -113,9 +117,9 @@ def __get_graphql_model_with_schema__( for arg_name, arg_options in final_args.items(): arg_description = get_description_node(arg_options.description) if arg_description: - args_descriptions[cls_attr.field][ - arg_name - ] = arg_description + args_descriptions[cls_attr.field][arg_name] = ( + arg_description + ) arg_default = arg_options.default_value if arg_default is not None: @@ -163,6 +167,7 @@ def __get_graphql_model_with_schema__( ast=ObjectTypeDefinitionNode( name=NameNode(value=definition.name.value), fields=tuple(fields), + interfaces=definition.interfaces, ), resolvers=resolvers, aliases=getattr(cls, "__aliases__", {}), @@ -195,6 +200,10 @@ def __get_graphql_model_without_schema__( if field.args: out_names[field.name] = get_field_args_out_names(field.args) + interfaces_ast: List[NamedTypeNode] = [] + for interface_name in type_data.interfaces: + interfaces_ast.append(NamedTypeNode(name=NameNode(value=interface_name))) + return GraphQLObjectModel( name=name, ast_type=ObjectTypeDefinitionNode, @@ -204,6 +213,7 @@ def __get_graphql_model_without_schema__( getattr(cls, "__description__", None), ), fields=tuple(fields_ast), + interfaces=tuple(interfaces_ast), ), resolvers=resolvers, aliases=aliases, @@ -226,6 +236,7 @@ def __get_graphql_types_with_schema__( ) -> Iterable["GraphQLType"]: types: List[GraphQLType] = [cls] types.extend(getattr(cls, "__requires__", [])) + types.extend(getattr(cls, "__implements__", [])) return types @classmethod @@ -305,6 +316,7 @@ def argument( @dataclass(frozen=True) class GraphQLObjectData: fields: Dict[str, "GraphQLObjectField"] + interfaces: List[str] def get_graphql_object_data( @@ -341,6 +353,10 @@ def create_graphql_object_data_without_schema( aliases: Dict[str, str] = getattr(cls, "__aliases__", None) or {} aliases_targets: List[str] = list(aliases.values()) + interfaces: List[str] = [ + interface.__name__ for interface in getattr(cls, "__implements__", []) + ] + for attr_name, attr_type in type_hints.items(): if attr_name.startswith("__"): continue @@ -411,7 +427,7 @@ def create_graphql_object_data_without_schema( default_value=fields_defaults.get(field_name), ) - return GraphQLObjectData(fields=fields) + return GraphQLObjectData(fields=fields, interfaces=interfaces) class GraphQLObjectField: @@ -592,7 +608,7 @@ def get_field_args_from_resolver( def get_field_args_out_names( - field_args: Dict[str, GraphQLObjectFieldArg] + field_args: Dict[str, GraphQLObjectFieldArg], ) -> Dict[str, str]: out_names: Dict[str, str] = {} for field_arg in field_args.values(): @@ -661,15 +677,18 @@ def update_field_args_options( return updated_args -def validate_object_type_with_schema(cls: Type[GraphQLObject]) -> Dict[str, Any]: - definition = parse_definition(ObjectTypeDefinitionNode, cls.__schema__) +def validate_object_type_with_schema( + cls: Type[GraphQLObject], + valid_type: Type[TypeDefinitionNode] = ObjectTypeDefinitionNode, +) -> Dict[str, Any]: + definition = parse_definition(valid_type, cls.__schema__) - if not isinstance(definition, ObjectTypeDefinitionNode): + if not isinstance(definition, valid_type): raise ValueError( f"Class '{cls.__name__}' defines '__schema__' attribute " "with declaration for an invalid GraphQL type. " f"('{definition.__class__.__name__}' != " - f"'{ObjectTypeDefinitionNode.__name__}')" + f"'{valid_type.__name__}')" ) validate_name(cls, definition) diff --git a/tests_next/snapshots/snap_test_interface_type_validation.py b/tests_next/snapshots/snap_test_interface_type_validation.py new file mode 100644 index 0000000..fc72186 --- /dev/null +++ b/tests_next/snapshots/snap_test_interface_type_validation.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +# snapshottest: v1 - https://goo.gl/zC4yUc +from __future__ import unicode_literals + +from snapshottest import Snapshot + + +snapshots = Snapshot() + +snapshots['test_interface_no_interface_in_schema 1'] = "Unknown type 'BaseInterface'." + +snapshots['test_interface_with_different_types 1'] = '''Query root type must be provided. + +Interface field UserInterface.score expects type String! but User.score is type Int!.''' + +snapshots['test_missing_interface_implementation 1'] = '''Query root type must be provided. + +Interface field RequiredInterface.requiredField expected but Implementing does not provide it.''' diff --git a/tests_next/test_interface_type.py b/tests_next/test_interface_type.py new file mode 100644 index 0000000..163ce19 --- /dev/null +++ b/tests_next/test_interface_type.py @@ -0,0 +1,278 @@ +from typing import List + +from graphql import graphql_sync + +from ariadne_graphql_modules.next import ( + GraphQLID, + GraphQLObject, + GraphQLInterface, + GraphQLUnion, + make_executable_schema, +) + + +class CommentType(GraphQLObject): + id: GraphQLID + content: str + + +def test_interface_without_schema(assert_schema_equals): + class UserInterface(GraphQLInterface): + summary: str + score: int + + class UserType(GraphQLObject): + name: str + summary: str + score: int + + __implements__ = [UserInterface] + + class ResultType(GraphQLUnion): + __types__ = [UserType, CommentType] + + class QueryType(GraphQLObject): + @GraphQLObject.field(type=List[ResultType]) + def search(*_) -> List[UserType | CommentType]: + return [ + UserType(id=1, username="Bob"), + CommentType(id=2, content="Hello World!"), + ] + + schema = make_executable_schema(QueryType, UserInterface, UserType) + + assert_schema_equals( + schema, + """ + type Query { + search: [Result!]! + } + + union Result = User | Comment + + type User implements UserInterface { + name: String! + summary: String! + score: Int! + } + + type Comment { + id: ID! + content: String! + } + + interface UserInterface { + summary: String! + score: Int! + } + + """, + ) + + +def test_interface_with_schema(assert_schema_equals): + class UserInterface(GraphQLInterface): + __schema__ = """ + interface UserInterface { + summary: String! + score: Int! + } + """ + + class UserType(GraphQLObject): + __schema__ = """ + type User implements UserInterface { + id: ID! + name: String! + summary: String! + score: Int! + } + """ + + __implements__ = [UserInterface] + + class ResultType(GraphQLUnion): + __types__ = [UserType, CommentType] + + class QueryType(GraphQLObject): + @GraphQLObject.field(type=List[ResultType]) + def search(*_) -> List[UserType | CommentType]: + return [ + UserType(id=1, username="Bob"), + CommentType(id=2, content="Hello World!"), + ] + + schema = make_executable_schema(QueryType, UserType) + + assert_schema_equals( + schema, + """ + type Query { + search: [Result!]! + } + + union Result = User | Comment + + type User implements UserInterface { + id: ID! + name: String! + summary: String! + score: Int! + } + + interface UserInterface { + summary: String! + score: Int! + } + + type Comment { + id: ID! + content: String! + } + + """, + ) + + +def test_interface_inheritance(assert_schema_equals): + class BaseEntityInterface(GraphQLInterface): + id: GraphQLID + + class UserInterface(GraphQLInterface): + id: GraphQLID + username: str + + __implements__ = [BaseEntityInterface] + + class UserType(GraphQLObject): + id: GraphQLID + username: str + + __implements__ = [UserInterface, BaseEntityInterface] + + class QueryType(GraphQLObject): + @GraphQLObject.field + def user(*_) -> UserType: + return UserType(id="1", username="test_user") + + schema = make_executable_schema( + QueryType, BaseEntityInterface, UserInterface, UserType + ) + + assert_schema_equals( + schema, + """ + type Query { + user: User! + } + + type User implements UserInterface & BaseEntityInterface { + id: ID! + username: String! + } + + interface UserInterface implements BaseEntityInterface { + id: ID! + username: String! + } + + interface BaseEntityInterface { + id: ID! + } + """, + ) + + +def test_interface_descriptions(assert_schema_equals): + class UserInterface(GraphQLInterface): + summary: str + score: int + + __description__ = "Lorem ipsum." + + class UserType(GraphQLObject): + id: GraphQLID + username: str + summary: str + score: int + + __implements__ = [UserInterface] + + class QueryType(GraphQLObject): + @GraphQLObject.field + def user(*_) -> UserType: + return UserType(id="1", username="test_user") + + schema = make_executable_schema(QueryType, UserType, UserInterface) + + assert_schema_equals( + schema, + """ + type Query { + user: User! + } + + type User implements UserInterface { + id: ID! + username: String! + summary: String! + score: Int! + } + + \"\"\"Lorem ipsum.\"\"\" + interface UserInterface { + summary: String! + score: Int! + } + """, + ) + + +def test_interface_resolvers_and_field_descriptions(assert_schema_equals): + class UserInterface(GraphQLInterface): + summary: str + score: int + + @GraphQLInterface.resolver("score", description="Lorem ipsum.") + def resolve_score(*_): + return 200 + + class UserType(GraphQLObject): + id: GraphQLID + summary: str + score: int + + __implements__ = [UserInterface] + + class QueryType(GraphQLObject): + @GraphQLObject.field + def user(*_) -> UserType: + return UserType(id="1") + + schema = make_executable_schema(QueryType, UserType, UserInterface) + + assert_schema_equals( + schema, + """ + type Query { + user: User! + } + + type User implements UserInterface { + id: ID! + summary: String! + score: Int! + } + + interface UserInterface { + summary: String! + + \"\"\"Lorem ipsum.\"\"\" + score: Int! + } + """, + ) + result = graphql_sync(schema, "{ user { score } }") + + assert not result.errors + assert result.data == {"user": {"score": 200}} diff --git a/tests_next/test_interface_type_validation.py b/tests_next/test_interface_type_validation.py new file mode 100644 index 0000000..bc2f414 --- /dev/null +++ b/tests_next/test_interface_type_validation.py @@ -0,0 +1,61 @@ +import pytest + +from ariadne_graphql_modules.next import ( + GraphQLID, + GraphQLObject, + GraphQLInterface, + make_executable_schema, +) + + +def test_interface_with_different_types(snapshot): + with pytest.raises(TypeError) as exc_info: + + class UserInterface(GraphQLInterface): + summary: str + score: str + + class UserType(GraphQLObject): + name: str + summary: str + score: int + + __implements__ = [UserInterface] + + make_executable_schema(UserType, UserInterface) + + snapshot.assert_match(str(exc_info.value)) + + +def test_missing_interface_implementation(snapshot): + with pytest.raises(TypeError) as exc_info: + + class RequiredInterface(GraphQLInterface): + required_field: str + + class ImplementingType(GraphQLObject): + optional_field: str + + __implements__ = [RequiredInterface] + + make_executable_schema(ImplementingType, RequiredInterface) + + snapshot.assert_match(str(exc_info.value)) + + +def test_interface_no_interface_in_schema(snapshot): + with pytest.raises(TypeError) as exc_info: + + class BaseInterface(GraphQLInterface): + id: GraphQLID + + class UserType(GraphQLObject): + id: GraphQLID + username: str + email: str + + __implements__ = [BaseInterface] + + make_executable_schema(UserType) + + snapshot.assert_match(str(exc_info.value))