From e97afee7c7cbb37f87685408bb5a3bef84197c3f Mon Sep 17 00:00:00 2001 From: omarzouk Date: Fri, 19 Apr 2024 19:35:48 +0200 Subject: [PATCH 01/15] first attempt, introducing new param --- strawberry/http/__init__.py | 2 ++ strawberry/http/async_base_view.py | 1 + strawberry/schema/base.py | 1 + strawberry/schema/schema.py | 4 ++++ strawberry/types/execution.py | 1 + 5 files changed, 9 insertions(+) diff --git a/strawberry/http/__init__.py b/strawberry/http/__init__.py index dc86e7c9f8..3c503d9f98 100644 --- a/strawberry/http/__init__.py +++ b/strawberry/http/__init__.py @@ -33,6 +33,7 @@ class GraphQLRequestData: query: Optional[str] variables: Optional[Dict[str, Any]] operation_name: Optional[str] + extensions: Optional[Dict[str, Any]] def parse_query_params(params: Dict[str, str]) -> Dict[str, Any]: @@ -47,4 +48,5 @@ def parse_request_data(data: Mapping[str, Any]) -> GraphQLRequestData: query=data.get("query"), variables=data.get("variables"), operation_name=data.get("operationName"), + extensions=data.get("extensions"), ) diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 4e800238a0..ebbe08dccd 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -121,6 +121,7 @@ async def execute_operation( context_value=context, operation_name=request_data.operation_name, allowed_operation_types=allowed_operation_types, + protocol_extensions=request_data.extensions ) async def parse_multipart(self, request: AsyncHTTPRequestAdapter) -> Dict[str, str]: diff --git a/strawberry/schema/base.py b/strawberry/schema/base.py index a1c286c6d0..39f0ca81bc 100644 --- a/strawberry/schema/base.py +++ b/strawberry/schema/base.py @@ -39,6 +39,7 @@ async def execute( root_value: Optional[Any] = None, operation_name: Optional[str] = None, allowed_operation_types: Optional[Iterable[OperationType]] = None, + protocol_extensions: Optional[Dict[str, Any]] = None, ) -> ExecutionResult: raise NotImplementedError diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index b43963d9b5..e277a193c0 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -246,6 +246,7 @@ async def execute( root_value: Optional[Any] = None, operation_name: Optional[str] = None, allowed_operation_types: Optional[Iterable[OperationType]] = None, + protocol_extensions: Optional[Dict[str, Any]] = None, ) -> ExecutionResult: if allowed_operation_types is None: allowed_operation_types = DEFAULT_ALLOWED_OPERATION_TYPES @@ -258,6 +259,7 @@ async def execute( root_value=root_value, variables=variable_values, provided_operation_name=operation_name, + protocol_extensions=protocol_extensions ) result = await execute( @@ -279,6 +281,7 @@ def execute_sync( root_value: Optional[Any] = None, operation_name: Optional[str] = None, allowed_operation_types: Optional[Iterable[OperationType]] = None, + protocol_extensions: Optional[Dict[str, Any]] = None, ) -> ExecutionResult: if allowed_operation_types is None: allowed_operation_types = DEFAULT_ALLOWED_OPERATION_TYPES @@ -290,6 +293,7 @@ def execute_sync( root_value=root_value, variables=variable_values, provided_operation_name=operation_name, + protocol_extensions=protocol_extensions ) result = execute_sync( diff --git a/strawberry/types/execution.py b/strawberry/types/execution.py index 9dc7ff7ef3..54adafee2e 100644 --- a/strawberry/types/execution.py +++ b/strawberry/types/execution.py @@ -35,6 +35,7 @@ class ExecutionContext: schema: Schema context: Any = None variables: Optional[Dict[str, Any]] = None + protocol_extensions: Optional[Dict[str, Any]] = None parse_options: ParseOptions = dataclasses.field( default_factory=lambda: ParseOptions() ) From 41fad1608658fea58b9d41a9e0ea03e2a870a55b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Apr 2024 17:38:30 +0000 Subject: [PATCH 02/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- strawberry/http/async_base_view.py | 2 +- strawberry/schema/schema.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index ebbe08dccd..b9ed3c49e0 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -121,7 +121,7 @@ async def execute_operation( context_value=context, operation_name=request_data.operation_name, allowed_operation_types=allowed_operation_types, - protocol_extensions=request_data.extensions + protocol_extensions=request_data.extensions, ) async def parse_multipart(self, request: AsyncHTTPRequestAdapter) -> Dict[str, str]: diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index e277a193c0..b77da24bec 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -259,7 +259,7 @@ async def execute( root_value=root_value, variables=variable_values, provided_operation_name=operation_name, - protocol_extensions=protocol_extensions + protocol_extensions=protocol_extensions, ) result = await execute( @@ -293,7 +293,7 @@ def execute_sync( root_value=root_value, variables=variable_values, provided_operation_name=operation_name, - protocol_extensions=protocol_extensions + protocol_extensions=protocol_extensions, ) result = execute_sync( From cbd138bcc745fa69074f80b5291452bf2f2a123b Mon Sep 17 00:00:00 2001 From: omarzouk Date: Fri, 19 Apr 2024 19:52:00 +0200 Subject: [PATCH 03/15] missing params --- strawberry/http/async_base_view.py | 1 + strawberry/http/sync_base_view.py | 1 + 2 files changed, 2 insertions(+) diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index b9ed3c49e0..064ea8e48f 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -220,6 +220,7 @@ async def parse_http_body( query=data.get("query"), variables=data.get("variables"), operation_name=data.get("operationName"), + extensions=data.get("extensions"), ) async def process_result( diff --git a/strawberry/http/sync_base_view.py b/strawberry/http/sync_base_view.py index d0fbf03576..e3f359b023 100644 --- a/strawberry/http/sync_base_view.py +++ b/strawberry/http/sync_base_view.py @@ -159,6 +159,7 @@ def parse_http_body(self, request: SyncHTTPRequestAdapter) -> GraphQLRequestData query=data.get("query"), variables=data.get("variables"), operation_name=data.get("operationName"), + extensions=data.get("extensions"), ) def _handle_errors( From 69b35c3c0b06d0a2805d3d631d6db03ba4605285 Mon Sep 17 00:00:00 2001 From: omarzouk Date: Mon, 24 Jun 2024 12:55:28 +0200 Subject: [PATCH 04/15] revert trying to pass to execute --- strawberry/http/async_base_view.py | 1 - strawberry/schema/base.py | 1 - strawberry/schema/schema.py | 4 ---- strawberry/types/execution.py | 1 - 4 files changed, 7 deletions(-) diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 064ea8e48f..af32a1c12a 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -121,7 +121,6 @@ async def execute_operation( context_value=context, operation_name=request_data.operation_name, allowed_operation_types=allowed_operation_types, - protocol_extensions=request_data.extensions, ) async def parse_multipart(self, request: AsyncHTTPRequestAdapter) -> Dict[str, str]: diff --git a/strawberry/schema/base.py b/strawberry/schema/base.py index 39f0ca81bc..a1c286c6d0 100644 --- a/strawberry/schema/base.py +++ b/strawberry/schema/base.py @@ -39,7 +39,6 @@ async def execute( root_value: Optional[Any] = None, operation_name: Optional[str] = None, allowed_operation_types: Optional[Iterable[OperationType]] = None, - protocol_extensions: Optional[Dict[str, Any]] = None, ) -> ExecutionResult: raise NotImplementedError diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index b77da24bec..b43963d9b5 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -246,7 +246,6 @@ async def execute( root_value: Optional[Any] = None, operation_name: Optional[str] = None, allowed_operation_types: Optional[Iterable[OperationType]] = None, - protocol_extensions: Optional[Dict[str, Any]] = None, ) -> ExecutionResult: if allowed_operation_types is None: allowed_operation_types = DEFAULT_ALLOWED_OPERATION_TYPES @@ -259,7 +258,6 @@ async def execute( root_value=root_value, variables=variable_values, provided_operation_name=operation_name, - protocol_extensions=protocol_extensions, ) result = await execute( @@ -281,7 +279,6 @@ def execute_sync( root_value: Optional[Any] = None, operation_name: Optional[str] = None, allowed_operation_types: Optional[Iterable[OperationType]] = None, - protocol_extensions: Optional[Dict[str, Any]] = None, ) -> ExecutionResult: if allowed_operation_types is None: allowed_operation_types = DEFAULT_ALLOWED_OPERATION_TYPES @@ -293,7 +290,6 @@ def execute_sync( root_value=root_value, variables=variable_values, provided_operation_name=operation_name, - protocol_extensions=protocol_extensions, ) result = execute_sync( diff --git a/strawberry/types/execution.py b/strawberry/types/execution.py index 54adafee2e..9dc7ff7ef3 100644 --- a/strawberry/types/execution.py +++ b/strawberry/types/execution.py @@ -35,7 +35,6 @@ class ExecutionContext: schema: Schema context: Any = None variables: Optional[Dict[str, Any]] = None - protocol_extensions: Optional[Dict[str, Any]] = None parse_options: ParseOptions = dataclasses.field( default_factory=lambda: ParseOptions() ) From 2841b698ecd4f711af8b0540851f48c031f723dc Mon Sep 17 00:00:00 2001 From: omarzouk Date: Mon, 24 Jun 2024 13:05:56 +0200 Subject: [PATCH 05/15] implement Option 1 --- strawberry/http/async_base_view.py | 6 +++++- strawberry/http/sync_base_view.py | 6 +++++- strawberry/types/context_wrapper.py | 8 ++++++++ strawberry/types/info.py | 9 +++++++++ 4 files changed, 27 insertions(+), 2 deletions(-) create mode 100644 strawberry/types/context_wrapper.py diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index af32a1c12a..e71038b506 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -26,6 +26,7 @@ from .exceptions import HTTPException from .types import FormData, HTTPMethod, QueryParams from .typevars import Context, Request, Response, RootValue, SubResponse +from ..types.context_wrapper import ContextWrapper class AsyncHTTPRequestAdapter(abc.ABC): @@ -114,11 +115,14 @@ async def execute_operation( assert self.schema + context_wrapper = ContextWrapper(context=context, + extensions=request_data.extensions) + return await self.schema.execute( request_data.query, root_value=root_value, variable_values=request_data.variables, - context_value=context, + context_value=context_wrapper, operation_name=request_data.operation_name, allowed_operation_types=allowed_operation_types, ) diff --git a/strawberry/http/sync_base_view.py b/strawberry/http/sync_base_view.py index e3f359b023..c4eecb9d6f 100644 --- a/strawberry/http/sync_base_view.py +++ b/strawberry/http/sync_base_view.py @@ -27,6 +27,7 @@ from .exceptions import HTTPException from .types import HTTPMethod, QueryParams from .typevars import Context, Request, Response, RootValue, SubResponse +from ..types.context_wrapper import ContextWrapper class SyncHTTPRequestAdapter(abc.ABC): @@ -125,11 +126,14 @@ def execute_operation( assert self.schema + context_wrapper = ContextWrapper(context=context, + extensions=request_data.extensions) + return self.schema.execute_sync( request_data.query, root_value=root_value, variable_values=request_data.variables, - context_value=context, + context_value=context_wrapper, operation_name=request_data.operation_name, allowed_operation_types=allowed_operation_types, ) diff --git a/strawberry/types/context_wrapper.py b/strawberry/types/context_wrapper.py new file mode 100644 index 0000000000..ad87c87f1b --- /dev/null +++ b/strawberry/types/context_wrapper.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass +from typing import Any, Dict, Optional + + +@dataclass +class ContextWrapper: + context: Optional[Any] + extensions: Optional[Dict[str, Any]] diff --git a/strawberry/types/info.py b/strawberry/types/info.py index 879652508f..206f459b90 100644 --- a/strawberry/types/info.py +++ b/strawberry/types/info.py @@ -16,6 +16,7 @@ ) from typing_extensions import TypeVar +from .context_wrapper import ContextWrapper from .nodes import convert_selections if TYPE_CHECKING: @@ -79,8 +80,16 @@ def selected_fields(self) -> List[Selection]: @property def context(self) -> ContextType: + if type(self._raw_info.context) is ContextWrapper: + return self._raw_info.context.context return self._raw_info.context + @property + def input_extensions(self) -> Dict[str, Any]: + if type(self._raw_info.context) is ContextWrapper: + return self._raw_info.context.extensions + return {} + @property def root_value(self) -> RootValueType: return self._raw_info.root_value From 419f2f914627470fbd043be3b3ee925939dc0ce2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Jun 2024 11:06:11 +0000 Subject: [PATCH 06/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- strawberry/http/async_base_view.py | 7 ++++--- strawberry/http/sync_base_view.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index e71038b506..f6bf2f3367 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -22,11 +22,11 @@ from strawberry.types import ExecutionResult from strawberry.types.graphql import OperationType +from ..types.context_wrapper import ContextWrapper from .base import BaseView from .exceptions import HTTPException from .types import FormData, HTTPMethod, QueryParams from .typevars import Context, Request, Response, RootValue, SubResponse -from ..types.context_wrapper import ContextWrapper class AsyncHTTPRequestAdapter(abc.ABC): @@ -115,8 +115,9 @@ async def execute_operation( assert self.schema - context_wrapper = ContextWrapper(context=context, - extensions=request_data.extensions) + context_wrapper = ContextWrapper( + context=context, extensions=request_data.extensions + ) return await self.schema.execute( request_data.query, diff --git a/strawberry/http/sync_base_view.py b/strawberry/http/sync_base_view.py index c4eecb9d6f..2b2c72e4f8 100644 --- a/strawberry/http/sync_base_view.py +++ b/strawberry/http/sync_base_view.py @@ -23,11 +23,11 @@ from strawberry.types import ExecutionResult from strawberry.types.graphql import OperationType +from ..types.context_wrapper import ContextWrapper from .base import BaseView from .exceptions import HTTPException from .types import HTTPMethod, QueryParams from .typevars import Context, Request, Response, RootValue, SubResponse -from ..types.context_wrapper import ContextWrapper class SyncHTTPRequestAdapter(abc.ABC): @@ -126,8 +126,9 @@ def execute_operation( assert self.schema - context_wrapper = ContextWrapper(context=context, - extensions=request_data.extensions) + context_wrapper = ContextWrapper( + context=context, extensions=request_data.extensions + ) return self.schema.execute_sync( request_data.query, From 393b97697aa57752f3e62ceed25cd9edd82edcff Mon Sep 17 00:00:00 2001 From: omarzouk Date: Mon, 24 Jun 2024 15:53:20 +0200 Subject: [PATCH 07/15] fix: missing parsing of extensions in params for GET requests --- strawberry/http/base.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/strawberry/http/base.py b/strawberry/http/base.py index ffdc6f5446..d97445cf49 100644 --- a/strawberry/http/base.py +++ b/strawberry/http/base.py @@ -67,6 +67,15 @@ def parse_query_params( if variables: params["variables"] = self.parse_json(variables) + if "extensions" in params: + extensions = params["extensions"] + + if isinstance(extensions, list): + extensions = extensions[0] + + if extensions: + params["extensions"] = self.parse_json(extensions) + return params @property From 3eafc92d1e12eecca58b55076b405a18b454372e Mon Sep 17 00:00:00 2001 From: omarzouk Date: Mon, 24 Jun 2024 15:54:36 +0200 Subject: [PATCH 08/15] chore: add new test and adjust http test code to pass the parameter --- tests/http/clients/aiohttp.py | 4 +++- tests/http/clients/asgi.py | 4 +++- tests/http/clients/base.py | 12 +++++++++++- tests/http/clients/chalice.py | 4 +++- tests/http/clients/channels.py | 13 ++++++++++--- tests/http/clients/django.py | 4 +++- tests/http/clients/fastapi.py | 4 +++- tests/http/clients/flask.py | 4 +++- tests/http/clients/litestar.py | 4 +++- tests/http/clients/quart.py | 4 +++- tests/http/clients/sanic.py | 4 +++- tests/http/clients/starlite.py | 4 +++- tests/http/test_query.py | 15 +++++++++++++++ tests/views/schema.py | 4 ++++ 14 files changed, 70 insertions(+), 14 deletions(-) diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index cd552e877c..44d9bc99c3 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -106,11 +106,13 @@ async def _graphql_request( variables: Optional[Dict[str, object]] = None, files: Optional[Dict[str, BytesIO]] = None, headers: Optional[Dict[str, str]] = None, + extensions: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Response: async with TestClient(TestServer(self.app)) as client: body = self._build_body( - query=query, variables=variables, files=files, method=method + query=query, variables=variables, files=files, method=method, + extensions=extensions ) if body and files: diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index 72d9e95aa6..745f53beeb 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -97,10 +97,12 @@ async def _graphql_request( variables: Optional[Dict[str, object]] = None, files: Optional[Dict[str, BytesIO]] = None, headers: Optional[Dict[str, str]] = None, + extensions: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Response: body = self._build_body( - query=query, variables=variables, files=files, method=method + query=query, variables=variables, files=files, method=method, + extensions=extensions ) if method == "get": diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index 38447799d0..43a94c0402 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -56,6 +56,7 @@ async def _graphql_request( variables: Optional[Dict[str, object]] = None, files: Optional[Dict[str, BytesIO]] = None, headers: Optional[Dict[str, str]] = None, + extensions: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Response: ... @@ -94,9 +95,11 @@ async def query( variables: Optional[Dict[str, object]] = None, files: Optional[Dict[str, BytesIO]] = None, headers: Optional[Dict[str, str]] = None, + extensions: Optional[Dict[str, Any]] = None ) -> Response: return await self._graphql_request( - method, query=query, headers=headers, variables=variables, files=files + method, query=query, headers=headers, variables=variables, files=files, + extensions=extensions ) def _get_headers( @@ -122,6 +125,7 @@ def _build_body( variables: Optional[Dict[str, object]] = None, files: Optional[Dict[str, BytesIO]] = None, method: Literal["get", "post"] = "post", + extensions: Optional[Dict[str, Any]] = None, ) -> Optional[Dict[str, object]]: if query is None: assert files is None @@ -134,6 +138,9 @@ def _build_body( if variables: body["variables"] = variables + if extensions: + body["extensions"] = extensions + if files: assert variables is not None @@ -147,6 +154,9 @@ def _build_body( if method == "get" and variables: body["variables"] = json.dumps(variables) + if method == "get" and extensions: + body["extensions"] = json.dumps(extensions) + return body @staticmethod diff --git a/tests/http/clients/chalice.py b/tests/http/clients/chalice.py index eddb7d8ada..7a30ca7888 100644 --- a/tests/http/clients/chalice.py +++ b/tests/http/clients/chalice.py @@ -74,10 +74,12 @@ async def _graphql_request( variables: Optional[Dict[str, object]] = None, files: Optional[Dict[str, BytesIO]] = None, headers: Optional[Dict[str, str]] = None, + extensions: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Response: body = self._build_body( - query=query, variables=variables, files=files, method=method + query=query, variables=variables, files=files, method=method, + extensions=extensions ) data: Union[Dict[str, object], str, None] = None diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index ee31c8e88b..6edf33112e 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -32,11 +32,14 @@ def generate_get_path( - path, query: str, variables: Optional[Dict[str, Any]] = None + path, query: str, variables: Optional[Dict[str, Any]] = None, + extensions: Optional[Dict[str, Any]] = None ) -> str: body: Dict[str, Any] = {"query": query} if variables is not None: body["variables"] = json_module.dumps(variables) + if extensions is not None: + body["extensions"] = json_module.dumps(extensions) parts = [f"{k}={v}" for k, v in body.items()] return f"{path}?{'&'.join(parts)}" @@ -167,10 +170,12 @@ async def _graphql_request( variables: Optional[Dict[str, object]] = None, files: Optional[Dict[str, BytesIO]] = None, headers: Optional[Dict[str, str]] = None, + extensions: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Response: body = self._build_body( - query=query, variables=variables, files=files, method=method + query=query, variables=variables, files=files, method=method, + extensions=extensions ) headers = self._get_headers(method=method, headers=headers, files=files) @@ -185,7 +190,9 @@ async def _graphql_request( endpoint_url = "/graphql" else: body = b"" - endpoint_url = generate_get_path("/graphql", query, variables) + endpoint_url = generate_get_path( + "/graphql", query, variables, extensions + ) return await self.request( url=endpoint_url, method=method, body=body, headers=headers diff --git a/tests/http/clients/django.py b/tests/http/clients/django.py index 75efa2825c..69891bb41a 100644 --- a/tests/http/clients/django.py +++ b/tests/http/clients/django.py @@ -101,13 +101,15 @@ async def _graphql_request( variables: Optional[Dict[str, object]] = None, files: Optional[Dict[str, BytesIO]] = None, headers: Optional[Dict[str, str]] = None, + extensions: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Response: headers = self._get_headers(method=method, headers=headers, files=files) additional_arguments = {**kwargs, **headers} body = self._build_body( - query=query, variables=variables, files=files, method=method + query=query, variables=variables, files=files, method=method, + extensions=extensions ) data: Union[Dict[str, object], str, None] = None diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index f271509f40..469f5802d1 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -117,10 +117,12 @@ async def _graphql_request( variables: Optional[Dict[str, object]] = None, files: Optional[Dict[str, BytesIO]] = None, headers: Optional[Dict[str, str]] = None, + extensions: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Response: body = self._build_body( - query=query, variables=variables, files=files, method=method + query=query, variables=variables, files=files, method=method, + extensions=extensions ) if body: diff --git a/tests/http/clients/flask.py b/tests/http/clients/flask.py index abc0e3cec4..c60e0b9ec5 100644 --- a/tests/http/clients/flask.py +++ b/tests/http/clients/flask.py @@ -86,10 +86,12 @@ async def _graphql_request( variables: Optional[Dict[str, object]] = None, files: Optional[Dict[str, BytesIO]] = None, headers: Optional[Dict[str, str]] = None, + extensions: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Response: body = self._build_body( - query=query, variables=variables, files=files, method=method + query=query, variables=variables, files=files, method=method, + extensions=extensions ) data: Union[Dict[str, object], str, None] = None diff --git a/tests/http/clients/litestar.py b/tests/http/clients/litestar.py index ccf9999f7f..7ea8bab100 100644 --- a/tests/http/clients/litestar.py +++ b/tests/http/clients/litestar.py @@ -98,10 +98,12 @@ async def _graphql_request( variables: Optional[Dict[str, object]] = None, files: Optional[Dict[str, BytesIO]] = None, headers: Optional[Dict[str, str]] = None, + extensions: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Response: if body := self._build_body( - query=query, variables=variables, files=files, method=method + query=query, variables=variables, files=files, method=method, + extensions=extensions ): if method == "get": kwargs["params"] = body diff --git a/tests/http/clients/quart.py b/tests/http/clients/quart.py index 60bc14b8c2..48b9b709ca 100644 --- a/tests/http/clients/quart.py +++ b/tests/http/clients/quart.py @@ -79,10 +79,12 @@ async def _graphql_request( variables: Optional[Dict[str, object]] = None, files: Optional[Dict[str, BytesIO]] = None, headers: Optional[Dict[str, str]] = None, + extensions: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Response: body = self._build_body( - query=query, variables=variables, files=files, method=method + query=query, variables=variables, files=files, method=method, + extensions=extensions ) url = "/graphql" diff --git a/tests/http/clients/sanic.py b/tests/http/clients/sanic.py index 449aa316e8..091572b897 100644 --- a/tests/http/clients/sanic.py +++ b/tests/http/clients/sanic.py @@ -76,10 +76,12 @@ async def _graphql_request( variables: Optional[Dict[str, object]] = None, files: Optional[Dict[str, BytesIO]] = None, headers: Optional[Dict[str, str]] = None, + extensions: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Response: body = self._build_body( - query=query, variables=variables, files=files, method=method + query=query, variables=variables, files=files, method=method, + extensions=extensions ) if body: diff --git a/tests/http/clients/starlite.py b/tests/http/clients/starlite.py index 9af3a8206e..62e0c7cf7b 100644 --- a/tests/http/clients/starlite.py +++ b/tests/http/clients/starlite.py @@ -98,10 +98,12 @@ async def _graphql_request( variables: Optional[Dict[str, object]] = None, files: Optional[Dict[str, BytesIO]] = None, headers: Optional[Dict[str, str]] = None, + extensions: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Response: body = self._build_body( - query=query, variables=variables, files=files, method=method + query=query, variables=variables, files=files, method=method, + extensions=extensions ) if body: diff --git a/tests/http/test_query.py b/tests/http/test_query.py index f60a122d8c..2a2c450983 100644 --- a/tests/http/test_query.py +++ b/tests/http/test_query.py @@ -191,6 +191,21 @@ async def test_query_context(method: Literal["get", "post"], http_client: HttpCl assert data["valueFromContext"] == "a value from context" +@pytest.mark.parametrize("method", ["get", "post"]) +async def test_query_extensions( + method: Literal["get", "post"], http_client: HttpClient +): + response = await http_client.query( + method=method, + query='{ valueFromExtensions(key:"test") }', + extensions={"test": "hello"} + ) + data = response.json["data"] + + assert response.status_code == 200 + assert data["valueFromExtensions"] == "hello" + + @pytest.mark.parametrize("method", ["get", "post"]) async def test_returning_status_code( method: Literal["get", "post"], http_client: HttpClient diff --git a/tests/views/schema.py b/tests/views/schema.py index 14a25b0c0d..6e9c7bf031 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -108,6 +108,10 @@ def root_name(self) -> str: def value_from_context(self, info: strawberry.Info) -> str: return info.context["custom_value"] + @strawberry.field + def value_from_extensions(self, key: str, info: strawberry.Info) -> str: + return info.input_extensions[key] + @strawberry.field def returns_401(self, info: strawberry.Info) -> str: response = info.context["response"] From ae3456b81a3fd8d4db662b4a2dcf30c5eafdfde7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Jun 2024 13:54:52 +0000 Subject: [PATCH 09/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/http/clients/aiohttp.py | 7 +++++-- tests/http/clients/asgi.py | 7 +++++-- tests/http/clients/base.py | 10 +++++++--- tests/http/clients/chalice.py | 7 +++++-- tests/http/clients/channels.py | 17 ++++++++++------- tests/http/clients/django.py | 7 +++++-- tests/http/clients/fastapi.py | 7 +++++-- tests/http/clients/flask.py | 7 +++++-- tests/http/clients/litestar.py | 7 +++++-- tests/http/clients/quart.py | 7 +++++-- tests/http/clients/sanic.py | 7 +++++-- tests/http/clients/starlite.py | 7 +++++-- tests/http/test_query.py | 2 +- 13 files changed, 68 insertions(+), 31 deletions(-) diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index 44d9bc99c3..0bfc3ef68b 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -111,8 +111,11 @@ async def _graphql_request( ) -> Response: async with TestClient(TestServer(self.app)) as client: body = self._build_body( - query=query, variables=variables, files=files, method=method, - extensions=extensions + query=query, + variables=variables, + files=files, + method=method, + extensions=extensions, ) if body and files: diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index 745f53beeb..87a4296775 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -101,8 +101,11 @@ async def _graphql_request( **kwargs: Any, ) -> Response: body = self._build_body( - query=query, variables=variables, files=files, method=method, - extensions=extensions + query=query, + variables=variables, + files=files, + method=method, + extensions=extensions, ) if method == "get": diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index 43a94c0402..7bcae4f1c4 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -95,11 +95,15 @@ async def query( variables: Optional[Dict[str, object]] = None, files: Optional[Dict[str, BytesIO]] = None, headers: Optional[Dict[str, str]] = None, - extensions: Optional[Dict[str, Any]] = None + extensions: Optional[Dict[str, Any]] = None, ) -> Response: return await self._graphql_request( - method, query=query, headers=headers, variables=variables, files=files, - extensions=extensions + method, + query=query, + headers=headers, + variables=variables, + files=files, + extensions=extensions, ) def _get_headers( diff --git a/tests/http/clients/chalice.py b/tests/http/clients/chalice.py index 7a30ca7888..0fc6838958 100644 --- a/tests/http/clients/chalice.py +++ b/tests/http/clients/chalice.py @@ -78,8 +78,11 @@ async def _graphql_request( **kwargs: Any, ) -> Response: body = self._build_body( - query=query, variables=variables, files=files, method=method, - extensions=extensions + query=query, + variables=variables, + files=files, + method=method, + extensions=extensions, ) data: Union[Dict[str, object], str, None] = None diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index 6edf33112e..3592375fa3 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -32,8 +32,10 @@ def generate_get_path( - path, query: str, variables: Optional[Dict[str, Any]] = None, - extensions: Optional[Dict[str, Any]] = None + path, + query: str, + variables: Optional[Dict[str, Any]] = None, + extensions: Optional[Dict[str, Any]] = None, ) -> str: body: Dict[str, Any] = {"query": query} if variables is not None: @@ -174,8 +176,11 @@ async def _graphql_request( **kwargs: Any, ) -> Response: body = self._build_body( - query=query, variables=variables, files=files, method=method, - extensions=extensions + query=query, + variables=variables, + files=files, + method=method, + extensions=extensions, ) headers = self._get_headers(method=method, headers=headers, files=files) @@ -190,9 +195,7 @@ async def _graphql_request( endpoint_url = "/graphql" else: body = b"" - endpoint_url = generate_get_path( - "/graphql", query, variables, extensions - ) + endpoint_url = generate_get_path("/graphql", query, variables, extensions) return await self.request( url=endpoint_url, method=method, body=body, headers=headers diff --git a/tests/http/clients/django.py b/tests/http/clients/django.py index 69891bb41a..9e5a358493 100644 --- a/tests/http/clients/django.py +++ b/tests/http/clients/django.py @@ -108,8 +108,11 @@ async def _graphql_request( additional_arguments = {**kwargs, **headers} body = self._build_body( - query=query, variables=variables, files=files, method=method, - extensions=extensions + query=query, + variables=variables, + files=files, + method=method, + extensions=extensions, ) data: Union[Dict[str, object], str, None] = None diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index 469f5802d1..6f09baa4bf 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -121,8 +121,11 @@ async def _graphql_request( **kwargs: Any, ) -> Response: body = self._build_body( - query=query, variables=variables, files=files, method=method, - extensions=extensions + query=query, + variables=variables, + files=files, + method=method, + extensions=extensions, ) if body: diff --git a/tests/http/clients/flask.py b/tests/http/clients/flask.py index c60e0b9ec5..60c2d67801 100644 --- a/tests/http/clients/flask.py +++ b/tests/http/clients/flask.py @@ -90,8 +90,11 @@ async def _graphql_request( **kwargs: Any, ) -> Response: body = self._build_body( - query=query, variables=variables, files=files, method=method, - extensions=extensions + query=query, + variables=variables, + files=files, + method=method, + extensions=extensions, ) data: Union[Dict[str, object], str, None] = None diff --git a/tests/http/clients/litestar.py b/tests/http/clients/litestar.py index 7ea8bab100..bfedb6a6e8 100644 --- a/tests/http/clients/litestar.py +++ b/tests/http/clients/litestar.py @@ -102,8 +102,11 @@ async def _graphql_request( **kwargs: Any, ) -> Response: if body := self._build_body( - query=query, variables=variables, files=files, method=method, - extensions=extensions + query=query, + variables=variables, + files=files, + method=method, + extensions=extensions, ): if method == "get": kwargs["params"] = body diff --git a/tests/http/clients/quart.py b/tests/http/clients/quart.py index 48b9b709ca..4c65224f5c 100644 --- a/tests/http/clients/quart.py +++ b/tests/http/clients/quart.py @@ -83,8 +83,11 @@ async def _graphql_request( **kwargs: Any, ) -> Response: body = self._build_body( - query=query, variables=variables, files=files, method=method, - extensions=extensions + query=query, + variables=variables, + files=files, + method=method, + extensions=extensions, ) url = "/graphql" diff --git a/tests/http/clients/sanic.py b/tests/http/clients/sanic.py index 091572b897..6038179e17 100644 --- a/tests/http/clients/sanic.py +++ b/tests/http/clients/sanic.py @@ -80,8 +80,11 @@ async def _graphql_request( **kwargs: Any, ) -> Response: body = self._build_body( - query=query, variables=variables, files=files, method=method, - extensions=extensions + query=query, + variables=variables, + files=files, + method=method, + extensions=extensions, ) if body: diff --git a/tests/http/clients/starlite.py b/tests/http/clients/starlite.py index 62e0c7cf7b..0bc2a152f1 100644 --- a/tests/http/clients/starlite.py +++ b/tests/http/clients/starlite.py @@ -102,8 +102,11 @@ async def _graphql_request( **kwargs: Any, ) -> Response: body = self._build_body( - query=query, variables=variables, files=files, method=method, - extensions=extensions + query=query, + variables=variables, + files=files, + method=method, + extensions=extensions, ) if body: diff --git a/tests/http/test_query.py b/tests/http/test_query.py index 2a2c450983..665e77464f 100644 --- a/tests/http/test_query.py +++ b/tests/http/test_query.py @@ -198,7 +198,7 @@ async def test_query_extensions( response = await http_client.query( method=method, query='{ valueFromExtensions(key:"test") }', - extensions={"test": "hello"} + extensions={"test": "hello"}, ) data = response.json["data"] From ea1b0539ed0a448771373aca639e374bec766d40 Mon Sep 17 00:00:00 2001 From: Omar Marzouk Date: Sat, 6 Jul 2024 21:18:05 +0200 Subject: [PATCH 10/15] Change parameter order in test function Co-authored-by: Thiago Bellini Ribeiro --- tests/views/schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/views/schema.py b/tests/views/schema.py index 6e9c7bf031..6d12dbe771 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -109,7 +109,7 @@ def value_from_context(self, info: strawberry.Info) -> str: return info.context["custom_value"] @strawberry.field - def value_from_extensions(self, key: str, info: strawberry.Info) -> str: + def value_from_extensions(self, info: strawberry.Info, key: str) -> str: return info.input_extensions[key] @strawberry.field From 351c48e6a742524783c92f9f48ca4c8df2962f6d Mon Sep 17 00:00:00 2001 From: Omar Marzouk Date: Sat, 6 Jul 2024 21:18:38 +0200 Subject: [PATCH 11/15] Use isinstance Co-authored-by: Thiago Bellini Ribeiro --- strawberry/types/info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/strawberry/types/info.py b/strawberry/types/info.py index 206f459b90..d368657a6a 100644 --- a/strawberry/types/info.py +++ b/strawberry/types/info.py @@ -80,7 +80,7 @@ def selected_fields(self) -> List[Selection]: @property def context(self) -> ContextType: - if type(self._raw_info.context) is ContextWrapper: + if isinstance(self._raw_info.context, ContextWrapper): return self._raw_info.context.context return self._raw_info.context From 494d2b6bd7e6135dfe5679ba42fc6e8aa7509e3d Mon Sep 17 00:00:00 2001 From: Omar Marzouk Date: Sat, 6 Jul 2024 21:18:48 +0200 Subject: [PATCH 12/15] Use isinstance Co-authored-by: Thiago Bellini Ribeiro --- strawberry/types/info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/strawberry/types/info.py b/strawberry/types/info.py index d368657a6a..d94c87f69d 100644 --- a/strawberry/types/info.py +++ b/strawberry/types/info.py @@ -86,7 +86,7 @@ def context(self) -> ContextType: @property def input_extensions(self) -> Dict[str, Any]: - if type(self._raw_info.context) is ContextWrapper: + if isinstance(self._raw_info.context, ContextWrapper): return self._raw_info.context.extensions return {} From 2b16e45157933a398c044f35f8e37ca69ed618c0 Mon Sep 17 00:00:00 2001 From: omarzouk Date: Sat, 6 Jul 2024 21:21:32 +0200 Subject: [PATCH 13/15] change to absolute imports --- strawberry/http/async_base_view.py | 2 +- strawberry/http/sync_base_view.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index f6bf2f3367..6706b1a5e0 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -22,7 +22,7 @@ from strawberry.types import ExecutionResult from strawberry.types.graphql import OperationType -from ..types.context_wrapper import ContextWrapper +from strawberry.types.context_wrapper import ContextWrapper from .base import BaseView from .exceptions import HTTPException from .types import FormData, HTTPMethod, QueryParams diff --git a/strawberry/http/sync_base_view.py b/strawberry/http/sync_base_view.py index 2b2c72e4f8..47e688d51e 100644 --- a/strawberry/http/sync_base_view.py +++ b/strawberry/http/sync_base_view.py @@ -23,7 +23,7 @@ from strawberry.types import ExecutionResult from strawberry.types.graphql import OperationType -from ..types.context_wrapper import ContextWrapper +from strawberry.types.context_wrapper import ContextWrapper from .base import BaseView from .exceptions import HTTPException from .types import HTTPMethod, QueryParams From d22c2d95e3b43ad7c8fa72a873917eff72a96d32 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 6 Jul 2024 19:21:53 +0000 Subject: [PATCH 14/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- strawberry/http/async_base_view.py | 2 +- strawberry/http/sync_base_view.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 6706b1a5e0..05e0a7605a 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -20,9 +20,9 @@ from strawberry.schema.base import BaseSchema from strawberry.schema.exceptions import InvalidOperationTypeError from strawberry.types import ExecutionResult +from strawberry.types.context_wrapper import ContextWrapper from strawberry.types.graphql import OperationType -from strawberry.types.context_wrapper import ContextWrapper from .base import BaseView from .exceptions import HTTPException from .types import FormData, HTTPMethod, QueryParams diff --git a/strawberry/http/sync_base_view.py b/strawberry/http/sync_base_view.py index 47e688d51e..31d866dc94 100644 --- a/strawberry/http/sync_base_view.py +++ b/strawberry/http/sync_base_view.py @@ -21,9 +21,9 @@ from strawberry.schema import BaseSchema from strawberry.schema.exceptions import InvalidOperationTypeError from strawberry.types import ExecutionResult +from strawberry.types.context_wrapper import ContextWrapper from strawberry.types.graphql import OperationType -from strawberry.types.context_wrapper import ContextWrapper from .base import BaseView from .exceptions import HTTPException from .types import HTTPMethod, QueryParams From e3afec33e41618c862f5c33c5b116d8f1f0c6708 Mon Sep 17 00:00:00 2001 From: omarzouk Date: Tue, 16 Jul 2024 16:40:00 +0200 Subject: [PATCH 15/15] remove list check --- strawberry/http/base.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/strawberry/http/base.py b/strawberry/http/base.py index 3cf55cfbb0..1de1d08f53 100644 --- a/strawberry/http/base.py +++ b/strawberry/http/base.py @@ -62,9 +62,6 @@ def parse_query_params(self, params: QueryParams) -> Dict[str, Any]: if "extensions" in params: extensions = params["extensions"] - if isinstance(extensions, list): - extensions = extensions[0] - if extensions: params["extensions"] = self.parse_json(extensions)