Skip to content

Commit

Permalink
Enable MQTT clusters (#115)
Browse files Browse the repository at this point in the history
* Added support for multiple MQTT brokers to enable clustering

* Bump version to 4.5.0
PatrickBaus authored Jan 21, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 35b220c commit 38c4ddf
Showing 4 changed files with 139 additions and 66 deletions.
2 changes: 1 addition & 1 deletion _version.py
Original file line number Diff line number Diff line change
@@ -2,4 +2,4 @@
Kraken version information.
"""

__version__ = "4.4.0"
__version__ = "4.5.0"
13 changes: 5 additions & 8 deletions kraken_ethernet.py
Original file line number Diff line number Diff line change
@@ -110,14 +110,11 @@ def load_database_parameters():
@staticmethod
def load_mqtt_parameters() -> MQTTParams:
"""Loads MQTT broker parameters from env variables"""
result: MQTTParams = {
"hostname": config("MQTT_HOST"),
"port": config("MQTT_PORT", cast=int, default=1883),
"username": load_secret("MQTT_CLIENT_USER", default=None),
"password": load_secret("MQTT_CLIENT_PASSWORD", default=None),
}

return result
return MQTTParams(
hosts=config("MQTT_HOST"),
username=load_secret("MQTT_CLIENT_USER", default=None),
password=load_secret("MQTT_CLIENT_PASSWORD", default=None),
)

async def run(self): # pylint: disable=too-many-locals
"""
146 changes: 89 additions & 57 deletions managers.py
Original file line number Diff line number Diff line change
@@ -7,15 +7,17 @@
from __future__ import annotations

import asyncio
import itertools
import logging
import re
from contextlib import AsyncExitStack
from typing import Any, TypedDict, Unpack
from typing import Any, Unpack
from uuid import UUID

import aiomqtt
import simplejson as json
from aiostream import pipe, stream
from pydantic import BaseModel, field_validator

from async_event_bus import TopicNotRegisteredError, event_bus
from data_types import DataEvent
@@ -29,14 +31,68 @@
MQTT_DATA_TOPIC = "sensors/{driver}/{uid}/{sid}"


class MQTTParams(TypedDict):
"""Parameters used for the MQTT broker."""

hostname: str
port: int
# A regular expression to match a hostname with an optional port.
# It adheres to RFC 1035 (https://www.rfc-editor.org/rfc/rfc1035) and matches ports
# between 0-65535.
HOSTNAME_REGEX = (
r"^((?=.{1,255}$)[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?(?:\.[0-9A-Za-z](?:(?:["
r"0-9A-Za-z]|-){0,61}[0-9A-Za-z])?)*\.?)(?:\:([0-9]{1,4}|[1-5][0-9]{4}|6[0-4][0-9]{3}|65[0-4][0-9]{"
r"2}|655[0-2][0-9]|6553[0-5]))?$"
)


class MQTTParams(BaseModel):
"""
Parameters used to connect to the MQTT broker.
Parameters
----------
hosts: List of Tuple of str and int
A list of host:port tuples. The list contains the servers of a cluster. If no port is provided it defaults to
1883. If port number 0 is provided the default value of 1883 is used.
identifier: str or None
An MQTT client id used to uniquely identify a client to persist messages.
username: str or None
The username used for authentication. Set to None if no username is required
password: str or None
The password used for authentication. Set to None if no username is required
"""

hosts: list[tuple[str, int]]
identifier: str | None
username: str | None
password: str | None

@field_validator("hosts", mode="before")
@classmethod
def ensure_list_of_hosts(cls, value: str) -> list[tuple[str, int]]:
"""
Parse
Parameters
----------
value: str
Either a single hostname:port string or a comma separated list of hostname:port strings.
Returns
-------
list of tuple of str and int
A list of (hostname, port) tuples.
"""
hosts = value.split(",")
result = []
for host in hosts:
host = host.strip()
match = re.search(HOSTNAME_REGEX, host)
if match is None:
raise ValueError(f"'{value}' is not a valid hostname or list of hostnames.")
result.append(
(
match.group(1),
int(match.group(2)) if match.group(2) and not match.group(2) == "0" else 1883,
)
)
return result


class MqttManager:
"""This manager will take the sensor data from the event_bus backend and publish them onto the MQTT network"""
@@ -95,7 +151,9 @@ def _calculate_timeout(last_reconnect_attempt: float, reconnect_interval: float)
"""
return max(0.0, reconnect_interval - (asyncio.get_running_loop().time() - last_reconnect_attempt))

def _log_mqtt_error_code(self, worker_name: str, error_code: str | int, previous_error_code: str | int) -> None:
def _log_mqtt_error_code(
self, worker_name: str, host: tuple[str, int], error_code: str | int, previous_error_code: str | int
) -> None:
"""
Log the MQTT error codes as human-readable errors to the logger (error log level). If the code is unknown, log
it as an exception for debugging. Suppresses errors, if they are repeated.
@@ -114,50 +172,29 @@ def _log_mqtt_error_code(self, worker_name: str, error_code: str | int, previous
self.__logger.error(
"Worker (%s): Connection refused by MQTT broker (%s:%i). Retrying.",
worker_name,
self.__broker["hostname"],
self.__broker["port"],
*host,
)
elif error_code == 113:
self.__logger.error(
"Worker (%s): MQTT broker (%s:%i) is unreachable. Retrying.",
worker_name,
self.__broker["hostname"],
self.__broker["port"],
)
self.__logger.error("Worker (%s): MQTT broker (%s:%i) is unreachable. Retrying.", worker_name, *host)
elif error_code == 7:
self.__logger.error(
"Worker (%s): The connection to MQTT broker (%s:%i) was lost. Retrying.",
worker_name,
self.__broker["hostname"],
self.__broker["port"],
"Worker (%s): The connection to MQTT broker (%s:%i) was lost. Retrying.", worker_name, *host
)
elif error_code == -2:
self.__logger.error(
"Worker (%s): Failure in name resolution of MQTT broker (%s:%i). Retrying.",
worker_name,
self.__broker["hostname"],
self.__broker["port"],
"Worker (%s): Failure in name resolution of MQTT broker (%s:%i). Retrying.", worker_name, *host
)
elif error_code == -3:
self.__logger.error(
"Worker (%s): Temporary failure in name resolution of MQTT broker (%s:%i). Retrying.",
worker_name,
self.__broker["hostname"],
self.__broker["port"],
*host,
)
elif error_code == -5:
self.__logger.error(
"Worker (%s): Unknown host name of MQTT broker (%s:%i). Retrying.",
worker_name,
self.__broker["hostname"],
self.__broker["port"],
)
self.__logger.error("Worker (%s): Unknown host name of MQTT broker (%s:%i). Retrying.", worker_name, *host)
elif error_code == "timed out":
self.__logger.error(
"Worker (%s): The connection to MQTT broker (%s:%i) timed out. Retrying.",
worker_name,
self.__broker["hostname"],
self.__broker["port"],
"Worker (%s): The connection to MQTT broker (%s:%i) timed out. Retrying.", worker_name, *host
)
else:
self.__logger.exception("Worker (%s): MQTT connection error (code: %s). Retrying.", worker_name, error_code)
@@ -184,37 +221,30 @@ async def consumer( # pylint: disable=too-many-branches
error_code: str | int = 0 # 0 = success
event: tuple[str, dict[str, str | float | int]] | None = None
previous_reconnect_attempt = asyncio.get_running_loop().time() - reconnect_interval
while "not connected":
for host in itertools.cycle(self.__broker.hosts): # iterate over the list of hostnames until the end of time
# Wait for at least reconnect_interval before connecting again
timeout = self._calculate_timeout(previous_reconnect_attempt, reconnect_interval)
if round(timeout) > 0:
# Do not print '0 s' as this is confusing.
self.__logger.info(
"Worker (%s): Connecting to MQTT broker (%s:%i) in %.0f s due to rate limiting.",
worker_name,
self.__broker["hostname"],
self.__broker["port"],
*host,
timeout,
)
else:
self.__logger.info(
"Worker (%s): Connecting to MQTT broker (%s:%i).",
worker_name,
self.__broker["hostname"],
self.__broker["port"],
)
self.__logger.info("Worker (%s): Connecting to MQTT broker (%s:%i).", worker_name, *host)
await asyncio.sleep(timeout)
previous_reconnect_attempt = asyncio.get_running_loop().time()
try:
async with aiomqtt.Client(
identifier=f"Labkraken-{self.__node_id}_worker-{worker_name}",
**self.__broker,
hostname=host[0],
port=host[1],
**self.__broker.model_dump(exclude={"hosts"}),
) as mqtt_client:
self.__logger.info(
"Worker (%s): Successfully connected to MQTT broker (%s:%i).",
worker_name,
self.__broker["hostname"],
self.__broker["port"],
"Worker (%s): Successfully connected to MQTT broker (%s:%i).", worker_name, *host
)
while "loop not cancelled":
if event is None:
@@ -240,30 +270,32 @@ async def consumer( # pylint: disable=too-many-branches
input_queue.task_done()
error_code = 0 # 0 = success
except aiomqtt.MqttCodeError as exc:
self._log_mqtt_error_code(worker_name, error_code=exc.rc, previous_error_code=error_code)
self._log_mqtt_error_code(worker_name, host=host, error_code=exc.rc, previous_error_code=error_code)
error_code = exc.rc
except ConnectionRefusedError:
self._log_mqtt_error_code(
worker_name, error_code=111, previous_error_code=error_code
worker_name, host=host, error_code=111, previous_error_code=error_code
) # Connection refused is code 111
error_code = 111
except aiomqtt.MqttError as exc:
error = re.search(r"\[Errno ([+-]?\d+)]", str(exc))
if error is not None:
self._log_mqtt_error_code(
worker_name, error_code=int(error.group(1)), previous_error_code=error_code
worker_name,
host=host,
error_code=int(error.group(1)),
previous_error_code=error_code,
)
error_code = int(error.group(1))
else: # no match found
self._log_mqtt_error_code(worker_name, error_code=str(exc), previous_error_code=error_code)
self._log_mqtt_error_code(
worker_name, host=host, error_code=str(exc), previous_error_code=error_code
)
error_code = str(exc)
except Exception: # pylint: disable=broad-except
# Catch all exceptions, log them, then try to restart the worker.
self.__logger.exception(
"Worker (%s): Error while publishing data to MQTT broker (%s:%i). Reconnecting.",
worker_name,
self.__broker["hostname"],
self.__broker["port"],
"Worker (%s): Error while publishing data to MQTT broker (%s:%i). Reconnecting.", worker_name, *host
)

async def cancel_tasks(self, tasks: set[asyncio.Task]) -> None:
44 changes: 44 additions & 0 deletions tests/test_env_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
Tests for the input environment variable parser.
"""

import pytest
from pydantic import ValidationError

from managers import MQTTParams


@pytest.mark.parametrize(
["hosts", "result"],
[
["example.com", [("example.com", 1883)]],
["example.com:1234", [("example.com", 1234)]],
["example1.com,example2.com", [("example1.com", 1883), ("example2.com", 1883)]],
["example1.com, example2.com", [("example1.com", 1883), ("example2.com", 1883)]],
["example1.com:1234,example2.com", [("example1.com", 1234), ("example2.com", 1883)]],
["example1.com:1234,example2.com:0", [("example1.com", 1234), ("example2.com", 1883)]],
],
)
def test_mqtt_hostnames_pass(hosts, result):
"""
Test parsing hostname(s) from env variables.
"""
env_params = MQTTParams(hosts=hosts, identifier="foo", username="bar", password="12345")

assert env_params.hosts == result


@pytest.mark.parametrize(
["hosts", "expectation"],
[
["example.com:123456", pytest.raises(ValidationError)],
["e^xample.com", pytest.raises(ValidationError)],
["example.com:-1", pytest.raises(ValidationError)],
],
)
def test_mqtt_hostnames_fail(hosts, expectation):
"""
Test parsing hostname(s) from env variables. Test invalid hostnames.
"""
with expectation:
MQTTParams(hosts=hosts, identifier="foo", username="bar", password="12345")

0 comments on commit 38c4ddf

Please sign in to comment.