Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for extensions to the HTTP protocol #3461

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions strawberry/http/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class GraphQLRequestData:
query: Optional[str]
variables: Optional[Dict[str, Any]]
operation_name: Optional[str]
extensions: Optional[Dict[str, Any]]
omarzouk marked this conversation as resolved.
Show resolved Hide resolved
omarzouk marked this conversation as resolved.
Show resolved Hide resolved
patrick91 marked this conversation as resolved.
Show resolved Hide resolved


def parse_query_params(params: Dict[str, str]) -> Dict[str, Any]:
Expand All @@ -47,6 +48,7 @@ def parse_request_data(data: Mapping[str, Any]) -> GraphQLRequestData:
query=data.get("query"),
variables=data.get("variables"),
operation_name=data.get("operationName"),
extensions=data.get("extensions"),
)


Expand Down
8 changes: 7 additions & 1 deletion strawberry/http/async_base_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from strawberry.schema.base import BaseSchema
from strawberry.schema.exceptions import InvalidOperationTypeError
from strawberry.types import ExecutionResult
from strawberry.types.context_wrapper import ContextWrapper
from strawberry.types.graphql import OperationType

from .base import BaseView
Expand Down Expand Up @@ -102,11 +103,15 @@ async def execute_operation(

assert self.schema

context_wrapper = ContextWrapper(
context=context, extensions=request_data.extensions
)

return await self.schema.execute(
request_data.query,
root_value=root_value,
variable_values=request_data.variables,
context_value=context,
context_value=context_wrapper,
omarzouk marked this conversation as resolved.
Show resolved Hide resolved
operation_name=request_data.operation_name,
allowed_operation_types=allowed_operation_types,
)
Expand Down Expand Up @@ -205,6 +210,7 @@ async def parse_http_body(
query=data.get("query"),
variables=data.get("variables"),
operation_name=data.get("operationName"),
extensions=data.get("extensions"),
)

async def process_result(
Expand Down
6 changes: 6 additions & 0 deletions strawberry/http/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def parse_query_params(self, params: QueryParams) -> Dict[str, Any]:
if variables:
params["variables"] = self.parse_json(variables)

if "extensions" in params:
extensions = params["extensions"]

if extensions:
params["extensions"] = self.parse_json(extensions)

return params

@property
Expand Down
8 changes: 7 additions & 1 deletion strawberry/http/sync_base_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from strawberry.schema import BaseSchema
from strawberry.schema.exceptions import InvalidOperationTypeError
from strawberry.types import ExecutionResult
from strawberry.types.context_wrapper import ContextWrapper
from strawberry.types.graphql import OperationType

from .base import BaseView
Expand Down Expand Up @@ -112,11 +113,15 @@ def execute_operation(

assert self.schema

context_wrapper = ContextWrapper(
omarzouk marked this conversation as resolved.
Show resolved Hide resolved
context=context, extensions=request_data.extensions
)
Comment on lines +116 to +118
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this context wrapper only to pass request data?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes exactly. Since the context data comes from get_context and that can be any type of object the user decides to return, this felt like a safe way to attach more data

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to review and merge this this week 😊 thanks for the patience!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you so much! 🙏


return self.schema.execute_sync(
request_data.query,
root_value=root_value,
variable_values=request_data.variables,
context_value=context,
context_value=context_wrapper,
operation_name=request_data.operation_name,
allowed_operation_types=allowed_operation_types,
)
Expand Down Expand Up @@ -146,6 +151,7 @@ def parse_http_body(self, request: SyncHTTPRequestAdapter) -> GraphQLRequestData
query=data.get("query"),
variables=data.get("variables"),
operation_name=data.get("operationName"),
extensions=data.get("extensions"),
)

def _handle_errors(
Expand Down
8 changes: 8 additions & 0 deletions strawberry/types/context_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional


@dataclass
class ContextWrapper:
omarzouk marked this conversation as resolved.
Show resolved Hide resolved
context: Optional[Any]
extensions: Optional[Dict[str, Any]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same suggestion as above

Suggested change
extensions: Optional[Dict[str, Any]]
extensions: Optional[Dict[str, Any]] = None

12 changes: 12 additions & 0 deletions strawberry/types/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from typing_extensions import TypeVar

from .context_wrapper import ContextWrapper
from .nodes import convert_selections

if TYPE_CHECKING:
Expand Down Expand Up @@ -115,8 +116,19 @@
@property
def context(self) -> ContextType:
"""The context passed to the query execution."""
if isinstance(self._raw_info.context, ContextWrapper):
return self._raw_info.context.context

return self._raw_info.context

@property
def input_extensions(self) -> Dict[str, Any]:
"""The input extensions passed to the query execution."""
if isinstance(self._raw_info.context, ContextWrapper):
return self._raw_info.context.extensions

return {}

Check warning on line 130 in strawberry/types/info.py

View check run for this annotation

Codecov / codecov/patch

strawberry/types/info.py#L130

Added line #L130 was not covered by tests

@property
def root_value(self) -> RootValueType:
"""The root value passed to the query execution."""
Expand Down
7 changes: 6 additions & 1 deletion tests/http/clients/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,16 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response:
async with TestClient(TestServer(self.app)) as client:
body = self._build_body(
query=query, variables=variables, files=files, method=method
query=query,
variables=variables,
files=files,
method=method,
extensions=extensions,
)

if body and files:
Expand Down
7 changes: 6 additions & 1 deletion tests/http/clients/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,15 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response:
body = self._build_body(
query=query, variables=variables, files=files, method=method
query=query,
variables=variables,
files=files,
method=method,
extensions=extensions,
)

if method == "get":
Expand Down
16 changes: 15 additions & 1 deletion tests/http/clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response: ...

Expand Down Expand Up @@ -89,9 +90,15 @@ async def query(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
) -> Response:
return await self._graphql_request(
method, query=query, headers=headers, variables=variables, files=files
method,
query=query,
headers=headers,
variables=variables,
files=files,
extensions=extensions,
)

def _get_headers(
Expand All @@ -117,6 +124,7 @@ def _build_body(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
method: Literal["get", "post"] = "post",
extensions: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, object]]:
if query is None:
assert files is None
Expand All @@ -129,6 +137,9 @@ def _build_body(
if variables:
body["variables"] = variables

if extensions:
body["extensions"] = extensions

if files:
assert variables is not None

Expand All @@ -142,6 +153,9 @@ def _build_body(
if method == "get" and variables:
body["variables"] = json.dumps(variables)

if method == "get" and extensions:
body["extensions"] = json.dumps(extensions)

return body

@staticmethod
Expand Down
7 changes: 6 additions & 1 deletion tests/http/clients/chalice.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,15 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response:
body = self._build_body(
query=query, variables=variables, files=files, method=method
query=query,
variables=variables,
files=files,
method=method,
extensions=extensions,
)

data: Union[Dict[str, object], str, None] = None
Expand Down
16 changes: 13 additions & 3 deletions tests/http/clients/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,16 @@


def generate_get_path(
path, query: str, variables: Optional[Dict[str, Any]] = None
path,
query: str,
variables: Optional[Dict[str, Any]] = None,
extensions: Optional[Dict[str, Any]] = None,
) -> str:
body: Dict[str, Any] = {"query": query}
if variables is not None:
body["variables"] = json_module.dumps(variables)
if extensions is not None:
body["extensions"] = json_module.dumps(extensions)

parts = [f"{k}={v}" for k, v in body.items()]
return f"{path}?{'&'.join(parts)}"
Expand Down Expand Up @@ -165,10 +170,15 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response:
body = self._build_body(
query=query, variables=variables, files=files, method=method
query=query,
variables=variables,
files=files,
method=method,
extensions=extensions,
)

headers = self._get_headers(method=method, headers=headers, files=files)
Expand All @@ -183,7 +193,7 @@ async def _graphql_request(
endpoint_url = "/graphql"
else:
body = b""
endpoint_url = generate_get_path("/graphql", query, variables)
endpoint_url = generate_get_path("/graphql", query, variables, extensions)

return await self.request(
url=endpoint_url, method=method, body=body, headers=headers
Expand Down
7 changes: 6 additions & 1 deletion tests/http/clients/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,18 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response:
headers = self._get_headers(method=method, headers=headers, files=files)
additional_arguments = {**kwargs, **headers}

body = self._build_body(
query=query, variables=variables, files=files, method=method
query=query,
variables=variables,
files=files,
method=method,
extensions=extensions,
)

data: Union[Dict[str, object], str, None] = None
Expand Down
7 changes: 6 additions & 1 deletion tests/http/clients/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,15 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response:
body = self._build_body(
query=query, variables=variables, files=files, method=method
query=query,
variables=variables,
files=files,
method=method,
extensions=extensions,
)

if body:
Expand Down
7 changes: 6 additions & 1 deletion tests/http/clients/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,15 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response:
body = self._build_body(
query=query, variables=variables, files=files, method=method
query=query,
variables=variables,
files=files,
method=method,
extensions=extensions,
)

data: Union[Dict[str, object], str, None] = None
Expand Down
7 changes: 6 additions & 1 deletion tests/http/clients/litestar.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,15 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response:
if body := self._build_body(
query=query, variables=variables, files=files, method=method
query=query,
variables=variables,
files=files,
method=method,
extensions=extensions,
):
if method == "get":
kwargs["params"] = body
Expand Down
7 changes: 6 additions & 1 deletion tests/http/clients/quart.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,15 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response:
body = self._build_body(
query=query, variables=variables, files=files, method=method
query=query,
variables=variables,
files=files,
method=method,
extensions=extensions,
)

url = "/graphql"
Expand Down
7 changes: 6 additions & 1 deletion tests/http/clients/sanic.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,15 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response:
body = self._build_body(
query=query, variables=variables, files=files, method=method
query=query,
variables=variables,
files=files,
method=method,
extensions=extensions,
)

if body:
Expand Down
Loading
Loading