Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use much stricter typings #39

Merged
merged 5 commits into from
Jan 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 119 additions & 70 deletions aiogqlc/client.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,42 @@
import asyncio
import json
from io import IOBase
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from types import TracebackType
from typing import (
AsyncContextManager,
AsyncGenerator,
Dict,
List,
Optional,
Tuple,
Type,
Union,
)

import aiohttp
import aiohttp.client

from aiogqlc.constants import (
GQL_COMPLETE,
GQL_CONNECTION_ACK,
GQL_CONNECTION_ERROR,
GQL_CONNECTION_INIT,
GQL_CONNECTION_KEEP_ALIVE,
GQL_CONNECTION_TERMINATE,
GQL_DATA,
GQL_ERROR,
GQL_START,
GQL_STOP,
GRAPHQL_WS,
)
from aiogqlc.constants import GRAPHQL_WS
from aiogqlc.errors import (
GraphQLWSConnectionError,
GraphQLWSOperationError,
GraphQLWSProtocolError,
)
from aiogqlc.types import (
ConnectionInitParams,
FilesToPathsMapping,
GraphQLWSConnectionInitMessage,
GraphQLWSConnectionTerminateMessage,
GraphQLWSDataMessagePayload,
GraphQLWSServerConnectionOperationMessage,
GraphQLWSServerExecutionOperationMessage,
GraphQLWSServerOperationMessage,
GraphQLWSStartMessage,
GraphQLWSStopMessage,
Payload,
Variables,
VariableValue,
)
from aiogqlc.utils import serialize_payload


Expand All @@ -32,16 +45,18 @@ def __init__(
self,
endpoint: str,
session: aiohttp.ClientSession,
connection_params: Optional[Dict[str, Any]] = None,
connection_params: Optional[ConnectionInitParams] = None,
) -> None:
self._endpoint = endpoint
self._session = session
self._connection_params = connection_params
self._last_operation_id = 0
self._ws_context: aiohttp.client._WSRequestContextManager
self._ws_context: AsyncContextManager[aiohttp.ClientWebSocketResponse]
self._ws: aiohttp.ClientWebSocketResponse
self._operations_message_queues: Dict[str, asyncio.Queue] = {}
self._connection_handler_task: asyncio.Task
self._operations_message_queues: Dict[
str, asyncio.Queue[GraphQLWSServerOperationMessage]
] = {}
self._connection_handler_task: asyncio.Task[None]

async def __aenter__(self) -> "GraphQLWSManager":
self._ws_context = self._session.ws_connect(
Expand All @@ -52,17 +67,22 @@ async def __aenter__(self) -> "GraphQLWSManager":
self._connection_handler_task = asyncio.create_task(self.handle_connection())
return self

async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
await self.terminate_connection()
await self._connection_handler_task
await self._ws_context.__aexit__(exc_type, exc_val, exc_tb)

async def subscribe(
self,
query: str,
variables: Optional[Dict[str, Any]] = None,
variables: Optional[Variables] = None,
operation: Optional[str] = None,
) -> AsyncGenerator[Dict[str, Any], None]:
) -> AsyncGenerator[GraphQLWSDataMessagePayload, None]:
operation_id = self.get_next_operation_id()
self._operations_message_queues[operation_id] = asyncio.Queue()

Expand All @@ -82,75 +102,94 @@ def get_next_operation_id(self) -> str:
self._last_operation_id += 1
return str(self._last_operation_id)

async def init_connection(self, params: Optional[Dict[str, Any]]) -> None:
await self._ws.send_json({"type": GQL_CONNECTION_INIT, "payload": params})
message = await self._ws.receive_json()
async def init_connection(self, params: Optional[ConnectionInitParams]) -> None:
connection_init_message: GraphQLWSConnectionInitMessage = {
"type": "connection_init",
"payload": params or {},
}
await self._ws.send_json(connection_init_message)

if message["type"] == GQL_CONNECTION_ACK:
message: GraphQLWSServerConnectionOperationMessage = (
await self._ws.receive_json()
)

if message["type"] == "connection_ack":
return
elif message["type"] == GQL_CONNECTION_ERROR:
raise GraphQLWSConnectionError(message.get("payload"))
else:
raise GraphQLWSProtocolError(message.get("payload"))

if message["type"] == "connection_error":
raise GraphQLWSConnectionError(message["payload"])

raise GraphQLWSProtocolError(message.get("payload"))

async def start_operation(
self,
operation_id: str,
query: str,
variables: Optional[Dict[str, Any]] = None,
variables: Optional[Variables] = None,
operation: Optional[str] = None,
) -> None:
payload = serialize_payload(query, variables, operation)
await self._ws.send_json(
{
"type": GQL_START,
"id": operation_id,
"payload": payload,
}
)
start_message: GraphQLWSStartMessage = {
"type": "start",
"id": operation_id,
"payload": payload,
}
await self._ws.send_json(start_message)

async def handle_connection(self) -> None:
async for message in self._ws: # type: aiohttp.WSMessage
if message.type != aiohttp.WSMsgType.TEXT:
async for ws_message in self._ws:
if ws_message.type != aiohttp.WSMsgType.TEXT:
continue

operation_message = message.json()
message: GraphQLWSServerOperationMessage = ws_message.json()

if operation_message.get("type") == GQL_CONNECTION_KEEP_ALIVE:
if message["type"] == "ka":
continue

if "id" in operation_message:
self.yield_operation_message(operation_message)
if (
message["type"] == "data"
or message["type"] == "error"
or message["type"] == "complete"
):
self.yield_operation_message(message)
continue

