Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ NEW: Initial implementation of GraphQL mutations #24

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions aiida_restapi/graphql/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
"""Defines plugins for AiiDA groups."""
# pylint: disable=too-few-public-methods,redefined-builtin,,unused-argument

from typing import Any, Optional
from typing import Any, Optional, Tuple

import graphene as gr
from aiida.cmdline.utils.decorators import with_dbenv
from aiida.orm import Group

from aiida_restapi.filter_syntax import parse_filter_str
from aiida_restapi.graphql.nodes import NodesQuery
from aiida_restapi.graphql.plugins import QueryPlugin
from aiida_restapi.graphql.plugins import MutationPlugin, QueryPlugin

from .orm_factories import (
ENTITY_DICT_TYPE,
Expand Down Expand Up @@ -73,3 +74,47 @@ def resolve_Groups(
),
resolve_Groups,
)


class GroupCreate(gr.Mutation):
"""Create an AiiDA group (or change an existing one)."""

class Arguments:
"""The arguments to create a group."""

label = gr.String(required=True)
description = gr.String(default_value="")
type_string = gr.String()

created = gr.Boolean(
description="Whether the group was created or already existed."
)
group = gr.Field(lambda: GroupQuery)

@with_dbenv()
@staticmethod
def mutate(
root: Any,
info: gr.ResolveInfo,
label: str,
description: str = "",
type_string: Optional[str] = None,
) -> "GroupCreate":
"""Create the group and return the requested fields."""
output: Tuple[Group, bool] = Group.objects.get_or_create(
label=label, description=description, type_string=type_string
)
orm_group, created = output
if not created and not orm_group.description == description:
orm_group.description = description
group = GroupQuery(
id=orm_group.id,
uuid=orm_group.uuid,
label=orm_group.label,
type_string=orm_group.type_string,
description=orm_group.description,
)
return GroupCreate(group=group, created=created)


GroupCreatePlugin = MutationPlugin("groupCreate", GroupCreate)
75 changes: 71 additions & 4 deletions aiida_restapi/graphql/main.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
# -*- coding: utf-8 -*-
"""Main module that generates the full Graphql App."""
# pylint: disable=no-self-use,redefined-builtin,too-many-arguments,too-few-public-methods
from typing import Any, Callable, List, Optional

import graphene as gr
from starlette.concurrency import run_in_threadpool
from starlette.graphql import GraphQLApp

from .basic import aiidaVersionPlugin, rowLimitMaxPlugin
from .comments import CommentQueryPlugin, CommentsQueryPlugin
from .computers import ComputerQueryPlugin, ComputersQueryPlugin
from .entry_points import aiidaEntryPointGroupsPlugin, aiidaEntryPointsPlugin
from .groups import GroupQueryPlugin, GroupsQueryPlugin
from .groups import GroupCreatePlugin, GroupQueryPlugin, GroupsQueryPlugin
from .logs import LogQueryPlugin, LogsQueryPlugin
from .nodes import NodeQueryPlugin, NodesQueryPlugin
from .plugins import create_schema
from .users import UserQueryPlugin, UsersQueryPlugin

SCHEMA = create_schema(
[
queries=[
rowLimitMaxPlugin,
aiidaVersionPlugin,
aiidaEntryPointGroupsPlugin,
Expand All @@ -30,8 +35,70 @@
NodesQueryPlugin,
UserQueryPlugin,
UsersQueryPlugin,
]
],
mutations=[GroupCreatePlugin],
)


app = GraphQLApp(schema=SCHEMA)
class GraphQLAppWithMiddleware(GraphQLApp):
"""A GraphQLApp that exposes graphene middleware."""

def __init__(
self,
schema: gr.Schema,
executor: Any = None,
executor_class: Optional[type] = None,
graphiql: bool = True,
middleware: Optional[List[Any]] = None,
) -> None:
"""Initialise GraphQLApp."""
self.middleware = middleware
super().__init__(schema, executor, executor_class, graphiql)

async def execute( # type: ignore
self, query, variables=None, context=None, operation_name=None
):
"""Execute a query."""
if self.is_async:
return await self.schema.execute(
query,
variables=variables,
operation_name=operation_name,
executor=self.executor,
return_promise=True,
context=context,
middleware=self.middleware,
)

return await run_in_threadpool(
self.schema.execute,
query,
variables=variables,
operation_name=operation_name,
context=context,
middleware=self.middleware,
)


class AuthorizationMiddleware:
"""GraphQL middleware, to handle authentication of requests."""

def resolve(
self, next: Callable[..., Any], root: Any, info: gr.ResolveInfo, **args: Any
) -> Any:
"""Run before each field query resolution or mutation"""
# we can get the header of the request from the context
if "request" in info.context:
# print(info.context["request"].headers)
pass
# we can then check what type of operation is being performed and act accordingly
if info.operation.operation == "query":
# TODO allow only a certain number of queries in a single request?
pass
elif info.operation.operation == "mutation":
# TODO handle authentication via JWT token
pass
return next(root, info, **args)


