Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for resolver ExecutionResult return type (aka extensions support) #205

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions graphql/execution/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def to_dict(self, format_error=None, dict_class=OrderedDict):
if not self.invalid:
response["data"] = self.data

if self.extensions:
response["extensions"] = self.extensions

return response


Expand All @@ -76,6 +79,7 @@ class ResolveInfo(object):
"variable_values",
"context",
"path",
"extensions",
)

def __init__(
Expand All @@ -91,6 +95,7 @@ def __init__(
variable_values, # type: Dict
context, # type: Optional[Any]
path=None, # type: Union[List[Union[int, str]], List[str]]
extensions=None, # type: Dict
):
# type: (...) -> None
self.field_name = field_name
Expand All @@ -104,6 +109,7 @@ def __init__(
self.variable_values = variable_values
self.context = context
self.path = path
self.extensions = extensions


__all__ = [
Expand Down
28 changes: 25 additions & 3 deletions graphql/execution/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,17 @@ def on_resolve(data):
if isinstance(data, Observable):
return data

if not exe_context.errors:
if exe_context.errors and exe_context.extensions:
return ExecutionResult(
data=data, errors=exe_context.errors, extensions=exe_context.extensions
)
elif exe_context.errors:
return ExecutionResult(data=data, errors=exe_context.errors)
elif exe_context.extensions:
return ExecutionResult(data=data, extensions=exe_context.extensions)
else:
return ExecutionResult(data=data)

return ExecutionResult(data=data, errors=exe_context.errors)

promise = (
Promise.resolve(None).then(promise_executor).catch(on_rejected).then(on_resolve)
)
Expand Down Expand Up @@ -354,6 +360,7 @@ def resolve_field(
variable_values=exe_context.variable_values,
context=context,
path=field_path,
extensions=exe_context.extensions,
)

executor = exe_context.executor
Expand Down Expand Up @@ -408,6 +415,7 @@ def subscribe_field(
variable_values=exe_context.variable_values,
context=context,
path=path,
extensions=exe_context.extensions,
)

executor = exe_context.executor
Expand Down Expand Up @@ -531,6 +539,20 @@ def complete_value(
),
)

# If result is ExecutionResult, update exe_context and complete for data field
if isinstance(result, ExecutionResult):
data = getattr(result, "data", None)
extensions = getattr(result, "extensions", None)
errors = getattr(result, "errors", None)

if extensions:
exe_context.update_extensions(extensions)
if errors:
for error in errors:
exe_context.report_error(error)

return complete_value(exe_context, return_type, field_asts, info, path, data)

# print return_type, type(result)
if isinstance(result, Exception):
raise GraphQLLocatedError(field_asts, original_error=result, path=path)
Expand Down
52 changes: 52 additions & 0 deletions graphql/execution/tests/test_resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
GraphQLSchema,
GraphQLString,
)
from graphql.execution import ExecutionResult
from promise import Promise

# from graphql.execution.base import ResolveInfo
Expand Down Expand Up @@ -112,6 +113,40 @@ def resolver(source, info, **args):
]


def test_handles_resolved_extensions_with_data():
# type: () -> None
def resolver(source, info, **args):
# type: (Optional[str], ResolveInfo, **Any) -> ExecutionResult
extensions = info.extensions or {}
extensions["test_extensions"] = extensions.get("test_extensions", {})
extensions["test_extensions"].update({"foo": "bar"})
return ExecutionResult(data="foobar", extensions=extensions)

schema = _test_schema(GraphQLField(GraphQLString, resolver=resolver))

result = graphql(schema, "{ test }", None)
assert not result.errors
assert result.data == {"test": "foobar"}
assert result.extensions == {"test_extensions": {"foo": "bar"}}


def test_handles_resolved_extensions_with_errors():
# type: () -> None
def resolver(source, info, **args):
# type: (Optional[str], ResolveInfo, **Any) -> ExecutionResult
extensions = info.extensions or {}
extensions["errors"] = extensions.get("errors", {})
extensions["errors"].update({"test": {"foo": "bar"}})
return ExecutionResult(errors=[Exception()], extensions=extensions)

schema = _test_schema(GraphQLField(GraphQLString, resolver=resolver))

result = graphql(schema, "{ test }", None)
assert len(result.errors) == 1
assert result.data == {"test": None}
assert result.extensions == {"errors": {"test": {"foo": "bar"}}}


def test_handles_resolved_promises():
# type: () -> None
def resolver(source, info, **args):
Expand All @@ -125,6 +160,23 @@ def resolver(source, info, **args):
assert result.data == {"test": "foo"}


def test_handles_resolved_promises_extensions():
# type: () -> None
def resolver(source, info, **args):
# type: (Optional[Any], ResolveInfo, **Any) -> Promise
extensions = info.extensions or {}
extensions["test_extensions"] = extensions.get("test_extensions", {})
extensions["test_extensions"].update({"foo": "bar"})
return Promise.resolve(ExecutionResult(data="foobar", extensions=extensions))

schema = _test_schema(GraphQLField(GraphQLString, resolver=resolver))

result = graphql(schema, "{ test }", None)
assert not result.errors
assert result.data == {"test": "foobar"}
assert result.extensions == {"test_extensions": {"foo": "bar"}}


def test_handles_resolved_custom_promises():
# type: () -> None
def resolver(source, info, **args):
Expand Down
10 changes: 10 additions & 0 deletions graphql/execution/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
from copy import deepcopy
import logging
from traceback import format_exception

Expand Down Expand Up @@ -54,6 +55,7 @@ class ExecutionContext(object):
"middleware",
"allow_subscriptions",
"_subfields_cache",
"extensions",
)

def __init__(
Expand All @@ -67,6 +69,7 @@ def __init__(
executor, # type: Any
middleware, # type: Optional[Any]
allow_subscriptions, # type: bool
extensions=None, # type: Dict
):
# type: (...) -> None
"""Constructs a ExecutionContext object from the arguments passed
Expand Down Expand Up @@ -126,6 +129,7 @@ def __init__(
self.middleware = middleware
self.allow_subscriptions = allow_subscriptions
self._subfields_cache = {} # type: Dict[Tuple[GraphQLObjectType, Tuple[Field, ...]], DefaultOrderedDict]
self.extensions = extensions

def get_field_resolver(self, field_resolver):
# type: (Callable) -> Callable
Expand All @@ -151,6 +155,12 @@ def report_error(self, error, traceback=None):
logger.error("".join(exception))
self.errors.append(error)

def update_extensions(self, extensions):
# type: (Dict[str, Any]) -> None
if extensions:
self.extensions = self.extensions or {}
self.extensions.update(extensions)

def get_sub_fields(self, return_type, field_asts):
# type: (GraphQLObjectType, List[Field]) -> DefaultOrderedDict
k = return_type, tuple(field_asts)
Expand Down