From 573fa724616912424808f846b4fe5f4bea8ddc4a Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Tue, 14 Jan 2025 18:41:35 +0000 Subject: [PATCH] WIP defer support --- RELEASE.md | 3 + poetry.lock | 16 +-- pyproject.toml | 3 +- strawberry/http/async_base_view.py | 101 +++++++++++++++- strawberry/schema/execute.py | 49 ++++---- strawberry/schema/schema.py | 12 +- strawberry/static/graphiql.html | 7 +- tests/http/incremental/__init__.py | 0 tests/http/incremental/conftest.py | 44 +++++++ tests/http/incremental/test_defer.py | 38 ++++++ .../test_multipart_subscription.py | 62 ++++++++++ tests/http/test_multipart_subscription.py | 113 ------------------ 12 files changed, 291 insertions(+), 157 deletions(-) create mode 100644 RELEASE.md create mode 100644 tests/http/incremental/__init__.py create mode 100644 tests/http/incremental/conftest.py create mode 100644 tests/http/incremental/test_defer.py create mode 100644 tests/http/incremental/test_multipart_subscription.py delete mode 100644 tests/http/test_multipart_subscription.py diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..544e3920ea --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: minor + +@defer 👀 diff --git a/poetry.lock b/poetry.lock index 22cd41a298..df0c27ff66 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. [[package]] name = "aiofiles" @@ -1355,17 +1355,17 @@ files = [ [[package]] name = "graphql-core" -version = "3.2.5" -description = "GraphQL implementation for Python, a port of GraphQL.js, the JavaScript reference implementation for GraphQL." +version = "3.3.0a6" +description = "GraphQL-core is a Python port of GraphQL.js,the JavaScript reference implementation for GraphQL." optional = false -python-versions = "<4,>=3.6" +python-versions = "<4.0,>=3.7" files = [ - {file = "graphql_core-3.2.5-py3-none-any.whl", hash = "sha256:2f150d5096448aa4f8ab26268567bbfeef823769893b39c1a2e1409590939c8a"}, - {file = "graphql_core-3.2.5.tar.gz", hash = "sha256:e671b90ed653c808715645e3998b7ab67d382d55467b7e2978549111bbabf8d5"}, + {file = "graphql_core-3.3.0a6-py3-none-any.whl", hash = "sha256:ad99089e04ad7450956cb5f834986b5d9625ff5d90cee754bfab56da93a062b8"}, + {file = "graphql_core-3.3.0a6.tar.gz", hash = "sha256:3456712b3e6fd45c0d48bf2c1d87e7a80680da987e29f64563faab0886dab380"}, ] [package.dependencies] -typing-extensions = {version = ">=4,<5", markers = "python_version < \"3.10\""} +typing-extensions = {version = ">=4.12,<5.0", markers = "python_version >= \"3.8\" and python_version < \"3.10\""} [[package]] name = "h11" @@ -4774,4 +4774,4 @@ sanic = ["sanic"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "6ccb728661c37d68d045aaae1651fd3f55b6905aa7d9d81711b70da488f7d05f" +content-hash = "6b3a16c7185b6c7290b727311f16f6ec254c810e39e7db656958f0967ae6dcc0" diff --git a/pyproject.toml b/pyproject.toml index 1a177b222f..6ef14b67fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry.dependencies] python = "^3.9" -graphql-core = ">=3.2.0,<3.4.0" +graphql-core = ">=3.2.0" typing-extensions = ">=4.5.0" python-dateutil = "^2.7.0" starlette = {version = ">=0.18.0", optional = true} @@ -102,6 +102,7 @@ types-deprecated = "^1.2.15.20241117" types-six = "^1.17.0.20241205" types-pyyaml = "^6.0.12.20240917" mypy = "^1.13.0" +graphql-core = "3.3.0a6" [tool.poetry.group.integrations] optional = true diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index b73eb55181..30fca81383 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -17,6 +17,18 @@ from graphql import GraphQLError +# TODO: only import this if exists +from graphql.execution.execute import ( + ExperimentalIncrementalExecutionResults, + InitialIncrementalExecutionResult, +) +from graphql.execution.incremental_publisher import ( + IncrementalDeferResult, + IncrementalResult, + IncrementalStreamResult, + SubsequentIncrementalExecutionResult, +) + from strawberry.exceptions import MissingQueryError from strawberry.file_uploads.utils import replace_placeholders_with_files from strawberry.http import ( @@ -337,6 +349,29 @@ async def run( except MissingQueryError as e: raise HTTPException(400, "No GraphQL query found in the request") from e + if isinstance(result, ExperimentalIncrementalExecutionResults): + + async def stream(): + yield "---" + response = await self.process_result(request, result.initial_result) + yield self.encode_multipart_data(response, "-") + + async for value in result.subsequent_results: + response = await self.process_subsequent_result(request, value) + yield self.encode_multipart_data(response, "-") + + yield "--\r\n" + + return await self.create_streaming_response( + request, + stream, + sub_response, + headers={ + "Transfer-Encoding": "chunked", + "Content-Type": 'multipart/mixed; boundary="-"', + }, + ) + if isinstance(result, SubscriptionExecutionResult): stream = self._get_stream(request, result) @@ -360,12 +395,15 @@ async def run( ) def encode_multipart_data(self, data: Any, separator: str) -> str: + encoded_data = self.encode_json(data) + return "".join( [ - f"\r\n--{separator}\r\n", - "Content-Type: application/json\r\n\r\n", - self.encode_json(data), - "\n", + "\r\n", + "Content-Type: application/json; charset=utf-8\r\n", + "\r\n", + encoded_data, + f"\r\n--{separator}", ] ) @@ -475,9 +513,62 @@ async def parse_http_body( protocol=protocol, ) + def process_incremental_result( + self, request: Request, result: IncrementalResult + ) -> GraphQLHTTPResponse: + if isinstance(result, IncrementalDeferResult): + return { + "data": result.data, + "errors": result.errors, + "path": result.path, + "label": result.label, + "extensions": result.extensions, + } + if isinstance(result, IncrementalStreamResult): + return { + "items": result.items, + "errors": result.errors, + "path": result.path, + "label": result.label, + "extensions": result.extensions, + } + + raise ValueError(f"Unsupported incremental result type: {type(result)}") + + async def process_subsequent_result( + self, + request: Request, + result: SubsequentIncrementalExecutionResult, + # TODO: use proper return type + ) -> GraphQLHTTPResponse: + data = { + "incremental": [ + await self.process_result(request, value) + for value in result.incremental + ], + "hasNext": result.has_next, + "extensions": result.extensions, + } + + return data + async def process_result( - self, request: Request, result: ExecutionResult + self, + request: Request, + result: Union[ExecutionResult, InitialIncrementalExecutionResult], ) -> GraphQLHTTPResponse: + if isinstance(result, InitialIncrementalExecutionResult): + return { + "data": result.data, + "incremental": [ + self.process_incremental_result(request, value) + for value in result.incremental + ] + if result.incremental + else [], + "hasNext": result.has_next, + "extensions": result.extensions, + } return process_result(result) async def on_ws_connect( diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index b0e1ebf45d..ac0eca1046 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -14,7 +14,7 @@ from graphql import ExecutionResult as GraphQLExecutionResult from graphql import GraphQLError, parse -from graphql import execute as original_execute +from graphql.execution import experimental_execute_incrementally from graphql.validation import validate from strawberry.exceptions import MissingQueryError @@ -121,16 +121,17 @@ async def _handle_execution_result( extensions_runner: SchemaExtensionsRunner, process_errors: ProcessErrors | None, ) -> ExecutionResult: - # Set errors on the context so that it's easier - # to access in extensions - if result.errors: - context.errors = result.errors - if process_errors: - process_errors(result.errors, context) - if isinstance(result, GraphQLExecutionResult): - result = ExecutionResult(data=result.data, errors=result.errors) - result.extensions = await extensions_runner.get_extensions_results(context) - context.result = result # type: ignore # mypy failed to deduce correct type. + # TODO: deal with this later + # # Set errors on the context so that it's easier + # # to access in extensions + # if result.errors: + # context.errors = result.errors + # if process_errors: + # process_errors(result.errors, context) + # if isinstance(result, GraphQLExecutionResult): + # result = ExecutionResult(data=result.data, errors=result.errors) + # result.extensions = await extensions_runner.get_extensions_results(context) + # context.result = result # type: ignore # mypy failed to deduce correct type. return result @@ -164,7 +165,7 @@ async def execute( async with extensions_runner.executing(): if not execution_context.result: result = await await_maybe( - original_execute( + experimental_execute_incrementally( schema, execution_context.graphql_document, root_value=execution_context.root_value, @@ -178,16 +179,18 @@ async def execute( execution_context.result = result else: result = execution_context.result - # Also set errors on the execution_context so that it's easier - # to access in extensions - if result.errors: - execution_context.errors = result.errors - - # Run the `Schema.process_errors` function here before - # extensions have a chance to modify them (see the MaskErrors - # extension). That way we can log the original errors but - # only return a sanitised version to the client. - process_errors(result.errors, execution_context) + # TODO: deal with this later + # # Also set errors on the execution_context so that it's easier + # # to access in extensions + # breakpoint() + # if result.errors: + # execution_context.errors = result.errors + + # # Run the `Schema.process_errors` function here before + # # extensions have a chance to modify them (see the MaskErrors + # # extension). That way we can log the original errors but + # # only return a sanitised version to the client. + # process_errors(result.errors, execution_context) except (MissingQueryError, InvalidOperationTypeError): raise @@ -252,7 +255,7 @@ def execute_sync( with extensions_runner.executing(): if not execution_context.result: - result = original_execute( + result = experimental_execute_incrementally( schema, execution_context.graphql_document, root_value=execution_context.root_value, diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index a7de78c95a..6ee9a1b21c 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -20,7 +20,11 @@ validate_schema, ) from graphql.execution.middleware import MiddlewareManager -from graphql.type.directives import specified_directives +from graphql.type.directives import ( + GraphQLDeferDirective, + GraphQLStreamDirective, + specified_directives, +) from strawberry import relay from strawberry.annotation import StrawberryAnnotation @@ -194,7 +198,11 @@ class Query: query=query_type, mutation=mutation_type, subscription=subscription_type if subscription else None, - directives=specified_directives + tuple(graphql_directives), + directives=( + specified_directives + + tuple(graphql_directives) + + (GraphQLDeferDirective, GraphQLStreamDirective) + ), types=graphql_types, extensions={ GraphQLCoreConverter.DEFINITION_BACKREF: self, diff --git a/strawberry/static/graphiql.html b/strawberry/static/graphiql.html index b66082a97f..9892cf9a36 100644 --- a/strawberry/static/graphiql.html +++ b/strawberry/static/graphiql.html @@ -61,8 +61,7 @@ Loading...