app = GraphQLAppWithMiddleware(schema=SCHEMA, middleware=[AuthorizationMiddleware()])
45 changes: 40 additions & 5 deletions aiida_restapi/graphql/plugins.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""Module defining the graphql plugin mechanism."""
from typing import Any, Callable, Dict, NamedTuple, Sequence, Type, Union
from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence, Type, Union

import graphene as gr

Expand All @@ -19,8 +19,10 @@ class QueryPlugin(NamedTuple):

def create_query(
queries: Sequence[QueryPlugin], docstring: str = "The root query"
) -> Type[gr.ObjectType]:
) -> Optional[Type[gr.ObjectType]]:
"""Generate a query from a sequence of query plugins."""
if not queries:
return None
# check that there are no duplicate names
name_map: Dict[str, QueryPlugin] = {}
# construct the dict of attributes/methods on the class
Expand All @@ -39,9 +41,39 @@ def create_query(
return type("RootQuery", (gr.ObjectType,), attr_map)


class MutationPlugin(NamedTuple):
"""Define a top-level mutation, to plugin to the schema."""

name: str
mutation: Type[gr.Mutation]


def create_mutations(
mutations: Sequence[MutationPlugin], docstring: str = "The root mutation"
) -> Optional[Type[gr.ObjectType]]:
"""Generate mutations from a sequence of mutation plugins."""
if not mutations:
return None
# check that there are no duplicate names
name_map: Dict[str, MutationPlugin] = {}
# construct the dict of attributes/methods on the class
attr_map: Dict[str, Any] = {}
for mutation in mutations:
if mutation.name in name_map:
raise ValueError(
f"Duplicate plugin name '{mutation.name}': {mutation} and {name_map[mutation.name]}"
)
name_map[mutation.name] = mutation
attr_map[mutation.name] = mutation.mutation.Field()
attr_map["__doc__"] = docstring
return type("RootMutation", (gr.ObjectType,), attr_map)


def create_schema(
queries: Sequence[QueryPlugin],
docstring: str = "The root query",
queries: Sequence[QueryPlugin] = (),
mutations: Sequence[MutationPlugin] = (),
query_docstring: str = "The root query",
mutations_docstring: str = "The root mutation",
auto_camelcase: bool = False,
**kwargs: Any,
) -> gr.Schema:
Expand All @@ -50,5 +82,8 @@ def create_schema(
Note we set auto_camelcase False, since this keeps database field names the same.
"""
return gr.Schema(
query=create_query(queries, docstring), auto_camelcase=auto_camelcase, **kwargs
query=create_query(queries, query_docstring),
mutation=create_mutations(mutations, mutations_docstring),
auto_camelcase=auto_camelcase,
**kwargs,
)
2 changes: 2 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@
"pydantic.types.Json",
"graphene.types.generic.GenericScalar",
"graphene.types.objecttype.ObjectType",
"graphene.types.mutation.Mutation",
"graphene.types.scalars.String",
"starlette.graphql.GraphQLApp",
"aiida_restapi.aiida_db_mappings.Config",
"aiida_restapi.models.Config",
"aiida_restapi.routers.auth.Config",
Expand Down
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def _func(
data,
{key: lambda k: type(k).__name__ for key in varfields},
)
if "data" in data:
# for graphql mutations this is an ordered dict
data["data"] = dict(data["data"])
data_regression.check(data)

return _func
Expand Down
17 changes: 16 additions & 1 deletion tests/test_graphql/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
"""Tests for group plugins."""
from graphene.test import Client

from aiida_restapi.graphql.groups import GroupQueryPlugin, GroupsQueryPlugin
from aiida_restapi.graphql.groups import (
GroupCreatePlugin,
GroupQueryPlugin,
GroupsQueryPlugin,
)
from aiida_restapi.graphql.orm_factories import field_names_from_orm
from aiida_restapi.graphql.plugins import create_schema

Expand Down Expand Up @@ -39,3 +43,14 @@ def test_groups(create_group, orm_regression):
client = Client(schema)
executed = client.execute("{ groups { count rows { %s } } }" % " ".join(fields))
orm_regression(executed)


def test_group_create(orm_regression):
"""Test Group creation."""
schema = create_schema(mutations=[GroupCreatePlugin])
client = Client(schema)
executed = client.execute(
'mutation { groupCreate(label: "group1", description: "hi") '
"{ created group { id uuid label type_string description } } }"
)
orm_regression(executed)
9 changes: 9 additions & 0 deletions tests/test_graphql/test_groups/test_group_create.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
data:
groupCreate:
created: true
group:
description: hi
id: int
label: group1
type_string: core
uuid: str