diff --git a/mypy.ini b/mypy.ini index 118002257..033a7f7e9 100644 --- a/mypy.ini +++ b/mypy.ini @@ -19,4 +19,4 @@ disallow_untyped_decorators = true disallow_untyped_defs = true warn_return_any = true warn_unreachable = true -files = pychromecast/config.py, pychromecast/const.py, pychromecast/dial.py, pychromecast/error.py, pychromecast/models.py, pychromecast/response_handler.py +files = pychromecast/config.py, pychromecast/const.py, pychromecast/dial.py, pychromecast/discovery.py, pychromecast/error.py, pychromecast/models.py, pychromecast/response_handler.py diff --git a/pychromecast/discovery.py b/pychromecast/discovery.py index c9f988975..1c6272855 100644 --- a/pychromecast/discovery.py +++ b/pychromecast/discovery.py @@ -1,8 +1,12 @@ """Discovers Chromecasts on the network using mDNS/zeroconf.""" +from __future__ import annotations + import abc +from collections.abc import Callable import functools import itertools import logging +import ssl import threading import time from uuid import UUID @@ -28,7 +32,7 @@ class AbstractCastListener(abc.ABC): """Listener for discovering chromecasts.""" @abc.abstractmethod - def add_cast(self, uuid, service): + def add_cast(self, uuid: UUID, service: str) -> None: """A cast has been discovered. uuid: The cast's uuid, this is the dictionary key to find @@ -37,7 +41,7 @@ def add_cast(self, uuid, service): """ @abc.abstractmethod - def remove_cast(self, uuid, service, cast_info): + def remove_cast(self, uuid: UUID, service: str, cast_info: CastInfo) -> None: """A cast has been removed, meaning there are no longer any known services. uuid: The cast's uuid @@ -46,7 +50,7 @@ def remove_cast(self, uuid, service, cast_info): """ @abc.abstractmethod - def update_cast(self, uuid, service): + def update_cast(self, uuid: UUID, service: str) -> None: """A cast has been updated. uuid: The cast's uuid @@ -54,7 +58,9 @@ def update_cast(self, uuid, service): """ -def _is_blocked_from_host_browser(item, block_list, item_type): +def _is_blocked_from_host_browser( + item: str, block_list: list[str], item_type: str +) -> bool: for blocked_prefix in block_list: if item.startswith(blocked_prefix): _LOGGER.debug("%s %s is blocked from host based polling", item_type, item) @@ -62,7 +68,7 @@ def _is_blocked_from_host_browser(item, block_list, item_type): return False -def _is_model_blocked_from_host_browser(model): +def _is_model_blocked_from_host_browser(model: str) -> bool: return _is_blocked_from_host_browser( model, HOST_BROWSER_BLOCKED_MODEL_PREFIXES, "Model" ) @@ -71,36 +77,47 @@ def _is_model_blocked_from_host_browser(model): class SimpleCastListener(AbstractCastListener): """Helper for backwards compatibility.""" - def __init__(self, add_callback=None, remove_callback=None, update_callback=None): + def __init__( + self, + add_callback: Callable[[UUID, str], None] | None = None, + remove_callback: Callable[[UUID, str, CastInfo], None] | None = None, + update_callback: Callable[[UUID, str], None] | None = None, + ): self._add_callback = add_callback self._remove_callback = remove_callback self._update_callback = update_callback - def add_cast(self, uuid, service): + def add_cast(self, uuid: UUID, service: str) -> None: if self._add_callback: self._add_callback(uuid, service) - def remove_cast(self, uuid, service, cast_info): + def remove_cast(self, uuid: UUID, service: str, cast_info: CastInfo) -> None: if self._remove_callback: self._remove_callback(uuid, service, cast_info) - def update_cast(self, uuid, service): + def update_cast(self, uuid: UUID, service: str) -> None: if self._update_callback: self._update_callback(uuid, service) -class ZeroConfListener: +class ZeroConfListener(zeroconf.ServiceListener): """Listener for ZeroConf service browser.""" - def __init__(self, cast_listener, devices, host_browser, lock): + def __init__( + self, + cast_listener: AbstractCastListener, + devices: dict[UUID, CastInfo], + host_browser: HostBrowser, + lock: threading.Lock, + ) -> None: self._cast_listener = cast_listener self._devices = devices self._host_browser = host_browser self._services_lock = lock - def remove_service(self, _zconf, typ, name): + def remove_service(self, zc: zeroconf.Zeroconf, type_: str, name: str) -> None: """Called by zeroconf when an mDNS service is lost.""" - _LOGGER.debug("remove_service %s, %s", typ, name) + _LOGGER.debug("remove_service %s, %s", type_, name) cast_info = None device_removed = False uuid = None @@ -116,7 +133,7 @@ def remove_service(self, _zconf, typ, name): break if not cast_info: - _LOGGER.debug("remove_service unknown %s, %s", typ, name) + _LOGGER.debug("remove_service unknown %s, %s", type_, name) return if device_removed: @@ -124,18 +141,24 @@ def remove_service(self, _zconf, typ, name): else: self._cast_listener.update_cast(uuid, name) - def update_service(self, zconf, typ, name): + def update_service(self, zc: zeroconf.Zeroconf, type_: str, name: str) -> None: """Called by zeroconf when an mDNS service is updated.""" - _LOGGER.debug("update_service %s, %s", typ, name) - self._add_update_service(zconf, typ, name, self._cast_listener.update_cast) + _LOGGER.debug("update_service %s, %s", type_, name) + self._add_update_service(zc, type_, name, self._cast_listener.update_cast) - def add_service(self, zconf, typ, name): + def add_service(self, zc: zeroconf.Zeroconf, type_: str, name: str) -> None: """Called by zeroconf when an mDNS service is discovered.""" - _LOGGER.debug("add_service %s, %s", typ, name) - self._add_update_service(zconf, typ, name, self._cast_listener.add_cast) + _LOGGER.debug("add_service %s, %s", type_, name) + self._add_update_service(zc, type_, name, self._cast_listener.add_cast) # pylint: disable-next=too-many-locals - def _add_update_service(self, zconf, typ, name, callback): + def _add_update_service( + self, + zconf: zeroconf.Zeroconf, + typ: str, + name: str, + callback: Callable[[UUID, str], None], + ) -> None: """Add or update a service.""" service = None tries = 0 @@ -161,25 +184,38 @@ def _add_update_service(self, zconf, typ, name, callback): _LOGGER.debug("_add_update_service failed to add %s, %s", typ, name) return - def get_value(key): + if service.port is None: + _LOGGER.debug("_add_update_service port is None") + return + + def get_value(key: str) -> str | None: """Retrieve value and decode to UTF-8.""" value = service.properties.get(key.encode("utf-8")) - if value is None or isinstance(value, str): + # zeroconf would keep str version of cached items, this check + # can be removed if we pin zeroconf to a version where this is + # removed. + if value is None or isinstance(value, str): # type: ignore[unreachable] return value return value.decode("utf-8") addresses = service.parsed_addresses() host = addresses[0] if addresses else service.server + if host is None: + _LOGGER.debug( + "_add_update_service failed to get host for %s, %s", typ, name + ) + return + # Store the host, in case mDNS stops working self._host_browser.add_hosts([host]) friendly_name = get_value("fn") - model_name = get_value("md") - uuid = get_value("id") + model_name = get_value("md") or "Unknown model name" + uuid_str = get_value("id") - if not uuid: + if not uuid_str: _LOGGER.debug( "_add_update_service failed to get uuid for %s, %s", typ, name ) @@ -187,7 +223,7 @@ def get_value(key): # Ignore incorrect UUIDs from third-party Chromecast emulators try: - uuid = UUID(uuid) + uuid = UUID(uuid_str) except ValueError: _LOGGER.debug( "_add_update_service failed due to bad uuid for %s, %s, model %s", @@ -201,6 +237,8 @@ def get_value(key): # Lock because the HostBrowser may also add or remove items with self._services_lock: + cast_type: str | None + manufacturer: str | None if service.port != 8009: cast_type = CAST_TYPE_GROUP manufacturer = MF_GOOGLE @@ -240,7 +278,7 @@ def get_value(key): class HostStatus: """Status of known host.""" - def __init__(self): + def __init__(self) -> None: self.failcount = 0 self.no_polling = False @@ -252,25 +290,30 @@ def __init__(self): class HostBrowser(threading.Thread): """Repeateadly poll a set of known hosts.""" - def __init__(self, cast_listener, devices, lock): + def __init__( + self, + cast_listener: AbstractCastListener, + devices: dict[UUID, CastInfo], + lock: threading.Lock, + ) -> None: super().__init__(daemon=True) self._cast_listener = cast_listener self._devices = devices - self._known_hosts = {} + self._known_hosts: dict[str, HostStatus] = {} self._next_update = time.time() self._services_lock = lock self._start_requested = False - self._context = None + self._context: ssl.SSLContext | None = None self.stop = threading.Event() - def add_hosts(self, known_hosts): + def add_hosts(self, known_hosts: list[str]) -> None: """Add a list of known hosts to the set.""" for host in known_hosts: if host not in self._known_hosts: _LOGGER.debug("Addded host %s", host) self._known_hosts[host] = HostStatus() - def update_hosts(self, known_hosts): + def update_hosts(self, known_hosts: list[str] | None) -> None: """Update the set of known hosts. Note: Removed hosts will no longer be polled, but services of any associated @@ -286,7 +329,7 @@ def update_hosts(self, known_hosts): _LOGGER.debug("Removed host %s", host) self._known_hosts.pop(host) - def run(self): + def run(self) -> None: """Start worker thread.""" _LOGGER.debug("HostBrowser thread started") self._context = get_ssl_context() @@ -300,12 +343,12 @@ def run(self): raise _LOGGER.debug("HostBrowser thread done") - def _poll_hosts(self): + def _poll_hosts(self) -> None: # Iterate over a copy because other threads may modify the known_hosts list known_hosts = list(self._known_hosts.keys()) for host in known_hosts: - devices = [] - uuids = [] + devices: list[tuple[int, str, str, UUID, str, str]] = [] + uuids: list[UUID] = [] if self.stop.is_set(): break try: @@ -323,12 +366,17 @@ def _poll_hosts(self): if not device_status: hoststatus.failcount += 1 if hoststatus.failcount == HOSTLISTENER_MAX_FAIL: + # We can't contact the host, drop all its devices and UUIDs self._update_devices(host, devices, uuids) hoststatus.failcount = min( hoststatus.failcount, HOSTLISTENER_MAX_FAIL + 1 ) continue + if not device_status.uuid: + _LOGGER.debug("host %s does not report UUID", host) + continue + if ( device_status.cast_type != CAST_TYPE_AUDIO or _is_model_blocked_from_host_browser(device_status.model_name) @@ -369,7 +417,7 @@ def _poll_hosts(self): # ports of dynamic groups are not present in the eureka_info reply. if group.host and group.host not in self._known_hosts: self.add_hosts([group.host]) - if group.port is None or group.host != host: + if group.port is None or group.uuid is None or group.host != host: continue devices.append( ( @@ -385,8 +433,13 @@ def _poll_hosts(self): self._update_devices(host, devices, uuids) - def _update_devices(self, host, devices, host_uuids): - callbacks = [] + def _update_devices( + self, + host: str, + devices: list[tuple[int, str, str, UUID, str, str]], + host_uuids: list[UUID], + ) -> None: + callbacks: list[Callable[[], None]] = [] # Lock because the ZeroConfListener may also add or remove items with self._services_lock: @@ -424,15 +477,15 @@ def _update_devices(self, host, devices, host_uuids): def _add_host_service( self, - host, - port, - friendly_name, - model_name, - uuid, - callbacks, - cast_type, - manufacturer, - ): + host: str, + port: int, + friendly_name: str, + model_name: str, + uuid: UUID, + callbacks: list[Callable[[], None]], + cast_type: str, + manufacturer: str, + ) -> None: service_info = HostServiceInfo(host, port) callback = self._cast_listener.add_cast @@ -477,10 +530,14 @@ def _add_host_service( _LOGGER.debug( "Host %s (%s) up, adding or updating host based service", name, uuid ) - if callback: - callbacks.append(functools.partial(callback, uuid, name)) + callbacks.append(functools.partial(callback, uuid, name)) - def _remove_host_service(self, host, uuid, callbacks): + def _remove_host_service( + self, + host: str, + uuid: UUID, + callbacks: list[Callable[[], None]], + ) -> None: if uuid not in self._devices: return @@ -519,11 +576,16 @@ class CastBrowser: instance is passed, a new instance will be created. """ - def __init__(self, cast_listener, zeroconf_instance=None, known_hosts=None): + def __init__( + self, + cast_listener: AbstractCastListener, + zeroconf_instance: zeroconf.Zeroconf | None = None, + known_hosts: list[str] | None = None, + ) -> None: self._cast_listener = cast_listener self.zc = zeroconf_instance # pylint: disable=invalid-name - self._zc_browser = None - self.devices = {} + self._zc_browser: zeroconf.ServiceBrowser | None = None + self.devices: dict[UUID, CastInfo] = {} self.services = self.devices # For backwards compatibility self._services_lock = threading.Lock() self.host_browser = HostBrowser( @@ -536,17 +598,17 @@ def __init__(self, cast_listener, zeroconf_instance=None, known_hosts=None): self.host_browser.add_hosts(known_hosts) @property - def count(self): + def count(self) -> int: """Number of discovered cast devices.""" return len(self.devices) - def set_zeroconf_instance(self, zeroconf_instance): + def set_zeroconf_instance(self, zeroconf_instance: zeroconf.Zeroconf) -> None: """Set zeroconf_instance.""" if self.zc: return self.zc = zeroconf_instance - def start_discovery(self): + def start_discovery(self) -> None: """ This method will start discovering chromecasts on separate threads. When a chromecast is discovered, callback will be called with the @@ -565,7 +627,7 @@ def start_discovery(self): ) self.host_browser.start() - def stop_discovery(self): + def stop_discovery(self) -> None: """Stop the chromecast discovery threads.""" if self._zc_browser: try: @@ -584,7 +646,12 @@ class CastListener(CastBrowser): Deprecated as of February 2021, will be removed in June 2024. """ - def __init__(self, add_callback=None, remove_callback=None, update_callback=None): + def __init__( + self, + add_callback: Callable[[UUID, str], None] | None = None, + remove_callback: Callable[[UUID, str, CastInfo], None] | None = None, + update_callback: Callable[[UUID, str], None] | None = None, + ): _LOGGER.info( "CastListener is deprecated and will be removed in June 2024, update to use CastBrowser instead" ) @@ -592,7 +659,9 @@ def __init__(self, add_callback=None, remove_callback=None, update_callback=None super().__init__(listener) -def start_discovery(cast_browser, zeroconf_instance): +def start_discovery( + cast_browser: CastBrowser, zeroconf_instance: zeroconf.Zeroconf +) -> CastBrowser: """Start discovering chromecasts on the network. Deprecated as of February 2021, will be removed in June 2024. @@ -605,7 +674,7 @@ def start_discovery(cast_browser, zeroconf_instance): return cast_browser -def stop_discovery(cast_browser): +def stop_discovery(cast_browser: CastBrowser) -> None: """Stop the chromecast discovery threads. Deprecated as of February 2021, will be removed in June 2024. @@ -617,8 +686,11 @@ def stop_discovery(cast_browser): def discover_chromecasts( - max_devices=None, timeout=DISCOVER_TIMEOUT, zeroconf_instance=None, known_hosts=None -): + max_devices: int | None = None, + timeout: float = DISCOVER_TIMEOUT, + zeroconf_instance: zeroconf.Zeroconf | None = None, + known_hosts: list[str] | None = None, +) -> tuple[list[CastInfo], CastBrowser]: """ Discover chromecasts on the network. @@ -637,7 +709,7 @@ def discover_chromecasts( "discover_chromecasts is deprecated and will be removed in June 2024, update to use CastBrowser instead." ) - def add_callback(_uuid, _service): + def add_callback(_uuid: UUID, _service: str) -> None: """Called when a new chromecast has been discovered.""" if max_devices is not None and browser.count >= max_devices: discover_complete.set() @@ -654,12 +726,12 @@ def add_callback(_uuid, _service): def discover_listed_chromecasts( - friendly_names=None, - uuids=None, - discovery_timeout=DISCOVER_TIMEOUT, - zeroconf_instance=None, - known_hosts=None, -): + friendly_names: list[str] | None = None, + uuids: list[UUID] | None = None, + discovery_timeout: float = DISCOVER_TIMEOUT, + zeroconf_instance: zeroconf.Zeroconf | None = None, + known_hosts: list[str] | None = None, +) -> tuple[list[CastInfo], CastBrowser]: """ Searches the network for chromecast devices matching a list of friendly names or a list of UUIDs. @@ -677,12 +749,12 @@ def discover_listed_chromecasts( :param zeroconf_instance: An existing zeroconf instance. """ - cc_list = {} + cc_list: dict[UUID, CastInfo] = {} - def add_callback(uuid, service): + def add_callback(uuid: UUID, service: str) -> None: _LOGGER.debug("Got cast %s, %s", uuid, service) - service = browser.devices[uuid] - friendly_name = service[3] + cast_info = browser.devices[uuid] + friendly_name = cast_info.friendly_name if uuids and uuid in uuids: cc_list[uuid] = browser.devices[uuid] uuids.remove(uuid)