diff --git a/aiida_restapi/graphql/groups.py b/aiida_restapi/graphql/groups.py index bac9504..d54097e 100644 --- a/aiida_restapi/graphql/groups.py +++ b/aiida_restapi/graphql/groups.py @@ -2,14 +2,15 @@ """Defines plugins for AiiDA groups.""" # pylint: disable=too-few-public-methods,redefined-builtin,,unused-argument -from typing import Any, Optional +from typing import Any, Optional, Tuple import graphene as gr +from aiida.cmdline.utils.decorators import with_dbenv from aiida.orm import Group from aiida_restapi.filter_syntax import parse_filter_str from aiida_restapi.graphql.nodes import NodesQuery -from aiida_restapi.graphql.plugins import QueryPlugin +from aiida_restapi.graphql.plugins import MutationPlugin, QueryPlugin from .orm_factories import ( ENTITY_DICT_TYPE, @@ -73,3 +74,47 @@ def resolve_Groups( ), resolve_Groups, ) + + +class GroupCreate(gr.Mutation): + """Create an AiiDA group (or change an existing one).""" + + class Arguments: + """The arguments to create a group.""" + + label = gr.String(required=True) + description = gr.String(default_value="") + type_string = gr.String() + + created = gr.Boolean( + description="Whether the group was created or already existed." + ) + group = gr.Field(lambda: GroupQuery) + + @with_dbenv() + @staticmethod + def mutate( + root: Any, + info: gr.ResolveInfo, + label: str, + description: str = "", + type_string: Optional[str] = None, + ) -> "GroupCreate": + """Create the group and return the requested fields.""" + output: Tuple[Group, bool] = Group.objects.get_or_create( + label=label, description=description, type_string=type_string + ) + orm_group, created = output + if not created and not orm_group.description == description: + orm_group.description = description + group = GroupQuery( + id=orm_group.id, + uuid=orm_group.uuid, + label=orm_group.label, + type_string=orm_group.type_string, + description=orm_group.description, + ) + return GroupCreate(group=group, created=created) + + +GroupCreatePlugin = MutationPlugin("groupCreate", GroupCreate) diff --git a/aiida_restapi/graphql/main.py b/aiida_restapi/graphql/main.py index a1f0f61..89e3630 100644 --- a/aiida_restapi/graphql/main.py +++ b/aiida_restapi/graphql/main.py @@ -1,19 +1,24 @@ # -*- coding: utf-8 -*- """Main module that generates the full Graphql App.""" +# pylint: disable=no-self-use,redefined-builtin,too-many-arguments,too-few-public-methods +from typing import Any, Callable, List, Optional + +import graphene as gr +from starlette.concurrency import run_in_threadpool from starlette.graphql import GraphQLApp from .basic import aiidaVersionPlugin, rowLimitMaxPlugin from .comments import CommentQueryPlugin, CommentsQueryPlugin from .computers import ComputerQueryPlugin, ComputersQueryPlugin from .entry_points import aiidaEntryPointGroupsPlugin, aiidaEntryPointsPlugin -from .groups import GroupQueryPlugin, GroupsQueryPlugin +from .groups import GroupCreatePlugin, GroupQueryPlugin, GroupsQueryPlugin from .logs import LogQueryPlugin, LogsQueryPlugin from .nodes import NodeQueryPlugin, NodesQueryPlugin from .plugins import create_schema from .users import UserQueryPlugin, UsersQueryPlugin SCHEMA = create_schema( - [ + queries=[ rowLimitMaxPlugin, aiidaVersionPlugin, aiidaEntryPointGroupsPlugin, @@ -30,8 +35,70 @@ NodesQueryPlugin, UserQueryPlugin, UsersQueryPlugin, - ] + ], + mutations=[GroupCreatePlugin], ) -app = GraphQLApp(schema=SCHEMA) +class GraphQLAppWithMiddleware(GraphQLApp): + """A GraphQLApp that exposes graphene middleware.""" + + def __init__( + self, + schema: gr.Schema, + executor: Any = None, + executor_class: Optional[type] = None, + graphiql: bool = True, + middleware: Optional[List[Any]] = None, + ) -> None: + """Initialise GraphQLApp.""" + self.middleware = middleware + super().__init__(schema, executor, executor_class, graphiql) + + async def execute( # type: ignore + self, query, variables=None, context=None, operation_name=None + ): + """Execute a query.""" + if self.is_async: + return await self.schema.execute( + query, + variables=variables, + operation_name=operation_name, + executor=self.executor, + return_promise=True, + context=context, + middleware=self.middleware, + ) + + return await run_in_threadpool( + self.schema.execute, + query, + variables=variables, + operation_name=operation_name, + context=context, + middleware=self.middleware, + ) + + +class AuthorizationMiddleware: + """GraphQL middleware, to handle authentication of requests.""" + + def resolve( + self, next: Callable[..., Any], root: Any, info: gr.ResolveInfo, **args: Any + ) -> Any: + """Run before each field query resolution or mutation""" + # we can get the header of the request from the context + if "request" in info.context: + # print(info.context["request"].headers) + pass + # we can then check what type of operation is being performed and act accordingly + if info.operation.operation == "query": + # TODO allow only a certain number of queries in a single request? + pass + elif info.operation.operation == "mutation": + # TODO handle authentication via JWT token + pass + return next(root, info, **args) + + +app = GraphQLAppWithMiddleware(schema=SCHEMA, middleware=[AuthorizationMiddleware()]) diff --git a/aiida_restapi/graphql/plugins.py b/aiida_restapi/graphql/plugins.py index ea8254d..d3ae218 100644 --- a/aiida_restapi/graphql/plugins.py +++ b/aiida_restapi/graphql/plugins.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """Module defining the graphql plugin mechanism.""" -from typing import Any, Callable, Dict, NamedTuple, Sequence, Type, Union +from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence, Type, Union import graphene as gr @@ -19,8 +19,10 @@ class QueryPlugin(NamedTuple): def create_query( queries: Sequence[QueryPlugin], docstring: str = "The root query" -) -> Type[gr.ObjectType]: +) -> Optional[Type[gr.ObjectType]]: """Generate a query from a sequence of query plugins.""" + if not queries: + return None # check that there are no duplicate names name_map: Dict[str, QueryPlugin] = {} # construct the dict of attributes/methods on the class @@ -39,9 +41,39 @@ def create_query( return type("RootQuery", (gr.ObjectType,), attr_map) +class MutationPlugin(NamedTuple): + """Define a top-level mutation, to plugin to the schema.""" + + name: str + mutation: Type[gr.Mutation] + + +def create_mutations( + mutations: Sequence[MutationPlugin], docstring: str = "The root mutation" +) -> Optional[Type[gr.ObjectType]]: + """Generate mutations from a sequence of mutation plugins.""" + if not mutations: + return None + # check that there are no duplicate names + name_map: Dict[str, MutationPlugin] = {} + # construct the dict of attributes/methods on the class + attr_map: Dict[str, Any] = {} + for mutation in mutations: + if mutation.name in name_map: + raise ValueError( + f"Duplicate plugin name '{mutation.name}': {mutation} and {name_map[mutation.name]}" + ) + name_map[mutation.name] = mutation + attr_map[mutation.name] = mutation.mutation.Field() + attr_map["__doc__"] = docstring + return type("RootMutation", (gr.ObjectType,), attr_map) + + def create_schema( - queries: Sequence[QueryPlugin], - docstring: str = "The root query", + queries: Sequence[QueryPlugin] = (), + mutations: Sequence[MutationPlugin] = (), + query_docstring: str = "The root query", + mutations_docstring: str = "The root mutation", auto_camelcase: bool = False, **kwargs: Any, ) -> gr.Schema: @@ -50,5 +82,8 @@ def create_schema( Note we set auto_camelcase False, since this keeps database field names the same. """ return gr.Schema( - query=create_query(queries, docstring), auto_camelcase=auto_camelcase, **kwargs + query=create_query(queries, query_docstring), + mutation=create_mutations(mutations, mutations_docstring), + auto_camelcase=auto_camelcase, + **kwargs, ) diff --git a/docs/source/conf.py b/docs/source/conf.py index 0ded738..1838fbb 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -96,7 +96,9 @@ "pydantic.types.Json", "graphene.types.generic.GenericScalar", "graphene.types.objecttype.ObjectType", + "graphene.types.mutation.Mutation", "graphene.types.scalars.String", + "starlette.graphql.GraphQLApp", "aiida_restapi.aiida_db_mappings.Config", "aiida_restapi.models.Config", "aiida_restapi.routers.auth.Config", diff --git a/tests/conftest.py b/tests/conftest.py index 6ab2909..e687da3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -48,6 +48,9 @@ def _func( data, {key: lambda k: type(k).__name__ for key in varfields}, ) + if "data" in data: + # for graphql mutations this is an ordered dict + data["data"] = dict(data["data"]) data_regression.check(data) return _func diff --git a/tests/test_graphql/test_groups.py b/tests/test_graphql/test_groups.py index 8d979c0..09d0c0a 100644 --- a/tests/test_graphql/test_groups.py +++ b/tests/test_graphql/test_groups.py @@ -2,7 +2,11 @@ """Tests for group plugins.""" from graphene.test import Client -from aiida_restapi.graphql.groups import GroupQueryPlugin, GroupsQueryPlugin +from aiida_restapi.graphql.groups import ( + GroupCreatePlugin, + GroupQueryPlugin, + GroupsQueryPlugin, +) from aiida_restapi.graphql.orm_factories import field_names_from_orm from aiida_restapi.graphql.plugins import create_schema @@ -39,3 +43,14 @@ def test_groups(create_group, orm_regression): client = Client(schema) executed = client.execute("{ groups { count rows { %s } } }" % " ".join(fields)) orm_regression(executed) + + +def test_group_create(orm_regression): + """Test Group creation.""" + schema = create_schema(mutations=[GroupCreatePlugin]) + client = Client(schema) + executed = client.execute( + 'mutation { groupCreate(label: "group1", description: "hi") ' + "{ created group { id uuid label type_string description } } }" + ) + orm_regression(executed) diff --git a/tests/test_graphql/test_groups/test_group_create.yml b/tests/test_graphql/test_groups/test_group_create.yml new file mode 100644 index 0000000..a3d3650 --- /dev/null +++ b/tests/test_graphql/test_groups/test_group_create.yml @@ -0,0 +1,9 @@ +data: + groupCreate: + created: true + group: + description: hi + id: int + label: group1 + type_string: core + uuid: str