Skip to content

Commit

Permalink
Merge pull request #344 from sourcebots/mqtt-rework
Browse files Browse the repository at this point in the history
MQTT updates
  • Loading branch information
WillB97 authored Jul 1, 2024
2 parents ec8b4a3 + ef65562 commit d87bf96
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 38 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ dev = [
"types-pyserial",
"pytest",
"pytest-cov",
"types-paho-mqtt",
"paho-mqtt >=2,<3"
]
vision = ["opencv-python-headless >=4,<5"]
mqtt = ["paho-mqtt >=1.6,<2"]
mqtt = ["paho-mqtt >=2,<3"]
62 changes: 26 additions & 36 deletions sbot/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import os
from typing import Any, Callable, TypedDict
from urllib.parse import urlparse

import paho.mqtt.client as mqtt

Expand All @@ -19,7 +20,7 @@ def __init__(
self,
client_name: str | None = None,
topic_prefix: str | None = None,
mqtt_version: int = mqtt.MQTTv5,
mqtt_version: mqtt.MQTTProtocolVersion = mqtt.MQTTProtocolVersion.MQTTv5,
use_tls: bool | str = False,
username: str = '',
password: str = '',
Expand All @@ -31,7 +32,11 @@ def __init__(
self._client_name = client_name
self._img_topic = 'img'

self._client = mqtt.Client(client_id=client_name, protocol=mqtt_version)
self._client = mqtt.Client(
callback_api_version=mqtt.CallbackAPIVersion.VERSION2,
client_id=client_name,
protocol=mqtt_version,
)
self._client.on_connect = self._on_connect

if use_tls:
Expand All @@ -53,8 +58,8 @@ def connect(self, host: str, port: int) -> None:
return

try:
self._client.connect(host, port, keepalive=60)
except (TimeoutError, ValueError, ConnectionRefusedError):
self._client.connect_async(host, port, keepalive=60)
except ValueError:
LOGGER.error(f"Failed to connect to MQTT broker at {host}:{port}")
return
self._client.loop_start()
Expand Down Expand Up @@ -149,12 +154,16 @@ def wrapped_publish(
retain=retain, abs_topic=abs_topic)

def _on_connect(
self, client: mqtt.Client, userdata: Any, flags: dict[str, int], rc: int,
self,
client: mqtt.Client,
userdata: Any,
connect_flags: mqtt.ConnectFlags,
reason_code: mqtt.ReasonCode,
properties: mqtt.Properties | None = None,
) -> None:
if rc != mqtt.CONNACK_ACCEPTED:
if reason_code.is_failure:
LOGGER.warning(
f"Failed to connect to MQTT broker. Return code: {mqtt.error_string(rc)}"
f"Failed to connect to MQTT broker. Return code: {reason_code.getName()}" # type: ignore[no-untyped-call] # noqa: E501
)
return

Expand All @@ -181,36 +190,17 @@ def get_mqtt_variables() -> MQTTVariables:
# url format: mqtt[s]://<username>:<password>@<host>:<port>/<topic_root>
mqtt_url = os.environ['SBOT_MQTT_URL']

scheme, rest = mqtt_url.split('://', maxsplit=1)
# username and password are optional
try:
user_pass, host_port_topic = rest.rsplit('@', maxsplit=1)
except ValueError:
username, password = None, None
host_port_topic = rest
else:
try:
username, password = user_pass.split(':', maxsplit=1)
except ValueError:
# username can be supplied without password
username = user_pass
password = None

host_port, topic_root = host_port_topic.split('/', maxsplit=1)
use_tls = (scheme == 'mqtts')
try:
host, port_str = host_port.split(':', maxsplit=1)
port = int(port_str)
except ValueError:
# use default port for scheme
host = host_port
port = 8883 if use_tls else 1883
url_parts = urlparse(mqtt_url, allow_fragments=False)
use_tls = (url_parts.scheme == 'mqtts')

if url_parts.hostname is None:
raise ValueError("MQTT URL is missing a hostname.")

return MQTTVariables(
host=host,
port=port,
topic_prefix=topic_root,
host=url_parts.hostname,
port=url_parts.port or (8883 if use_tls else 1883),
topic_prefix=url_parts.path.lstrip('/'),
use_tls=use_tls,
username=username,
password=password,
username=url_parts.username,
password=url_parts.password,
)

0 comments on commit d87bf96

Please sign in to comment.