Skip to content

Commit

Permalink
Internal: Moved open_listener_sockets_from_getaddrinfo_result out of …
Browse files Browse the repository at this point in the history
…asyncio backend
  • Loading branch information
francis-clairicia committed Jun 29, 2024
1 parent 6a6d005 commit 852b38d
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 222 deletions.
61 changes: 61 additions & 0 deletions src/easynetwork/lowlevel/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,67 @@ def set_reuseport(sock: SupportsSocketOptions) -> None:
raise ValueError("reuse_port not supported by socket module, SO_REUSEPORT defined but not implemented.") from None


def open_listener_sockets_from_getaddrinfo_result(
infos: Iterable[tuple[int, int, int, str, tuple[Any, ...]]],
*,
backlog: int | None,
reuse_address: bool,
reuse_port: bool,
) -> list[_socket.socket]:
sockets: list[_socket.socket] = []
reuse_address = reuse_address and hasattr(_socket, "SO_REUSEADDR")
with contextlib.ExitStack() as _whole_context_stack:
errors: list[OSError] = []
_whole_context_stack.callback(errors.clear)

socket_exit_stack = _whole_context_stack.enter_context(contextlib.ExitStack())

for af, socktype, proto, _, sa in infos:
try:
sock = socket_exit_stack.enter_context(contextlib.closing(_socket.socket(af, socktype, proto)))
except OSError:
# Assume it's a bad family/type/protocol combination.
continue
sockets.append(sock)
if reuse_address:
try:
sock.setsockopt(_socket.SOL_SOCKET, _socket.SO_REUSEADDR, True)
except OSError:
# Will fail later on bind()
pass
if reuse_port:
set_reuseport(sock)
# Disable IPv4/IPv6 dual stack support (enabled by
# default on Linux) which makes a single socket
# listen on both address families.
if af == _socket.AF_INET6:
if hasattr(_socket, "IPPROTO_IPV6"):
sock.setsockopt(_socket.IPPROTO_IPV6, _socket.IPV6_V6ONLY, True)
if "%" in sa[0]:
addr, scope_id = sa[0].split("%", 1)
sa = (addr, sa[1], 0, int(scope_id))
try:
sock.bind(sa)
except OSError as exc:
errors.append(
OSError(
exc.errno, f"error while attempting to bind to address {sa!r}: {exc.strerror.lower()}"
).with_traceback(exc.__traceback__)
)
continue
if backlog is not None:
sock.listen(backlog)

if errors:
# No need to call errors.clear(), this is done by exit stack
raise ExceptionGroup("Error when trying to create listeners", errors)

# There were no errors, therefore do not close the sockets
socket_exit_stack.pop_all()

return sockets


def exception_with_notes(exc: _T_Exception, notes: str | Iterable[str]) -> _T_Exception:
if isinstance(notes, str):
notes = (notes,)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,19 @@
"create_connection",
"create_datagram_connection",
"ensure_resolved",
"open_listener_sockets_from_getaddrinfo_result",
"resolve_local_addresses",
"wait_until_readable",
"wait_until_writable",
]

import asyncio
import contextlib
import itertools
import math
import socket as _socket
from collections import OrderedDict
from collections.abc import Iterable, Sequence
from collections.abc import Sequence
from typing import Any, cast

from .... import _utils


