Skip to content

Commit

Permalink
Support custom JSON decoders for views (#3709)
Browse files Browse the repository at this point in the history
  • Loading branch information
DoctorJohn authored Nov 22, 2024
1 parent a3dd2df commit e8f4b6e
Show file tree
Hide file tree
Showing 18 changed files with 262 additions and 19 deletions.
7 changes: 7 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Release type: minor

The view classes of all integrations now have a `decode_json` method that allows
you to customize the decoding of HTTP JSON requests.

This is useful if you want to use a different JSON decoder, for example, to
optimize performance.
21 changes: 21 additions & 0 deletions docs/integrations/aiohttp.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,27 @@ class MyGraphQLView(GraphQLView):
In this case we are doing the default processing of the result, but it can be
tweaked based on your needs.

### decode_json

`decode_json` allows to customize the decoding of HTTP and WebSocket JSON
requests. By default we use `json.loads` but you can override this method to use
a different decoder.

```python
from strawberry.aiohttp.views import GraphQLView
from typing import Union
import orjson


class MyGraphQLView(GraphQLView):
def decode_json(self, data: Union[str, bytes]) -> object:
return orjson.loads(data)
```

Make sure your code raises `json.JSONDecodeError` or a subclass of it if the
JSON cannot be decoded. The library shown in the example above, `orjson`, does
this by default.

### encode_json

`encode_json` allows to customize the encoding of HTTP and WebSocket JSON
Expand Down
23 changes: 22 additions & 1 deletion docs/integrations/asgi.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ We allow to extend the base `GraphQL` app, by overriding the following methods:
- `async get_context(self, request: Union[Request, WebSocket], response: Optional[Response] = None) -> Any`
- `async get_root_value(self, request: Request) -> Any`
- `async process_result(self, request: Request, result: ExecutionResult) -> GraphQLHTTPResponse`
- `def encode_json(self, response_data: object) -> str`
- `def decode_json(self, data: Union[str, bytes]) -> object`
- `def encode_json(self, data: object) -> str`
- `async def render_graphql_ide(self, request: Request) -> Response`

### get_context
Expand Down Expand Up @@ -167,6 +168,26 @@ class MyGraphQL(GraphQL):
In this case we are doing the default processing of the result, but it can be
tweaked based on your needs.

### decode_json

`decode_json` allows to customize the decoding of HTTP JSON requests. By default
we use `json.loads` but you can override this method to use a different decoder.

```python
from strawberry.asgi import GraphQL
from typing import Union
import orjson


class MyGraphQLView(GraphQL):
def decode_json(self, data: Union[str, bytes]) -> object:
return orjson.loads(data)
```

Make sure your code raises `json.JSONDecodeError` or a subclass of it if the
JSON cannot be decoded. The library shown in the example above, `orjson`, does
this by default.

### encode_json

`encode_json` allows to customize the encoding of HTTP and WebSocket JSON
Expand Down
20 changes: 20 additions & 0 deletions docs/integrations/chalice.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,26 @@ class MyGraphQLView(GraphQLView):
In this case we are doing the default processing of the result, but it can be
tweaked based on your needs.

### decode_json

`decode_json` allows to customize the decoding of HTTP JSON requests. By default
we use `json.loads` but you can override this method to use a different decoder.

```python
from strawberry.chalice.views import GraphQLView
from typing import Union
import orjson


class MyGraphQLView(GraphQLView):
def decode_json(self, data: Union[str, bytes]) -> object:
return orjson.loads(data)
```

Make sure your code raises `json.JSONDecodeError` or a subclass of it if the
JSON cannot be decoded. The library shown in the example above, `orjson`, does
this by default.

### encode_json

`encode_json` allows to customize the encoding of HTTP and WebSocket JSON
Expand Down
21 changes: 21 additions & 0 deletions docs/integrations/django.md
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,27 @@ class MyGraphQLView(AsyncGraphQLView):
In this case we are doing the default processing of the result, but it can be
tweaked based on your needs.

### decode_json

`decode_json` allows to customize the decoding of HTTP and WebSocket JSON
requests. By default we use `json.loads` but you can override this method to use
a different decoder.

```python
from strawberry.django.views import AsyncGraphQLView
from typing import Union
import orjson


class MyGraphQLView(AsyncGraphQLView):
def decode_json(self, data: Union[str, bytes]) -> object:
return orjson.loads(data)
```

Make sure your code raises `json.JSONDecodeError` or a subclass of it if the
JSON cannot be decoded. The library shown in the example above, `orjson`, does
this by default.

### encode_json

`encode_json` allows to customize the encoding of HTTP and WebSocket JSON
Expand Down
21 changes: 21 additions & 0 deletions docs/integrations/fastapi.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,27 @@ class MyGraphQLRouter(GraphQLRouter):
In this case we are doing the default processing of the result, but it can be
tweaked based on your needs.

### decode_json

`decode_json` allows to customize the decoding of HTTP and WebSocket JSON
requests. By default we use `json.loads` but you can override this method to use
a different decoder.

```python
from strawberry.fastapi import GraphQLRouter
from typing import Union
import orjson


class MyGraphQLRouter(GraphQLRouter):
def decode_json(self, data: Union[str, bytes]) -> object:
return orjson.loads(data)
```

Make sure your code raises `json.JSONDecodeError` or a subclass of it if the
JSON cannot be decoded. The library shown in the example above, `orjson`, does
this by default.

### encode_json

`encode_json` allows to customize the encoding of HTTP and WebSocket JSON
Expand Down
20 changes: 20 additions & 0 deletions docs/integrations/flask.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,26 @@ class MyGraphQLView(GraphQLView):
In this case we are doing the default processing of the result, but it can be
tweaked based on your needs.

### decode_json

`decode_json` allows to customize the decoding of HTTP JSON requests. By default
we use `json.loads` but you can override this method to use a different decoder.

```python
from strawberry.flask.views import GraphQLView
from typing import Union
import orjson


class MyGraphQLView(GraphQLView):
def decode_json(self, data: Union[str, bytes]) -> object:
return orjson.loads(data)
```

Make sure your code raises `json.JSONDecodeError` or a subclass of it if the
JSON cannot be decoded. The library shown in the example above, `orjson`, does
this by default.

### encode_json

`encode_json` allows to customize the encoding of HTTP and WebSocket JSON
Expand Down
20 changes: 20 additions & 0 deletions docs/integrations/quart.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,26 @@ class MyGraphQLView(GraphQLView):
In this case we are doing the default processing of the result, but it can be
tweaked based on your needs.

### decode_json

`decode_json` allows to customize the decoding of HTTP JSON requests. By default
we use `json.loads` but you can override this method to use a different decoder.

```python
from strawberry.quart.views import GraphQLView
from typing import Union
import orjson


class MyGraphQLView(GraphQLView):
def decode_json(self, data: Union[str, bytes]) -> object:
return orjson.loads(data)
```

Make sure your code raises `json.JSONDecodeError` or a subclass of it if the
JSON cannot be decoded. The library shown in the example above, `orjson`, does
this by default.

### encode_json

`encode_json` allows to customize the encoding of HTTP and WebSocket JSON
Expand Down
20 changes: 20 additions & 0 deletions docs/integrations/sanic.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,26 @@ class MyGraphQLView(GraphQLView):
In this case we are doing the default processing of the result, but it can be
tweaked based on your needs.

### decode_json

`decode_json` allows to customize the decoding of HTTP JSON requests. By default
we use `json.loads` but you can override this method to use a different decoder.

```python
from strawberry.sanic.views import GraphQLView
from typing import Union
import orjson


class MyGraphQLView(GraphQLView):
def decode_json(self, data: Union[str, bytes]) -> object:
return orjson.loads(data)
```

Make sure your code raises `json.JSONDecodeError` or a subclass of it if the
JSON cannot be decoded. The library shown in the example above, `orjson`, does
this by default.

### encode_json

`encode_json` allows to customize the encoding of HTTP and WebSocket JSON
Expand Down
4 changes: 2 additions & 2 deletions strawberry/aiohttp/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ def __init__(

async def iter_json(
self, *, ignore_parsing_errors: bool = False
) -> AsyncGenerator[Dict[str, object], None]:
) -> AsyncGenerator[object, None]:
async for ws_message in self.ws:
if ws_message.type == http.WSMsgType.TEXT:
try:
yield ws_message.json()
yield self.view.decode_json(ws_message.data)
except JSONDecodeError:
if not ignore_parsing_errors:
raise NonJsonMessageReceived()
Expand Down
5 changes: 3 additions & 2 deletions strawberry/asgi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,12 @@ def __init__(

async def iter_json(
self, *, ignore_parsing_errors: bool = False
) -> AsyncGenerator[Dict[str, object], None]:
) -> AsyncGenerator[object, None]:
try:
while self.ws.application_state != WebSocketState.DISCONNECTED:
try:
yield await self.ws.receive_json()
text = await self.ws.receive_text()
yield self.view.decode_json(text)
except JSONDecodeError: # noqa: PERF203
if not ignore_parsing_errors:
raise NonJsonMessageReceived()
Expand Down
5 changes: 2 additions & 3 deletions strawberry/channels/handlers/ws_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import (
TYPE_CHECKING,
AsyncGenerator,
Dict,
Mapping,
Optional,
Tuple,
Expand Down Expand Up @@ -39,7 +38,7 @@ def __init__(

async def iter_json(
self, *, ignore_parsing_errors: bool = False
) -> AsyncGenerator[Dict[str, object], None]:
) -> AsyncGenerator[object, None]:
while True:
message = await self.ws_consumer.message_queue.get()

Expand All @@ -50,7 +49,7 @@ async def iter_json(
raise NonTextMessageReceived()

try:
yield json.loads(message["message"])
yield self.view.decode_json(message["message"])
except json.JSONDecodeError:
if not ignore_parsing_errors:
raise NonJsonMessageReceived()
Expand Down
2 changes: 1 addition & 1 deletion strawberry/http/async_base_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, view: "AsyncBaseHTTPView") -> None:
@abc.abstractmethod
def iter_json(
self, *, ignore_parsing_errors: bool = False
) -> AsyncGenerator[Dict[str, object], None]: ...
) -> AsyncGenerator[object, None]: ...

@abc.abstractmethod
async def send_json(self, message: Mapping[str, object]) -> None: ...
Expand Down
5 changes: 4 additions & 1 deletion strawberry/http/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,13 @@ def is_request_allowed(self, request: BaseRequestProtocol) -> bool:

def parse_json(self, data: Union[str, bytes]) -> Any:
try:
return json.loads(data)
return self.decode_json(data)
except json.JSONDecodeError as e:
raise HTTPException(400, "Unable to parse request body as JSON") from e

def decode_json(self, data: Union[str, bytes]) -> object:
return json.loads(data)

def encode_json(self, data: object) -> str:
return json.dumps(data)

Expand Down
4 changes: 2 additions & 2 deletions strawberry/litestar/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def __init__(

async def iter_json(
self, *, ignore_parsing_errors: bool = False
) -> AsyncGenerator[Dict[str, object], None]:
) -> AsyncGenerator[object, None]:
try:
while self.ws.connection_state != "disconnect":
text = await self.ws.receive_text()
Expand All @@ -212,7 +212,7 @@ async def iter_json(
raise NonTextMessageReceived()

try:
yield json.loads(text)
yield self.view.decode_json(text)
except json.JSONDecodeError:
if not ignore_parsing_errors:
raise NonJsonMessageReceived()
Expand Down
17 changes: 17 additions & 0 deletions tests/http/test_http.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest

from strawberry.http.base import BaseView

from .clients.base import HttpClient


Expand All @@ -11,3 +13,18 @@ async def test_does_only_allow_get_and_post(
response = await http_client.request(url="/graphql", method=method) # type: ignore

assert response.status_code == 405


async def test_the_http_handler_uses_the_views_decode_json_method(
http_client: HttpClient, mocker
):
spy = mocker.spy(BaseView, "decode_json")

response = await http_client.query(query="{ hello }")
assert response.status_code == 200

data = response.json["data"]
assert isinstance(data, dict)
assert data["hello"] == "Hello world"

assert spy.call_count == 1
Loading

0 comments on commit e8f4b6e

Please sign in to comment.