Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Proposal: replace asyncio with anyio #165

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
433 changes: 277 additions & 156 deletions poetry.lock

Large diffs are not rendered by default.

13 changes: 8 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,13 @@ packages = [
]

[tool.poetry.dependencies]
python = "^3.6"
python = "^3.6,>=3.6.2"
typing-extensions = { version = "^4.0", python = "<3.8" }
anyio = "^3.5.0"
trio = { version = "^0.20.0", python = ">=3.7", optional=true}

[tool.poetry.dev-dependencies]
pytest = "^6.2"
pytest-asyncio = [
{version=">=0.17,<1", python = ">=3.7" },
{version=">=0.16,<0.17", python = "<3.7" },
]
pytest-benchmark = "^3.4"
pytest-cov = "^3.0"
pytest-describe = "^2.0"
Expand All @@ -66,6 +64,11 @@ sphinx_rtd_theme = ">=1,<2"
check-manifest = ">=0.47,<1"
bump2version = ">=1.0,<2"
tox = "^3.24"
trio-typing = "^0.7.0"

[tool.poetry.extras]
trio = ["trio"]


[tool.black]
target-version = ['py36', 'py37', 'py38', 'py39', 'py310']
Expand Down
66 changes: 50 additions & 16 deletions src/graphql/execution/execute.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from asyncio import ensure_future, gather
from anyio import create_task_group, ExceptionGroup
from asyncio import ensure_future, CancelledError
from sniffio import current_async_library
from collections.abc import Mapping
from inspect import isawaitable
from typing import (
Expand Down Expand Up @@ -444,13 +446,16 @@ def execute_fields(
# field, which is possibly a coroutine object. Return a coroutine object that
# will yield this same map, but with any coroutines awaited in parallel and
# replaced with the values they yielded.
async def await_field(field: str) -> None:
results[field] = await results[field]

async def get_results() -> Dict[str, Any]:
results.update(
zip(
awaitable_fields,
await gather(*(results[field] for field in awaitable_fields)),
)
)
try:
async with create_task_group() as tg:
for field in awaitable_fields:
tg.start_soon(await_field, field)
except ExceptionGroup as exc:
raise exc.exceptions[0]
return results

return get_results()
Expand Down Expand Up @@ -531,6 +536,10 @@ async def await_result() -> Any:
return await completed
return completed
except Exception as raw_error:
if isinstance(
raw_error, CancelledError
): # pragma: no cover (Python >= 3.8)
raise
error = located_error(raw_error, field_nodes, path.as_list())
self.handle_field_error(error, return_type)
return None
Expand All @@ -546,6 +555,10 @@ async def await_completed() -> Any:
try:
return await completed
except Exception as raw_error:
if isinstance(
raw_error, CancelledError
): # pragma: no cover (Python >= 3.8)
raise
error = located_error(raw_error, field_nodes, path.as_list())
self.handle_field_error(error, return_type)
return None
Expand Down Expand Up @@ -714,6 +727,10 @@ async def await_completed(item: Any, item_path: Path) -> Any:
return await completed
return completed
except Exception as raw_error:
if isinstance(
raw_error, CancelledError
): # pragma: no cover (Python >= 3.8)
raise
error = located_error(
raw_error, field_nodes, item_path.as_list()
)
Expand All @@ -732,6 +749,10 @@ async def await_completed(item: Any, item_path: Path) -> Any:
try:
return await item
except Exception as raw_error:
if isinstance(
raw_error, CancelledError
): # pragma: no cover (Python >= 3.8)
raise
error = located_error(
raw_error, field_nodes, item_path.as_list()
)
Expand All @@ -752,14 +773,16 @@ async def await_completed(item: Any, item_path: Path) -> Any:
return completed_results

# noinspection PyShadowingNames
async def await_index(index: int) -> None:
completed_results[index] = await completed_results[index]

async def get_completed_results() -> List[Any]:
for index, result in zip(
awaitable_indices,
await gather(
*(completed_results[index] for index in awaitable_indices)
),
):
completed_results[index] = result
try:
async with create_task_group() as tg:
for index in awaitable_indices:
tg.start_soon(await_index, index)
except ExceptionGroup as exc:
raise exc.exceptions[0]
return completed_results

return get_completed_results()
Expand Down Expand Up @@ -1096,7 +1119,8 @@ def execute_sync(

# Assert that the execution was synchronous.
if isawaitable(result):
ensure_future(cast(Awaitable[ExecutionResult], result)).cancel()
if current_async_library() == "asyncio":
ensure_future(cast(Awaitable[ExecutionResult], result)).cancel()
raise RuntimeError("GraphQL execution failed to complete synchronously.")

return cast(ExecutionResult, result)
Expand Down Expand Up @@ -1216,8 +1240,18 @@ def default_type_resolver(

if awaitable_is_type_of_results:
# noinspection PyShadowingNames
is_type_of_results = [None for _ in awaitable_is_type_of_results]

async def await_is_type_of_result(index: int) -> None:
is_type_of_results[index] = await awaitable_is_type_of_results[index]

async def get_type() -> Optional[str]:
is_type_of_results = await gather(*awaitable_is_type_of_results)
try:
async with create_task_group() as tg:
for index in range(len(awaitable_is_type_of_results)):
tg.start_soon(await_is_type_of_result, index)
except ExceptionGroup as exc:
raise exc.exceptions[0]
for is_type_of_result, type_ in zip(is_type_of_results, awaitable_types):
if is_type_of_result:
return type_.name
Expand Down
63 changes: 35 additions & 28 deletions src/graphql/execution/map_async_iterator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from asyncio import CancelledError, Event, Task, ensure_future, wait
from concurrent.futures import FIRST_COMPLETED
from anyio import (
get_cancelled_exc_class,
create_task_group,
Event,
CancelScope,
)
from anyio.abc import TaskGroup
from inspect import isasyncgen, isawaitable
from typing import cast, Any, AsyncIterable, Callable, Optional, Set, Type, Union
from typing import cast, Any, AsyncIterable, Callable, Optional, Type, Union
from types import TracebackType

__all__ = ["MapAsyncIterator"]
Expand All @@ -27,38 +32,40 @@ def __aiter__(self) -> "MapAsyncIterator":
"""Get the iterator object."""
return self

async def _wait_for_close(self, tg: TaskGroup) -> None:
await self._close_event.wait()
tg.cancel_scope.cancel()

async def __anext__(self) -> Any:
"""Get the next value of the iterator."""
if self.is_closed:
if not isasyncgen(self.iterator):
raise StopAsyncIteration
value = await self.iterator.__anext__()
else:
aclose = ensure_future(self._close_event.wait())
anext = ensure_future(self.iterator.__anext__())

close_evt = None
iterator_exc = None
try:
pending: Set[Task] = (
await wait([aclose, anext], return_when=FIRST_COMPLETED)
)[1]
except CancelledError:
# cancel underlying tasks and close
aclose.cancel()
anext.cancel()
await self.aclose()
raise # re-raise the cancellation

for task in pending:
task.cancel()

if aclose.done():
async with create_task_group() as tg:
# we need to store the current event, it could be reset
close_evt = self._close_event
tg.start_soon(self._wait_for_close, tg)
try:
value = await self.iterator.__anext__()
except BaseException as exc:
iterator_exc = exc
tg.cancel_scope.cancel()
except BaseException:
# We ignore this and use the iterator exception (if any)
pass
if close_evt is not None and close_evt.is_set():
# closed from outside via `is_closed=True / aclose`
raise StopAsyncIteration

error = anext.exception()
if error:
raise error

value = anext.result()
if iterator_exc is not None:
if isinstance(iterator_exc, get_cancelled_exc_class()):
with CancelScope(shield=True):
await self.aclose()
raise iterator_exc

result = self.callback(value)

Expand Down Expand Up @@ -111,5 +118,5 @@ def is_closed(self, value: bool) -> None:
"""Mark the iterator as closed."""
if value:
self._close_event.set()
else:
self._close_event.clear()
elif self._close_event.is_set():
self._close_event = Event()
3 changes: 3 additions & 0 deletions src/graphql/execution/subscribe.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from asyncio import CancelledError
from inspect import isawaitable
from typing import (
Any,
Expand Down Expand Up @@ -209,4 +210,6 @@ async def execute_subscription(context: ExecutionContext) -> AsyncIterable[Any]:

return event_stream
except Exception as error:
if isinstance(error, CancelledError): # pragma: no cover (Python >= 3.8)
raise
raise located_error(error, field_nodes, path.as_list())
4 changes: 3 additions & 1 deletion src/graphql/graphql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from asyncio import ensure_future
from sniffio import current_async_library
from inspect import isawaitable
from typing import Any, Awaitable, Callable, Dict, Optional, Union, Type, cast

Expand Down Expand Up @@ -143,7 +144,8 @@ def graphql_sync(

# Assert that the execution was synchronous.
if isawaitable(result):
ensure_future(cast(Awaitable[ExecutionResult], result)).cancel()
if current_async_library() == "asyncio":
ensure_future(cast(Awaitable[ExecutionResult], result)).cancel()
raise RuntimeError("GraphQL execution failed to complete synchronously.")

return cast(ExecutionResult, result)
Expand Down
3 changes: 3 additions & 0 deletions src/graphql/pyutils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .path import Path
from .print_path_list import print_path_list
from .simple_pub_sub import SimplePubSub, SimplePubSubIterator
from .broadcast_stream import MemoryObjectBroadcastStream, create_broadcast_stream
from .undefined import Undefined, UndefinedType

__all__ = [
Expand Down Expand Up @@ -60,6 +61,8 @@
"print_path_list",
"SimplePubSub",
"SimplePubSubIterator",
"MemoryObjectBroadcastStream",
"create_broadcast_stream",
"Undefined",
"UndefinedType",
]
Loading