async def handle_operation(
self, operation_id: str
) -> AsyncGenerator[Dict[str, Any], None]:
) -> AsyncGenerator[GraphQLWSDataMessagePayload, None]:
while True:
operation_message = await self._operations_message_queues[
operation_id
].get()
payload = operation_message.get("payload")

if operation_message["type"] == GQL_DATA:
yield payload
if operation_message["type"] == "data":
yield operation_message["payload"]
continue

if operation_message["type"] == GQL_ERROR:
raise GraphQLWSOperationError(payload)
if operation_message["type"] == "error":
raise GraphQLWSOperationError(operation_message["payload"])

if operation_message["type"] == GQL_COMPLETE:
if operation_message["type"] == "complete":
return

def yield_operation_message(self, operation_message: dict) -> None:
def yield_operation_message(
self, operation_message: GraphQLWSServerExecutionOperationMessage
) -> None:
operation_id = operation_message["id"]
self._operations_message_queues[operation_id].put_nowait(operation_message)

async def stop_operation(self, operation_id: str) -> None:
await self._ws.send_json({"type": GQL_STOP, "id": operation_id})
stop_message: GraphQLWSStopMessage = {
"type": "stop",
"id": operation_id,
}
await self._ws.send_json(stop_message)

async def terminate_connection(self) -> None:
await self._ws.send_json({"type": GQL_CONNECTION_TERMINATE})
terminate_message: GraphQLWSConnectionTerminateMessage = {
"type": "connection_terminate"
}
await self._ws.send_json(terminate_message)


class GraphQLClient:
Expand All @@ -159,7 +198,7 @@ def __init__(self, endpoint: str, session: aiohttp.ClientSession) -> None:
self.session = session

def connect(
self, protocol: str = GRAPHQL_WS, params: Optional[Dict[str, Any]] = None
self, protocol: str = GRAPHQL_WS, params: Optional[ConnectionInitParams] = None
) -> GraphQLWSManager:
if protocol == GRAPHQL_WS:
return GraphQLWSManager(self.endpoint, self.session, params)
Expand All @@ -168,16 +207,19 @@ def connect(
async def execute(
self,
query: str,
variables: Optional[Dict[str, Any]] = None,
variables: Optional[Variables] = None,
operation: Optional[str] = None,
**kwargs,
**kwargs: Dict[str, object],
) -> aiohttp.ClientResponse:
nulled_variables, files_to_paths_mapping = self.prepare(variables)
data_param: Dict[str, Any]
data_param: Dict[str, Union[aiohttp.FormData, Payload]]

if files_to_paths_mapping:
form_data = self.prepare_multipart(
query, nulled_variables, files_to_paths_mapping, operation
query=query,
nulled_variables=nulled_variables,
files_to_paths_mapping=files_to_paths_mapping,
operation=operation,
)
data_param = {"data": form_data}
else:
Expand All @@ -190,20 +232,20 @@ async def execute(

@classmethod
def prepare(
cls, variables: Optional[Dict[str, Any]]
) -> Tuple[dict, Dict[IOBase, List[str]]]:
files_to_paths_mapping: Dict[IOBase, List[str]] = {}
cls, variables: Optional[Variables]
) -> Tuple[Optional[Variables], FilesToPathsMapping]:
files_to_paths_mapping: FilesToPathsMapping = {}

def separate_files(path: str, obj: object) -> Any:
def separate_files(path: str, obj: VariableValue) -> VariableValue:
if isinstance(obj, list):
nulled_list = []
for key, value in enumerate(obj):
value = separate_files(f"{path}.{key}", value)
nulled_list: List[VariableValue] = []
for index, value in enumerate(obj):
value = separate_files(f"{path}.{index}", value)
nulled_list.append(value)
return nulled_list

elif isinstance(obj, dict):
nulled_dict = {}
nulled_dict: Dict[str, VariableValue] = {}
for key, value in obj.items():
value = separate_files(f"{path}.{key}", value)
nulled_dict[key] = value
Expand All @@ -219,19 +261,26 @@ def separate_files(path: str, obj: object) -> Any:
else:
return obj

nulled_variables = separate_files("variables", variables)
if variables is None:
return None, files_to_paths_mapping

nulled_variables = dict(
(key, separate_files(f"variables.{key}", value))
for key, value in variables.items()
)

return nulled_variables, files_to_paths_mapping

@classmethod
def prepare_multipart(
cls,
query: str,
variables: Dict[str, Any],
files_to_paths_mapping: Dict[IOBase, List[str]],
nulled_variables: Optional[Variables],
files_to_paths_mapping: FilesToPathsMapping,
operation: Optional[str] = None,
) -> aiohttp.FormData:
form_data = aiohttp.FormData()
operations = serialize_payload(query, variables, operation)
operations = serialize_payload(query, nulled_variables, operation)

file_map = {
str(i): files_to_paths_mapping[file]
Expand Down
10 changes: 0 additions & 10 deletions aiogqlc/constants.py
Original file line number Diff line number Diff line change
@@ -1,11 +1 @@
GRAPHQL_WS = "graphql-ws"
GQL_CONNECTION_INIT = "connection_init"
GQL_CONNECTION_ACK = "connection_ack"
GQL_CONNECTION_ERROR = "connection_error"
GQL_CONNECTION_TERMINATE = "connection_terminate"
GQL_CONNECTION_KEEP_ALIVE = "ka"
GQL_START = "start"
GQL_DATA = "data"
GQL_ERROR = "error"
GQL_COMPLETE = "complete"
GQL_STOP = "stop"
2 changes: 1 addition & 1 deletion aiogqlc/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ class GraphQLWSError(Exception):


class GraphQLWSConnectionError(GraphQLWSError):
def __init__(self, payload: dict):
def __init__(self, payload: object):
self.payload = payload
super().__init__(payload)

Expand Down
Loading