async def ensure_resolved(
host: str | None,
Expand Down Expand Up @@ -285,67 +281,6 @@ async def create_datagram_connection(
)


def open_listener_sockets_from_getaddrinfo_result(
infos: Iterable[tuple[int, int, int, str, tuple[Any, ...]]],
*,
backlog: int | None,
reuse_address: bool,
reuse_port: bool,
) -> list[_socket.socket]:
sockets: list[_socket.socket] = []
reuse_address = reuse_address and hasattr(_socket, "SO_REUSEADDR")
with contextlib.ExitStack() as _whole_context_stack:
errors: list[OSError] = []
_whole_context_stack.callback(errors.clear)

socket_exit_stack = _whole_context_stack.enter_context(contextlib.ExitStack())

for af, socktype, proto, _, sa in infos:
try:
sock = socket_exit_stack.enter_context(contextlib.closing(_socket.socket(af, socktype, proto)))
except OSError:
# Assume it's a bad family/type/protocol combination.
continue
sockets.append(sock)
if reuse_address:
try:
sock.setsockopt(_socket.SOL_SOCKET, _socket.SO_REUSEADDR, True)
except OSError:
# Will fail later on bind()
pass
if reuse_port:
_utils.set_reuseport(sock)
# Disable IPv4/IPv6 dual stack support (enabled by
# default on Linux) which makes a single socket
# listen on both address families.
if af == _socket.AF_INET6:
if hasattr(_socket, "IPPROTO_IPV6"):
sock.setsockopt(_socket.IPPROTO_IPV6, _socket.IPV6_V6ONLY, True)
if "%" in sa[0]:
addr, scope_id = sa[0].split("%", 1)
sa = (addr, sa[1], 0, int(scope_id))
try:
sock.bind(sa)
except OSError as exc:
errors.append(
OSError(
exc.errno, f"error while attempting to bind to address {sa!r}: {exc.strerror.lower()}"
).with_traceback(exc.__traceback__)
)
continue
if backlog is not None:
sock.listen(backlog)

if errors:
# No need to call errors.clear(), this is done by exit stack
raise ExceptionGroup("Error when trying to create listeners", errors)

# There were no errors, therefore do not close the sockets
socket_exit_stack.pop_all()

return sockets


def wait_until_readable(sock: _socket.socket, loop: asyncio.AbstractEventLoop) -> asyncio.Future[None]:
def on_fut_done(f: asyncio.Future[None]) -> None:
loop.remove_reader(sock)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ async def create_tcp_listeners(
if not isinstance(backlog, int):
raise TypeError("backlog: Expected an integer")

from ._asyncio_utils import open_listener_sockets_from_getaddrinfo_result, resolve_local_addresses
from ._asyncio_utils import resolve_local_addresses
from .stream.listener import AcceptedSocketFactory, ListenerSocketAdapter

reuse_address: bool = os.name not in ("nt", "cygwin") and sys.platform != "cygwin"
Expand All @@ -172,7 +172,7 @@ async def create_tcp_listeners(
_socket.SOCK_STREAM,
)

sockets: list[_socket.socket] = open_listener_sockets_from_getaddrinfo_result(
sockets: list[_socket.socket] = _utils.open_listener_sockets_from_getaddrinfo_result(
infos,
backlog=backlog,
reuse_address=reuse_address,
Expand Down Expand Up @@ -216,7 +216,7 @@ async def create_udp_listeners(
*,
reuse_port: bool = False,
) -> Sequence[AsyncDatagramListener[tuple[Any, ...]]]:
from ._asyncio_utils import open_listener_sockets_from_getaddrinfo_result, resolve_local_addresses
from ._asyncio_utils import resolve_local_addresses
from .datagram.listener import DatagramListenerProtocol, DatagramListenerSocketAdapter

loop = self.__asyncio.get_running_loop()
Expand All @@ -237,7 +237,7 @@ async def create_udp_listeners(
_socket.SOCK_DGRAM,
)

sockets: list[_socket.socket] = open_listener_sockets_from_getaddrinfo_result(
sockets: list[_socket.socket] = _utils.open_listener_sockets_from_getaddrinfo_result(
infos,
backlog=None,
reuse_address=False,
Expand Down
18 changes: 9 additions & 9 deletions tests/unit_test/test_async/test_asyncio_backend/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ async def test____create_tcp_listeners____open_listener_sockets(
return_value=addrinfo_list,
)
mock_open_listeners = mocker.patch(
f"{_ASYNCIO_BACKEND_MODULE}._asyncio_utils.open_listener_sockets_from_getaddrinfo_result",
"easynetwork.lowlevel._utils.open_listener_sockets_from_getaddrinfo_result",
return_value=[mock_tcp_socket],
)
mock_ListenerSocketAdapter: MagicMock = mocker.patch(
Expand Down Expand Up @@ -387,7 +387,7 @@ async def test____create_tcp_listeners____bind_to_any_interfaces(
return_value=addrinfo_list,
)
mock_open_listeners = mocker.patch(
f"{_ASYNCIO_BACKEND_MODULE}._asyncio_utils.open_listener_sockets_from_getaddrinfo_result",
"easynetwork.lowlevel._utils.open_listener_sockets_from_getaddrinfo_result",
return_value=[mock_tcp_socket, mock_tcp_socket],
)
mock_ListenerSocketAdapter: MagicMock = mocker.patch(
Expand Down Expand Up @@ -457,7 +457,7 @@ async def test____create_tcp_listeners____bind_to_several_hosts(
side_effect=[[info] for info in addrinfo_list],
)
mock_open_listeners = mocker.patch(
f"{_ASYNCIO_BACKEND_MODULE}._asyncio_utils.open_listener_sockets_from_getaddrinfo_result",
"easynetwork.lowlevel._utils.open_listener_sockets_from_getaddrinfo_result",
return_value=[mock_tcp_socket, mock_tcp_socket],
)
mock_ListenerSocketAdapter: MagicMock = mocker.patch(
Expand Down Expand Up @@ -514,7 +514,7 @@ async def test____create_tcp_listeners____error_getaddrinfo_returns_empty_list(
return_value=[],
)
mock_open_listeners = mocker.patch(
f"{_ASYNCIO_BACKEND_MODULE}._asyncio_utils.open_listener_sockets_from_getaddrinfo_result",
"easynetwork.lowlevel._utils.open_listener_sockets_from_getaddrinfo_result",
side_effect=AssertionError,
)
mock_ListenerSocketAdapter: MagicMock = mocker.patch(
Expand Down Expand Up @@ -559,7 +559,7 @@ async def test____create_tcp_listeners____invalid_backlog(
return_value=[],
)
mock_open_listeners = mocker.patch(
f"{_ASYNCIO_BACKEND_MODULE}._asyncio_utils.open_listener_sockets_from_getaddrinfo_result",
"easynetwork.lowlevel._utils.open_listener_sockets_from_getaddrinfo_result",
side_effect=AssertionError,
)
mock_ListenerSocketAdapter: MagicMock = mocker.patch(
Expand Down Expand Up @@ -682,7 +682,7 @@ async def test____create_udp_listeners____open_listener_sockets(
return_value=addrinfo_list,
)
mock_open_listeners = mocker.patch(
f"{_ASYNCIO_BACKEND_MODULE}._asyncio_utils.open_listener_sockets_from_getaddrinfo_result",
"easynetwork.lowlevel._utils.open_listener_sockets_from_getaddrinfo_result",
return_value=[mock_udp_socket],
)
mock_create_datagram_endpoint: AsyncMock = mocker.patch.object(
Expand Down Expand Up @@ -761,7 +761,7 @@ async def test____create_udp_listeners____bind_to_local_interfaces(
return_value=addrinfo_list,
)
mock_open_listeners = mocker.patch(
f"{_ASYNCIO_BACKEND_MODULE}._asyncio_utils.open_listener_sockets_from_getaddrinfo_result",
"easynetwork.lowlevel._utils.open_listener_sockets_from_getaddrinfo_result",
return_value=[mock_udp_socket, mock_udp_socket],
)
mock_create_datagram_endpoint: AsyncMock = mocker.patch.object(
Expand Down Expand Up @@ -844,7 +844,7 @@ async def test____create_udp_listeners____bind_to_several_hosts(
return_value=addrinfo_list,
)
mock_open_listeners = mocker.patch(
f"{_ASYNCIO_BACKEND_MODULE}._asyncio_utils.open_listener_sockets_from_getaddrinfo_result",
"easynetwork.lowlevel._utils.open_listener_sockets_from_getaddrinfo_result",
return_value=[mock_udp_socket, mock_udp_socket],
)
mock_create_datagram_endpoint: AsyncMock = mocker.patch.object(
Expand Down Expand Up @@ -911,7 +911,7 @@ async def test____create_udp_listeners____error_getaddrinfo_returns_empty_list(
return_value=[],
)
mock_open_listeners = mocker.patch(
f"{_ASYNCIO_BACKEND_MODULE}._asyncio_utils.open_listener_sockets_from_getaddrinfo_result",
"easynetwork.lowlevel._utils.open_listener_sockets_from_getaddrinfo_result",
side_effect=AssertionError,
)
mock_create_datagram_endpoint: AsyncMock = mocker.patch.object(
Expand Down
Loading

0 comments on commit 852b38d

Please sign in to comment.