Skip to content

Commit

Permalink
Low-level API: Lazy asyncio import (#313)
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia authored Jun 29, 2024
1 parent 8bf26c1 commit 6a6d005
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 70 deletions.
112 changes: 69 additions & 43 deletions src/easynetwork/lowlevel/api_async/backend/_asyncio/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

__all__ = ["AsyncIOBackend"]

import asyncio
import asyncio.base_events
import contextvars
import functools
import math
Expand All @@ -32,21 +30,9 @@

from .... import _utils
from ....constants import HAPPY_EYEBALLS_DELAY as _DEFAULT_HAPPY_EYEBALLS_DELAY
from ...transports.abc import AsyncDatagramListener, AsyncDatagramTransport, AsyncListener, AsyncStreamTransport
from .. import _sniffio_helpers
from ..abc import AsyncBackend as AbstractAsyncBackend, ILock, TaskInfo
from ._asyncio_utils import (
create_connection,
create_datagram_connection,
open_listener_sockets_from_getaddrinfo_result,
resolve_local_addresses,
)
from .datagram.endpoint import create_datagram_endpoint
from .datagram.listener import DatagramListenerSocketAdapter
from .datagram.socket import AsyncioTransportDatagramSocketAdapter
from .stream.listener import AcceptedSocketFactory, ListenerSocketAdapter
from .stream.socket import AsyncioTransportStreamSocketAdapter, StreamReaderBufferedProtocol
from .tasks import CancelScope, TaskGroup, TaskUtils
from .threads import ThreadsPortal
from ..abc import AsyncBackend as AbstractAsyncBackend, CancelScope, ICondition, IEvent, ILock, TaskGroup, TaskInfo, ThreadsPortal

_P = ParamSpec("_P")
_T = TypeVar("_T")
Expand All @@ -55,48 +41,70 @@


class AsyncIOBackend(AbstractAsyncBackend):
__slots__ = ()
__slots__ = (
"__asyncio",
"__coro_yield",
"__cancel_shielded_coro_yield",
"__cancel_shielded_await",
)

def __init__(self) -> None:
import asyncio

from .tasks import TaskUtils

self.__asyncio = asyncio

self.__coro_yield = TaskUtils.coro_yield
self.__cancel_shielded_coro_yield = TaskUtils.cancel_shielded_coro_yield
self.__cancel_shielded_await = TaskUtils.cancel_shielded_await

def bootstrap(
self,
coro_func: Callable[[*_T_PosArgs], Coroutine[Any, Any, _T]],
*args: *_T_PosArgs,
runner_options: Mapping[str, Any] | None = None,
) -> _T:
with asyncio.Runner(**(runner_options or {})) as runner:
with self.__asyncio.Runner(**(runner_options or {})) as runner:
return runner.run(coro_func(*args))

async def coro_yield(self) -> None:
await TaskUtils.coro_yield()
await self.__coro_yield()

async def cancel_shielded_coro_yield(self) -> None:
await TaskUtils.cancel_shielded_coro_yield()
await self.__cancel_shielded_coro_yield()

def get_cancelled_exc_class(self) -> type[BaseException]:
return asyncio.CancelledError
return self.__asyncio.CancelledError

async def ignore_cancellation(self, coroutine: Awaitable[_T_co]) -> _T_co:
return await TaskUtils.cancel_shielded_await(coroutine)
return await self.__cancel_shielded_await(coroutine)

def open_cancel_scope(self, *, deadline: float = math.inf) -> CancelScope:
from .tasks import CancelScope

return CancelScope(deadline=deadline)

def current_time(self) -> float:
loop = asyncio.get_running_loop()
loop = self.__asyncio.get_running_loop()
return loop.time()

async def sleep(self, delay: float) -> None:
await asyncio.sleep(delay)
await self.__asyncio.sleep(delay)

async def sleep_forever(self) -> NoReturn:
loop = asyncio.get_running_loop()
loop = self.__asyncio.get_running_loop()
await loop.create_future()
raise AssertionError("await an unused future cannot end in any other way than by cancellation")

def create_task_group(self) -> TaskGroup:
from .tasks import TaskGroup

return TaskGroup()

def get_current_task(self) -> TaskInfo:
from .tasks import TaskUtils

current_task = TaskUtils.current_asyncio_task()
return TaskUtils.create_task_info(current_task)

Expand All @@ -107,10 +115,12 @@ async def create_tcp_connection(
*,
local_address: tuple[str, int] | None = None,
happy_eyeballs_delay: float | None = None,
) -> AsyncioTransportStreamSocketAdapter:
) -> AsyncStreamTransport:
if happy_eyeballs_delay is None:
happy_eyeballs_delay = _DEFAULT_HAPPY_EYEBALLS_DELAY

from ._asyncio_utils import create_connection

