diff --git a/graphql_server/__init__.py b/graphql_server/__init__.py index 8942332..b96dfc2 100644 --- a/graphql_server/__init__.py +++ b/graphql_server/__init__.py @@ -15,7 +15,7 @@ from graphql.error import format_error as format_error_default from graphql.execution import ExecutionResult, execute from graphql.language import OperationType, parse -from graphql.pyutils import AwaitableOrValue +from graphql.pyutils import AwaitableOrValue, is_awaitable from graphql.type import GraphQLSchema, validate_schema from graphql.utilities import get_operation_ast from graphql.validation import ASTValidationRule, validate @@ -99,9 +99,7 @@ def run_http_query( if not is_batch: if not isinstance(data, (dict, MutableMapping)): - raise HttpQueryError( - 400, f"GraphQL params should be a dict. Received {data!r}." - ) + raise HttpQueryError(400, f"GraphQL params should be a dict. Received {data!r}.") data = [data] elif not batch_enabled: raise HttpQueryError(400, "Batch GraphQL requests are not enabled.") @@ -114,15 +112,10 @@ def run_http_query( if not is_batch: extra_data = query_data or {} - all_params: List[GraphQLParams] = [ - get_graphql_params(entry, extra_data) for entry in data - ] + all_params: List[GraphQLParams] = [get_graphql_params(entry, extra_data) for entry in data] results: List[Optional[AwaitableOrValue[ExecutionResult]]] = [ - get_response( - schema, params, catch_exc, allow_only_query, run_sync, **execute_options - ) - for params in all_params + get_response(schema, params, catch_exc, allow_only_query, run_sync, **execute_options) for params in all_params ] return GraphQLResponse(results, all_params) @@ -160,10 +153,7 @@ def encode_execution_results( Returns a ServerResponse tuple with the serialized response as the first item and a status code of 200 or 400 in case any result was invalid as the second item. """ - results = [ - format_execution_result(execution_result, format_error) - for execution_result in execution_results - ] + results = [format_execution_result(execution_result, format_error) for execution_result in execution_results] result, status_codes = zip(*results) status_code = max(status_codes) @@ -274,14 +264,11 @@ def get_response( if operation != OperationType.QUERY.value: raise HttpQueryError( 405, - f"Can only perform a {operation} operation" - " from a POST request.", + f"Can only perform a {operation} operation" " from a POST request.", headers={"Allow": "POST"}, ) - validation_errors = validate( - schema, document, rules=validation_rules, max_errors=max_errors - ) + validation_errors = validate(schema, document, rules=validation_rules, max_errors=max_errors) if validation_errors: return ExecutionResult(data=None, errors=validation_errors) @@ -290,7 +277,7 @@ def get_response( document, variable_values=params.variables, operation_name=params.operation_name, - is_awaitable=assume_not_awaitable if run_sync else None, + is_awaitable=assume_not_awaitable if run_sync else is_awaitable, **kwargs, ) @@ -317,9 +304,7 @@ def format_execution_result( fe = [format_error(e) for e in execution_result.errors] # type: ignore response = {"errors": fe} - if execution_result.errors and any( - not getattr(e, "path", None) for e in execution_result.errors - ): + if execution_result.errors and any(not getattr(e, "path", None) for e in execution_result.errors): status_code = 400 else: response["data"] = execution_result.data diff --git a/graphql_server/flask/graphqlview.py b/graphql_server/flask/graphqlview.py index a417406..16a8c8b 100644 --- a/graphql_server/flask/graphqlview.py +++ b/graphql_server/flask/graphqlview.py @@ -1,3 +1,4 @@ +import asyncio import copy from collections.abc import MutableMapping from functools import partial @@ -5,7 +6,9 @@ from flask import Response, render_template_string, request from flask.views import View +from graphql import ExecutionResult from graphql.error import GraphQLError +from graphql.pyutils import is_awaitable from graphql.type.schema import GraphQLSchema from graphql_server import ( @@ -41,6 +44,7 @@ class GraphQLView(View): default_query = None header_editor_enabled = None should_persist_headers = None + enable_async = True methods = ["GET", "POST", "PUT", "DELETE"] @@ -53,19 +57,13 @@ def __init__(self, **kwargs): if hasattr(self, key): setattr(self, key, value) - assert isinstance( - self.schema, GraphQLSchema - ), "A Schema is required to be provided to GraphQLView." + assert isinstance(self.schema, GraphQLSchema), "A Schema is required to be provided to GraphQLView." def get_root_value(self): return self.root_value def get_context(self): - context = ( - copy.copy(self.context) - if self.context and isinstance(self.context, MutableMapping) - else {} - ) + context = copy.copy(self.context) if self.context and isinstance(self.context, MutableMapping) else {} if isinstance(context, MutableMapping) and "request" not in context: context.update({"request": request}) return context @@ -73,6 +71,13 @@ def get_context(self): def get_middleware(self): return self.middleware + @staticmethod + def get_async_execution_results(execution_results): + async def await_execution_results(execution_results): + return [ex if ex is None or is_awaitable(ex) else await ex for ex in execution_results] + + return asyncio.run(await_execution_results(execution_results)) + def dispatch_request(self): try: request_method = request.method.lower() @@ -96,6 +101,11 @@ def dispatch_request(self): context_value=self.get_context(), middleware=self.get_middleware(), ) + + if self.enable_async: + if any(is_awaitable(ex) for ex in execution_results): + execution_results = self.get_async_execution_results(execution_results) + result, status_code = encode_execution_results( execution_results, is_batch=isinstance(data, list), @@ -123,9 +133,7 @@ def dispatch_request(self): header_editor_enabled=self.header_editor_enabled, should_persist_headers=self.should_persist_headers, ) - source = render_graphiql_sync( - data=graphiql_data, config=graphiql_config, options=graphiql_options - ) + source = render_graphiql_sync(data=graphiql_data, config=graphiql_config, options=graphiql_options) return render_template_string(source) return Response(result, status=status_code, content_type="application/json") @@ -167,8 +175,4 @@ def should_display_graphiql(self): @staticmethod def request_wants_html(): best = request.accept_mimetypes.best_match(["application/json", "text/html"]) - return ( - best == "text/html" - and request.accept_mimetypes[best] - > request.accept_mimetypes["application/json"] - ) + return best == "text/html" and request.accept_mimetypes[best] > request.accept_mimetypes["application/json"]