diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..08fcdd26a6 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,4 @@ +Release type: patch + +This release fixes an issue that prevented extensions to receive the result from +the execution context when executing operations in async. diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index 4090b8e3b6..5262535062 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -123,18 +123,14 @@ async def _handle_execution_result( context: ExecutionContext, result: Union[GraphQLExecutionResult, ExecutionResult], extensions_runner: SchemaExtensionsRunner, - process_errors: ProcessErrors, + process_errors: ProcessErrors | None, ) -> ExecutionResult: # Set errors on the context so that it's easier # to access in extensions if result.errors: context.errors = result.errors - - # Run the `Schema.process_errors` function here before - # extensions have a chance to modify them (see the MaskErrors - # extension). That way we can log the original errors but - # only return a sanitised version to the client. - process_errors(result.errors, context) + if process_errors: + process_errors(result.errors, context) if isinstance(result, GraphQLExecutionResult): result = ExecutionResult(data=result.data, errors=result.errors) result.extensions = await extensions_runner.get_extensions_results(context) @@ -171,7 +167,7 @@ async def execute( assert execution_context.graphql_document async with extensions_runner.executing(): if not execution_context.result: - res = await await_maybe( + result = await await_maybe( original_execute( schema, execution_context.graphql_document, @@ -183,9 +179,20 @@ async def execute( execution_context_class=execution_context_class, ) ) - + execution_context.result = result else: - res = execution_context.result + result = execution_context.result + # Also set errors on the execution_context so that it's easier + # to access in extensions + if result.errors: + execution_context.errors = result.errors + + # Run the `Schema.process_errors` function here before + # extensions have a chance to modify them (see the MaskErrors + # extension). That way we can log the original errors but + # only return a sanitised version to the client. + process_errors(result.errors, execution_context) + except (MissingQueryError, InvalidOperationTypeError) as e: raise e except Exception as exc: @@ -195,10 +202,9 @@ async def execute( extensions_runner, process_errors, ) - # return results after all the operation completed. return await _handle_execution_result( - execution_context, res, extensions_runner, process_errors + execution_context, result, extensions_runner, None ) diff --git a/tests/schema/extensions/test_mask_errors.py b/tests/schema/extensions/test_mask_errors.py index 6d3c406c40..58ef32c28e 100644 --- a/tests/schema/extensions/test_mask_errors.py +++ b/tests/schema/extensions/test_mask_errors.py @@ -29,6 +29,29 @@ def hidden_error(self) -> str: ] +async def test_mask_all_errors_async(): + @strawberry.type + class Query: + @strawberry.field + def hidden_error(self) -> str: + raise KeyError("This error is not visible") + + schema = strawberry.Schema(query=Query, extensions=[MaskErrors()]) + + query = "query { hiddenError }" + + result = await schema.execute(query) + assert result.errors is not None + formatted_errors = [err.formatted for err in result.errors] + assert formatted_errors == [ + { + "locations": [{"column": 9, "line": 1}], + "message": "Unexpected error.", + "path": ["hiddenError"], + } + ] + + def test_mask_some_errors(): class VisibleError(Exception): pass