socket = await create_connection(
host,
port,
Expand All @@ -120,9 +130,11 @@ async def create_tcp_connection(

return await self.wrap_stream_socket(socket)

async def wrap_stream_socket(self, socket: _socket.socket) -> AsyncioTransportStreamSocketAdapter:
async def wrap_stream_socket(self, socket: _socket.socket) -> AsyncStreamTransport:
from .stream.socket import AsyncioTransportStreamSocketAdapter, StreamReaderBufferedProtocol

socket.setblocking(False)
loop = asyncio.get_running_loop()
loop = self.__asyncio.get_running_loop()
transport, protocol = await loop.create_connection(
_utils.make_callback(StreamReaderBufferedProtocol, loop=loop),
sock=socket,
Expand All @@ -136,10 +148,13 @@ async def create_tcp_listeners(
backlog: int,
*,
reuse_port: bool = False,
) -> Sequence[ListenerSocketAdapter[AsyncioTransportStreamSocketAdapter]]:
) -> Sequence[AsyncListener[AsyncStreamTransport]]:
if not isinstance(backlog, int):
raise TypeError("backlog: Expected an integer")

from ._asyncio_utils import open_listener_sockets_from_getaddrinfo_result, resolve_local_addresses
from .stream.listener import AcceptedSocketFactory, ListenerSocketAdapter

reuse_address: bool = os.name not in ("nt", "cygwin") and sys.platform != "cygwin"
hosts: Sequence[str | None]
if host == "" or host is None:
Expand All @@ -165,7 +180,8 @@ async def create_tcp_listeners(
)

factory = AcceptedSocketFactory()
return [ListenerSocketAdapter(self, sock, factory) for sock in sockets]
listeners = [ListenerSocketAdapter(self, sock, factory) for sock in sockets]
return listeners

async def create_udp_endpoint(
self,
Expand All @@ -174,7 +190,9 @@ async def create_udp_endpoint(
*,
local_address: tuple[str, int] | None = None,
family: int = _socket.AF_UNSPEC,
) -> AsyncioTransportDatagramSocketAdapter:
) -> AsyncDatagramTransport:
from ._asyncio_utils import create_datagram_connection

socket = await create_datagram_connection(
remote_host,
remote_port,
Expand All @@ -183,7 +201,10 @@ async def create_udp_endpoint(
)
return await self.wrap_connected_datagram_socket(socket)

async def wrap_connected_datagram_socket(self, socket: _socket.socket) -> AsyncioTransportDatagramSocketAdapter:
async def wrap_connected_datagram_socket(self, socket: _socket.socket) -> AsyncDatagramTransport:
from .datagram.endpoint import create_datagram_endpoint
from .datagram.socket import AsyncioTransportDatagramSocketAdapter

socket.setblocking(False)
endpoint = await create_datagram_endpoint(sock=socket)
return AsyncioTransportDatagramSocketAdapter(self, endpoint)
Expand All @@ -194,10 +215,11 @@ async def create_udp_listeners(
port: int,
*,
reuse_port: bool = False,
) -> Sequence[DatagramListenerSocketAdapter]:
from .datagram.listener import DatagramListenerProtocol
) -> Sequence[AsyncDatagramListener[tuple[Any, ...]]]:
from ._asyncio_utils import open_listener_sockets_from_getaddrinfo_result, resolve_local_addresses
from .datagram.listener import DatagramListenerProtocol, DatagramListenerSocketAdapter

loop = asyncio.get_running_loop()
loop = self.__asyncio.get_running_loop()

hosts: Sequence[str | None]
if host == "" or host is None:
Expand Down Expand Up @@ -226,22 +248,24 @@ async def create_udp_listeners(
listeners = [await loop.create_datagram_endpoint(protocol_factory, sock=sock) for sock in sockets]
return [DatagramListenerSocketAdapter(self, transport, protocol) for transport, protocol in listeners]

def create_lock(self) -> asyncio.Lock:
return asyncio.Lock()
def create_lock(self) -> ILock:
return self.__asyncio.Lock()

def create_event(self) -> asyncio.Event:
return asyncio.Event()
def create_event(self) -> IEvent:
return self.__asyncio.Event()

def create_condition_var(self, lock: ILock | None = None) -> asyncio.Condition:
def create_condition_var(self, lock: ILock | None = None) -> ICondition:
if lock is not None:
assert isinstance(lock, asyncio.Lock) # nosec assert_used
assert isinstance(lock, self.__asyncio.Lock) # nosec assert_used

return asyncio.Condition(lock)
return self.__asyncio.Condition(lock)

async def run_in_thread(self, func: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwargs) -> _T:
loop = asyncio.get_running_loop()
loop = self.__asyncio.get_running_loop()
ctx = contextvars.copy_context()

from .tasks import TaskUtils

_sniffio_helpers.setup_sniffio_contextvar(ctx, None)

future = loop.run_in_executor(None, functools.partial(ctx.run, func, *args, **kwargs))
Expand All @@ -251,4 +275,6 @@ async def run_in_thread(self, func: Callable[_P, _T], /, *args: _P.args, **kwarg
del future

def create_threads_portal(self) -> ThreadsPortal:
from .threads import ThreadsPortal

return ThreadsPortal()
Loading

0 comments on commit 6a6d005

Please sign in to comment.