Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let graphql-ws use type inference whenever possible #3704

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: patch

This release refactors part of the legacy `graphql-ws` protocol implementation, making it easier to read, maintain, and extend.
45 changes: 18 additions & 27 deletions strawberry/subscriptions/protocols/graphql_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,9 @@

from strawberry.http.exceptions import NonTextMessageReceived, WebSocketDisconnected
from strawberry.subscriptions.protocols.graphql_ws.types import (
CompleteMessage,
ConnectionAckMessage,
ConnectionErrorMessage,
ConnectionInitMessage,
ConnectionKeepAliveMessage,
ConnectionTerminateMessage,
DataMessage,
ErrorMessage,
OperationMessage,
StartMessage,
StopMessage,
Expand Down Expand Up @@ -93,15 +88,13 @@ async def handle_message(
async def handle_connection_init(self, message: ConnectionInitMessage) -> None:
payload = message.get("payload")
if payload is not None and not isinstance(payload, dict):
error_message: ConnectionErrorMessage = {"type": "connection_error"}
await self.websocket.send_json(error_message)
await self.send_message({"type": "connection_error"})
await self.websocket.close(code=1000, reason="")
return

self.connection_params = payload

connection_ack_message: ConnectionAckMessage = {"type": "connection_ack"}
await self.websocket.send_json(connection_ack_message)
await self.send_message({"type": "connection_ack"})

if self.keep_alive:
keep_alive_handler = self.handle_keep_alive()
Expand Down Expand Up @@ -139,8 +132,7 @@ async def handle_stop(self, message: StopMessage) -> None:
async def handle_keep_alive(self) -> None:
assert self.keep_alive_interval
while True:
data: ConnectionKeepAliveMessage = {"type": "ka"}
await self.websocket.send_json(data)
await self.send_message({"type": "ka"})
await asyncio.sleep(self.keep_alive_interval)

async def handle_async_results(
Expand All @@ -160,26 +152,22 @@ async def handle_async_results(
)
if isinstance(agen_or_err, PreExecutionError):
assert agen_or_err.errors
error_payload = agen_or_err.errors[0].formatted
error_message: ErrorMessage = {
"type": "error",
"id": operation_id,
"payload": error_payload,
}
await self.websocket.send_json(error_message)
await self.send_message(
{
"type": "error",
"id": operation_id,
"payload": agen_or_err.errors[0].formatted,
}
)
else:
self.subscriptions[operation_id] = agen_or_err

async for result in agen_or_err:
await self.send_data(result, operation_id)
await self.send_data_message(result, operation_id)

await self.websocket.send_json(
CompleteMessage({"type": "complete", "id": operation_id})
)
await self.send_message({"type": "complete", "id": operation_id})
except asyncio.CancelledError:
await self.websocket.send_json(
CompleteMessage({"type": "complete", "id": operation_id})
)
await self.send_message({"type": "complete", "id": operation_id})

async def cleanup_operation(self, operation_id: str) -> None:
if operation_id in self.subscriptions:
Expand All @@ -192,7 +180,7 @@ async def cleanup_operation(self, operation_id: str) -> None:
await self.tasks[operation_id]
del self.tasks[operation_id]

async def send_data(
async def send_data_message(
self, execution_result: ExecutionResult, operation_id: str
) -> None:
data_message: DataMessage = {
Expand All @@ -209,7 +197,10 @@ async def send_data(
if execution_result.extensions:
data_message["payload"]["extensions"] = execution_result.extensions

await self.websocket.send_json(data_message)
await self.send_message(data_message)

async def send_message(self, message: OperationMessage) -> None:
await self.websocket.send_json(message)


__all__ = ["BaseGraphQLWSHandler"]
1 change: 1 addition & 0 deletions strawberry/subscriptions/protocols/graphql_ws/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class ConnectionKeepAliveMessage(TypedDict):
DataMessage,
ErrorMessage,
CompleteMessage,
ConnectionKeepAliveMessage,
]


Expand Down
4 changes: 4 additions & 0 deletions tests/http/clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Message as GraphQLTransportWSMessage,
)
from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler
from strawberry.subscriptions.protocols.graphql_ws.types import OperationMessage
from strawberry.types import ExecutionResult

logger = logging.getLogger("strawberry.test.http_client")
Expand Down Expand Up @@ -307,6 +308,9 @@ async def __aiter__(self) -> AsyncGenerator[Message, None]:
async def send_message(self, message: GraphQLTransportWSMessage) -> None:
await self.send_json(message)

async def send_legacy_message(self, message: OperationMessage) -> None:
await self.send_json(message)


class DebuggableGraphQLTransportWSHandler(BaseGraphQLTransportWSHandler):
def on_init(self) -> None:
Expand Down
12 changes: 6 additions & 6 deletions tests/websockets/test_graphql_transport_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def assert_next(
async def test_unknown_message_type(ws_raw: WebSocketClient):
ws = ws_raw

await ws.send_message({"type": "NOT_A_MESSAGE_TYPE"}) # type: ignore
await ws.send_json({"type": "NOT_A_MESSAGE_TYPE"})

await ws.receive(timeout=2)
assert ws.closed
Expand All @@ -83,7 +83,7 @@ async def test_unknown_message_type(ws_raw: WebSocketClient):
async def test_missing_message_type(ws_raw: WebSocketClient):
ws = ws_raw

await ws.send_message({"notType": None}) # type: ignore
await ws.send_json({"notType": None})

await ws.receive(timeout=2)
assert ws.closed
Expand All @@ -92,7 +92,7 @@ async def test_missing_message_type(ws_raw: WebSocketClient):


async def test_parsing_an_invalid_message(ws: WebSocketClient):
await ws.send_message({"type": "subscribe", "notPayload": None}) # type: ignore
await ws.send_json({"type": "subscribe", "notPayload": None})

await ws.receive(timeout=2)
assert ws.closed
Expand Down Expand Up @@ -218,7 +218,7 @@ async def test_close_twice(

# We set payload is set to "invalid value" to force a invalid payload error
# which will close the connection
await ws.send_message({"type": "connection_init", "payload": "invalid value"}) # type: ignore
await ws.send_json({"type": "connection_init", "payload": "invalid value"})

# Yield control so that ._close can be called
await asyncio.sleep(0)
Expand Down Expand Up @@ -830,7 +830,7 @@ async def test_injects_connection_params(ws_raw: WebSocketClient):

async def test_rejects_connection_params_not_dict(ws_raw: WebSocketClient):
ws = ws_raw
await ws.send_message({"type": "connection_init", "payload": "gonna fail"}) # type: ignore
await ws.send_json({"type": "connection_init", "payload": "gonna fail"})

await ws.receive(timeout=2)
assert ws.closed
Expand All @@ -846,7 +846,7 @@ async def test_rejects_connection_params_with_wrong_type(
payload: object, ws_raw: WebSocketClient
):
ws = ws_raw
await ws.send_message({"type": "connection_init", "payload": payload}) # type: ignore
await ws.send_json({"type": "connection_init", "payload": payload})

await ws.receive(timeout=2)
assert ws.closed
Expand Down
Loading
Loading