diff --git a/pyproject.toml b/pyproject.toml index e63e7b1..393797d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ dev = [ "types-pyserial", "pytest", "pytest-cov", - "types-paho-mqtt", + "paho-mqtt >=2,<3" ] vision = ["opencv-python-headless >=4,<5"] -mqtt = ["paho-mqtt >=1.6,<2"] +mqtt = ["paho-mqtt >=2,<3"] diff --git a/sbot/mqtt.py b/sbot/mqtt.py index fe253e4..ad132b7 100644 --- a/sbot/mqtt.py +++ b/sbot/mqtt.py @@ -5,6 +5,7 @@ import logging import os from typing import Any, Callable, TypedDict +from urllib.parse import urlparse import paho.mqtt.client as mqtt @@ -19,7 +20,7 @@ def __init__( self, client_name: str | None = None, topic_prefix: str | None = None, - mqtt_version: int = mqtt.MQTTv5, + mqtt_version: mqtt.MQTTProtocolVersion = mqtt.MQTTProtocolVersion.MQTTv5, use_tls: bool | str = False, username: str = '', password: str = '', @@ -31,7 +32,11 @@ def __init__( self._client_name = client_name self._img_topic = 'img' - self._client = mqtt.Client(client_id=client_name, protocol=mqtt_version) + self._client = mqtt.Client( + callback_api_version=mqtt.CallbackAPIVersion.VERSION2, + client_id=client_name, + protocol=mqtt_version, + ) self._client.on_connect = self._on_connect if use_tls: @@ -53,8 +58,8 @@ def connect(self, host: str, port: int) -> None: return try: - self._client.connect(host, port, keepalive=60) - except (TimeoutError, ValueError, ConnectionRefusedError): + self._client.connect_async(host, port, keepalive=60) + except ValueError: LOGGER.error(f"Failed to connect to MQTT broker at {host}:{port}") return self._client.loop_start() @@ -149,12 +154,16 @@ def wrapped_publish( retain=retain, abs_topic=abs_topic) def _on_connect( - self, client: mqtt.Client, userdata: Any, flags: dict[str, int], rc: int, + self, + client: mqtt.Client, + userdata: Any, + connect_flags: mqtt.ConnectFlags, + reason_code: mqtt.ReasonCode, properties: mqtt.Properties | None = None, ) -> None: - if rc != mqtt.CONNACK_ACCEPTED: + if reason_code.is_failure: LOGGER.warning( - f"Failed to connect to MQTT broker. Return code: {mqtt.error_string(rc)}" + f"Failed to connect to MQTT broker. Return code: {reason_code.getName()}" # type: ignore[no-untyped-call] # noqa: E501 ) return @@ -181,36 +190,17 @@ def get_mqtt_variables() -> MQTTVariables: # url format: mqtt[s]://:@:/ mqtt_url = os.environ['SBOT_MQTT_URL'] - scheme, rest = mqtt_url.split('://', maxsplit=1) - # username and password are optional - try: - user_pass, host_port_topic = rest.rsplit('@', maxsplit=1) - except ValueError: - username, password = None, None - host_port_topic = rest - else: - try: - username, password = user_pass.split(':', maxsplit=1) - except ValueError: - # username can be supplied without password - username = user_pass - password = None - - host_port, topic_root = host_port_topic.split('/', maxsplit=1) - use_tls = (scheme == 'mqtts') - try: - host, port_str = host_port.split(':', maxsplit=1) - port = int(port_str) - except ValueError: - # use default port for scheme - host = host_port - port = 8883 if use_tls else 1883 + url_parts = urlparse(mqtt_url, allow_fragments=False) + use_tls = (url_parts.scheme == 'mqtts') + + if url_parts.hostname is None: + raise ValueError("MQTT URL is missing a hostname.") return MQTTVariables( - host=host, - port=port, - topic_prefix=topic_root, + host=url_parts.hostname, + port=url_parts.port or (8883 if use_tls else 1883), + topic_prefix=url_parts.path.lstrip('/'), use_tls=use_tls, - username=username, - password=password, + username=url_parts.username, + password=url_parts.password, )