Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for happy eyeballs (RFC 8305) #349

Merged
merged 18 commits into from
Dec 12, 2023
99 changes: 69 additions & 30 deletions aiohomekit/controller/ip/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

import asyncio
import logging
import socket
from typing import TYPE_CHECKING, Any

import aiohappyeyeballs
from async_interrupt import interrupt

from aiohomekit.crypto.chacha20poly1305 import (
Expand Down Expand Up @@ -49,6 +51,22 @@
logger = logging.getLogger(__name__)


def _convert_hosts_to_addr_infos(
hosts: list[str], port: int
) -> list[aiohappyeyeballs.AddrInfoType]:
"""Converts the list of hosts to a list of addr_infos.
The list of hosts is the result of a DNS lookup. The list of
addr_infos is the result of a call to `socket.getaddrinfo()`.
"""
addr_infos: list[aiohappyeyeballs.AddrInfoType] = []
for host in hosts:
is_ipv6 = ":" in host
family = socket.AF_INET6 if is_ipv6 else socket.AF_INET
addr = (host, port, 0, 0) if is_ipv6 else (host, port)
addr_infos.append((family, socket.SOCK_STREAM, socket.IPPROTO_TCP, host, addr))
return addr_infos


class ConnectionReady(Exception):
"""Raised when a connection is ready to be retried."""

Expand All @@ -58,7 +76,6 @@ class InsecureHomeKitProtocol(asyncio.Protocol):

def __init__(self, connection: HomeKitConnection) -> None:
self.connection = connection
self.host = ":".join((connection.host, str(connection.port)))
self.result_cbs: list[asyncio.Future[HttpResponse]] = []
self.current_response = HttpResponse()
self.loop = asyncio.get_running_loop()
Expand Down Expand Up @@ -218,10 +235,10 @@ def data_received(self, data: bytes) -> None:

class HomeKitConnection:
def __init__(
self, owner: IpPairing, host: str, port: int, concurrency_limit: int = 1
self, owner: IpPairing, hosts: list[str], port: int, concurrency_limit: int = 1
) -> None:
self.owner = owner
self.host = host
self.hosts = hosts
self.port = port

self.closing: bool = False
Expand All @@ -241,13 +258,15 @@ def __init__(
self._concurrency_limit = asyncio.Semaphore(concurrency_limit)
self._reconnect_future: asyncio.Future[None] | None = None
self._last_connector_error: Exception | None = None
self.connected_host: str | None = None
self.host_header: str | None = None

@property
def name(self) -> str:
"""Return the name of the connection."""
if self.owner:
return self.owner.name
return f"{self.host}:{self.port}"
return f"{self.connected_host or self.hosts}:{self.port}"

@property
def is_connected(self) -> bool:
Expand Down Expand Up @@ -472,12 +491,9 @@ async def request(
"Connection lost before request could be sent"
)

buffer = []
buffer.append(f"{method.upper()} {target} HTTP/1.1")

# WARNING: It is vital that a Host: header is present or some devices
# will reject the request.
buffer.append(f"Host: {self.host}")
buffer = [f"{method.upper()} {target} HTTP/1.1", self.host_header]

if headers:
for header, value in headers:
Expand All @@ -502,7 +518,7 @@ async def request(
async with self._concurrency_limit:
if not self.protocol:
raise AccessoryDisconnectedError("Tried to send while not connected")
logger.debug("%s: raw request: %r", self.host, request_bytes)
logger.debug("%s: raw request: %r", self.connected_host, request_bytes)
resp = await self.protocol.send_bytes(request_bytes)

if resp.code >= 400 and resp.code <= 499:
Expand All @@ -512,7 +528,7 @@ async def request(
response=resp,
)

logger.debug("%s: raw response: %r", self.host, resp.body)
logger.debug("%s: raw response: %r", self.connected_host, resp.body)

return resp

Expand Down Expand Up @@ -550,20 +566,41 @@ async def _connect_once(self) -> None:
"""_connect_once must only ever be called from _reconnect to ensure its done with a lock."""
loop = asyncio.get_event_loop()

logger.debug("Attempting connection to %s:%s", self.host, self.port)

try:
async with asyncio_timeout(10):
self.transport, self.protocol = await loop.create_connection(
lambda: InsecureHomeKitProtocol(self), self.host, self.port
)

except asyncio.TimeoutError:
raise TimeoutError("Timeout")
logger.debug("Attempting connection to %s:%s", self.hosts, self.port)

except OSError as e:
raise ConnectionError(str(e))
addr_infos = _convert_hosts_to_addr_infos(self.hosts, self.port)

last_exception: Exception | None = None
sock: socket.socket | None = None
interleave = 1
while addr_infos:
try:
async with asyncio_timeout(10):
sock = await aiohappyeyeballs.start_connection(
addr_infos,
happy_eyeballs_delay=0.25,
interleave=interleave,
loop=self._loop,
)
break
except (OSError, asyncio.TimeoutError) as err:
last_exception = err
aiohappyeyeballs.pop_addr_infos_interleave(addr_infos, interleave)

if sock is None:
if isinstance(last_exception, asyncio.TimeoutError):
raise TimeoutError("Timeout") from last_exception
raise ConnectionError(str(last_exception)) from last_exception

self.transport, self.protocol = await loop.create_connection(
lambda: InsecureHomeKitProtocol(self), sock=sock
)
connected_host = sock.getpeername()[0]
self.connected_host = connected_host
if ":" in connected_host:
self.host_header = f"Host: [{connected_host}]:{self.port}"
else:
self.host_header = f"Host: {connected_host}:{self.port}"
if self.owner:
await self.owner.connection_made(False)

Expand All @@ -582,7 +619,7 @@ async def _reconnect(self) -> None:
async with self._connect_lock:
interval = 0.5

logger.debug("Starting reconnect loop to %s:%s", self.host, self.port)
logger.debug("Starting reconnect loop to %s:%s", self.hosts, self.port)

while not self.closing:
self._last_connector_error = None
Expand Down Expand Up @@ -638,7 +675,7 @@ def event_received(self, event: HttpResponse) -> None:
self.owner.event_received(parsed)

def __repr__(self) -> str:
return f"HomeKitConnection(host={self.host!r}, port={self.port!r})"
return f"HomeKitConnection(host={(self.connected_host or self.hosts)!r}, port={self.port!r})"


class SecureHomeKitConnection(HomeKitConnection):
Expand All @@ -647,7 +684,7 @@ class SecureHomeKitConnection(HomeKitConnection):
def __init__(self, owner: IpPairing, pairing_data: dict[str, Any]) -> None:
super().__init__(
owner,
pairing_data["AccessoryIP"],
pairing_data.get("AccessoryIPs", [pairing_data["AccessoryIP"]]),
pairing_data["AccessoryPort"],
)
self.pairing_data = pairing_data
Expand All @@ -663,14 +700,14 @@ async def _connect_once(self):
if self.owner and self.owner.description:
pairing = self.owner
try:
if self.host != pairing.description.address:
if set(self.hosts) != set(pairing.description.addresses):
logger.debug(
"%s: Host changed from %s to %s",
pairing.name,
self.host,
pairing.description.address,
self.hosts,
pairing.description.addresses,
)
self.host = pairing.description.address
self.hosts = pairing.description.addresses

if self.port != pairing.description.port:
logger.debug(
Expand Down Expand Up @@ -714,7 +751,9 @@ async def _connect_once(self):

self.is_secure = True

logger.debug("Secure connection to %s:%s established", self.host, self.port)
logger.debug(
"Secure connection to %s:%s established", self.connected_host, self.port
)

if self.owner:
await self.owner.connection_made(True)
5 changes: 4 additions & 1 deletion aiohomekit/controller/ip/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ class IpDiscovery(ZeroconfDiscovery):
def __init__(self, controller, description: HomeKitService):
super().__init__(description)
self.controller = controller
self.connection = HomeKitConnection(None, description.address, description.port)
self.connection = HomeKitConnection(
None, description.addresses, description.port
)

def __repr__(self):
return f"IPDiscovery(host={self.description.address}, port={self.description.port})"
Expand Down Expand Up @@ -92,6 +94,7 @@ async def finish_pairing(pin: str) -> IpPairing:
break

pairing["AccessoryIP"] = self.description.address
pairing["AccessoryIPs"] = self.description.addresses
pairing["AccessoryPort"] = self.description.port
pairing["Connection"] = "IP"

Expand Down
6 changes: 4 additions & 2 deletions aiohomekit/controller/ip/pairing.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,11 @@ def poll_interval(self) -> timedelta:
@property
def name(self) -> str:
"""Return the name of the pairing with the address."""
connection = self.connection
host = connection.connected_host or connection.hosts
if self.description:
return f"{self.description.name} [{self.connection.host}:{self.connection.port}] (id={self.id})"
return f"[{self.connection.host}:{self.connection.port}] (id={self.id})"
return f"{self.description.name} [{host}:{connection.port}] (id={self.id})"
return f"[{host}:{connection.port}] (id={self.id})"

def event_received(self, event):
self._callback_listeners(format_characteristic_list(event))
Expand Down
22 changes: 13 additions & 9 deletions aiohomekit/zeroconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,16 @@ def from_service_info(cls, service: AsyncServiceInfo) -> HomeKitService:
# This means the first address will always be the most recently added
# address of the given IP version.
#
for ip_addr in addresses:
if not ip_addr.is_link_local and not ip_addr.is_unspecified:
address = str(ip_addr)
break
if not address:
valid_addresses = [
str(ip_addr)
for ip_addr in addresses
if not ip_addr.is_link_local and not ip_addr.is_unspecified
]
if not valid_addresses:
raise ValueError(
"Invalid HomeKit Zeroconf record: Missing non-link-local or unspecified address"
)
address = valid_addresses[0]

props: dict[str, str] = {
k.decode("utf-8").lower(): v.decode("utf-8")
Expand All @@ -118,7 +120,7 @@ def from_service_info(cls, service: AsyncServiceInfo) -> HomeKitService:
protocol_version=props.get("pv", "1.0"),
type=service.type,
address=address,
addresses=[str(ip_addr) for ip_addr in addresses],
addresses=valid_addresses,
port=service.port,
)

Expand All @@ -127,13 +129,13 @@ class ZeroconfServiceListener(ServiceListener):
"""An empty service listener."""

def add_service(self, zc: Zeroconf, type_: str, name: str) -> None:
pass
"""A service has been added."""

def remove_service(self, zc: Zeroconf, type_: str, name: str) -> None:
pass
"""A service has been removed."""

def update_service(self, zc: Zeroconf, type_: str, name: str) -> None:
pass
"""A service has been updated."""


def find_brower_for_hap_type(azc: AsyncZeroconf, hap_type: str) -> AsyncServiceBrowser:
Expand All @@ -158,6 +160,8 @@ def _update_from_discovery(self, description: HomeKitService):


class ZeroconfPairing(AbstractPairing):
description: HomeKitService

def _async_endpoint_changed(self) -> None:
"""The IP and/or port of the accessory has changed."""
pass
Expand Down
Loading
Loading