diff --git a/_version.py b/_version.py index 7525884..65396de 100644 --- a/_version.py +++ b/_version.py @@ -2,4 +2,4 @@ Kraken version information. """ -__version__ = "4.4.0" +__version__ = "4.5.0" diff --git a/kraken_ethernet.py b/kraken_ethernet.py index ed07f56..ff68de8 100755 --- a/kraken_ethernet.py +++ b/kraken_ethernet.py @@ -110,14 +110,11 @@ def load_database_parameters(): @staticmethod def load_mqtt_parameters() -> MQTTParams: """Loads MQTT broker parameters from env variables""" - result: MQTTParams = { - "hostname": config("MQTT_HOST"), - "port": config("MQTT_PORT", cast=int, default=1883), - "username": load_secret("MQTT_CLIENT_USER", default=None), - "password": load_secret("MQTT_CLIENT_PASSWORD", default=None), - } - - return result + return MQTTParams( + hosts=config("MQTT_HOST"), + username=load_secret("MQTT_CLIENT_USER", default=None), + password=load_secret("MQTT_CLIENT_PASSWORD", default=None), + ) async def run(self): # pylint: disable=too-many-locals """ diff --git a/managers.py b/managers.py index aefdf61..8e5f8ad 100644 --- a/managers.py +++ b/managers.py @@ -7,15 +7,17 @@ from __future__ import annotations import asyncio +import itertools import logging import re from contextlib import AsyncExitStack -from typing import Any, TypedDict, Unpack +from typing import Any, Unpack from uuid import UUID import aiomqtt import simplejson as json from aiostream import pipe, stream +from pydantic import BaseModel, field_validator from async_event_bus import TopicNotRegisteredError, event_bus from data_types import DataEvent @@ -29,14 +31,68 @@ MQTT_DATA_TOPIC = "sensors/{driver}/{uid}/{sid}" -class MQTTParams(TypedDict): - """Parameters used for the MQTT broker.""" - - hostname: str - port: int +# A regular expression to match a hostname with an optional port. +# It adheres to RFC 1035 (https://www.rfc-editor.org/rfc/rfc1035) and matches ports +# between 0-65535. +HOSTNAME_REGEX = ( + r"^((?=.{1,255}$)[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?(?:\.[0-9A-Za-z](?:(?:[" + r"0-9A-Za-z]|-){0,61}[0-9A-Za-z])?)*\.?)(?:\:([0-9]{1,4}|[1-5][0-9]{4}|6[0-4][0-9]{3}|65[0-4][0-9]{" + r"2}|655[0-2][0-9]|6553[0-5]))?$" +) + + +class MQTTParams(BaseModel): + """ + Parameters used to connect to the MQTT broker. + + Parameters + ---------- + hosts: List of Tuple of str and int + A list of host:port tuples. The list contains the servers of a cluster. If no port is provided it defaults to + 1883. If port number 0 is provided the default value of 1883 is used. + identifier: str or None + An MQTT client id used to uniquely identify a client to persist messages. + username: str or None + The username used for authentication. Set to None if no username is required + password: str or None + The password used for authentication. Set to None if no username is required + """ + + hosts: list[tuple[str, int]] + identifier: str | None username: str | None password: str | None + @field_validator("hosts", mode="before") + @classmethod + def ensure_list_of_hosts(cls, value: str) -> list[tuple[str, int]]: + """ + Parse + Parameters + ---------- + value: str + Either a single hostname:port string or a comma separated list of hostname:port strings. + + Returns + ------- + list of tuple of str and int + A list of (hostname, port) tuples. + """ + hosts = value.split(",") + result = [] + for host in hosts: + host = host.strip() + match = re.search(HOSTNAME_REGEX, host) + if match is None: + raise ValueError(f"'{value}' is not a valid hostname or list of hostnames.") + result.append( + ( + match.group(1), + int(match.group(2)) if match.group(2) and not match.group(2) == "0" else 1883, + ) + ) + return result + class MqttManager: """This manager will take the sensor data from the event_bus backend and publish them onto the MQTT network""" @@ -95,7 +151,9 @@ def _calculate_timeout(last_reconnect_attempt: float, reconnect_interval: float) """ return max(0.0, reconnect_interval - (asyncio.get_running_loop().time() - last_reconnect_attempt)) - def _log_mqtt_error_code(self, worker_name: str, error_code: str | int, previous_error_code: str | int) -> None: + def _log_mqtt_error_code( + self, worker_name: str, host: tuple[str, int], error_code: str | int, previous_error_code: str | int + ) -> None: """ Log the MQTT error codes as human-readable errors to the logger (error log level). If the code is unknown, log it as an exception for debugging. Suppresses errors, if they are repeated. @@ -114,50 +172,29 @@ def _log_mqtt_error_code(self, worker_name: str, error_code: str | int, previous self.__logger.error( "Worker (%s): Connection refused by MQTT broker (%s:%i). Retrying.", worker_name, - self.__broker["hostname"], - self.__broker["port"], + *host, ) elif error_code == 113: - self.__logger.error( - "Worker (%s): MQTT broker (%s:%i) is unreachable. Retrying.", - worker_name, - self.__broker["hostname"], - self.__broker["port"], - ) + self.__logger.error("Worker (%s): MQTT broker (%s:%i) is unreachable. Retrying.", worker_name, *host) elif error_code == 7: self.__logger.error( - "Worker (%s): The connection to MQTT broker (%s:%i) was lost. Retrying.", - worker_name, - self.__broker["hostname"], - self.__broker["port"], + "Worker (%s): The connection to MQTT broker (%s:%i) was lost. Retrying.", worker_name, *host ) elif error_code == -2: self.__logger.error( - "Worker (%s): Failure in name resolution of MQTT broker (%s:%i). Retrying.", - worker_name, - self.__broker["hostname"], - self.__broker["port"], + "Worker (%s): Failure in name resolution of MQTT broker (%s:%i). Retrying.", worker_name, *host ) elif error_code == -3: self.__logger.error( "Worker (%s): Temporary failure in name resolution of MQTT broker (%s:%i). Retrying.", worker_name, - self.__broker["hostname"], - self.__broker["port"], + *host, ) elif error_code == -5: - self.__logger.error( - "Worker (%s): Unknown host name of MQTT broker (%s:%i). Retrying.", - worker_name, - self.__broker["hostname"], - self.__broker["port"], - ) + self.__logger.error("Worker (%s): Unknown host name of MQTT broker (%s:%i). Retrying.", worker_name, *host) elif error_code == "timed out": self.__logger.error( - "Worker (%s): The connection to MQTT broker (%s:%i) timed out. Retrying.", - worker_name, - self.__broker["hostname"], - self.__broker["port"], + "Worker (%s): The connection to MQTT broker (%s:%i) timed out. Retrying.", worker_name, *host ) else: self.__logger.exception("Worker (%s): MQTT connection error (code: %s). Retrying.", worker_name, error_code) @@ -184,7 +221,7 @@ async def consumer( # pylint: disable=too-many-branches error_code: str | int = 0 # 0 = success event: tuple[str, dict[str, str | float | int]] | None = None previous_reconnect_attempt = asyncio.get_running_loop().time() - reconnect_interval - while "not connected": + for host in itertools.cycle(self.__broker.hosts): # iterate over the list of hostnames until the end of time # Wait for at least reconnect_interval before connecting again timeout = self._calculate_timeout(previous_reconnect_attempt, reconnect_interval) if round(timeout) > 0: @@ -192,29 +229,22 @@ async def consumer( # pylint: disable=too-many-branches self.__logger.info( "Worker (%s): Connecting to MQTT broker (%s:%i) in %.0f s due to rate limiting.", worker_name, - self.__broker["hostname"], - self.__broker["port"], + *host, timeout, ) else: - self.__logger.info( - "Worker (%s): Connecting to MQTT broker (%s:%i).", - worker_name, - self.__broker["hostname"], - self.__broker["port"], - ) + self.__logger.info("Worker (%s): Connecting to MQTT broker (%s:%i).", worker_name, *host) await asyncio.sleep(timeout) previous_reconnect_attempt = asyncio.get_running_loop().time() try: async with aiomqtt.Client( identifier=f"Labkraken-{self.__node_id}_worker-{worker_name}", - **self.__broker, + hostname=host[0], + port=host[1], + **self.__broker.model_dump(exclude={"hosts"}), ) as mqtt_client: self.__logger.info( - "Worker (%s): Successfully connected to MQTT broker (%s:%i).", - worker_name, - self.__broker["hostname"], - self.__broker["port"], + "Worker (%s): Successfully connected to MQTT broker (%s:%i).", worker_name, *host ) while "loop not cancelled": if event is None: @@ -240,30 +270,32 @@ async def consumer( # pylint: disable=too-many-branches input_queue.task_done() error_code = 0 # 0 = success except aiomqtt.MqttCodeError as exc: - self._log_mqtt_error_code(worker_name, error_code=exc.rc, previous_error_code=error_code) + self._log_mqtt_error_code(worker_name, host=host, error_code=exc.rc, previous_error_code=error_code) error_code = exc.rc except ConnectionRefusedError: self._log_mqtt_error_code( - worker_name, error_code=111, previous_error_code=error_code + worker_name, host=host, error_code=111, previous_error_code=error_code ) # Connection refused is code 111 error_code = 111 except aiomqtt.MqttError as exc: error = re.search(r"\[Errno ([+-]?\d+)]", str(exc)) if error is not None: self._log_mqtt_error_code( - worker_name, error_code=int(error.group(1)), previous_error_code=error_code + worker_name, + host=host, + error_code=int(error.group(1)), + previous_error_code=error_code, ) error_code = int(error.group(1)) else: # no match found - self._log_mqtt_error_code(worker_name, error_code=str(exc), previous_error_code=error_code) + self._log_mqtt_error_code( + worker_name, host=host, error_code=str(exc), previous_error_code=error_code + ) error_code = str(exc) except Exception: # pylint: disable=broad-except # Catch all exceptions, log them, then try to restart the worker. self.__logger.exception( - "Worker (%s): Error while publishing data to MQTT broker (%s:%i). Reconnecting.", - worker_name, - self.__broker["hostname"], - self.__broker["port"], + "Worker (%s): Error while publishing data to MQTT broker (%s:%i). Reconnecting.", worker_name, *host ) async def cancel_tasks(self, tasks: set[asyncio.Task]) -> None: diff --git a/tests/test_env_parser.py b/tests/test_env_parser.py new file mode 100644 index 0000000..d37caee --- /dev/null +++ b/tests/test_env_parser.py @@ -0,0 +1,44 @@ +""" +Tests for the input environment variable parser. +""" + +import pytest +from pydantic import ValidationError + +from managers import MQTTParams + + +@pytest.mark.parametrize( + ["hosts", "result"], + [ + ["example.com", [("example.com", 1883)]], + ["example.com:1234", [("example.com", 1234)]], + ["example1.com,example2.com", [("example1.com", 1883), ("example2.com", 1883)]], + ["example1.com, example2.com", [("example1.com", 1883), ("example2.com", 1883)]], + ["example1.com:1234,example2.com", [("example1.com", 1234), ("example2.com", 1883)]], + ["example1.com:1234,example2.com:0", [("example1.com", 1234), ("example2.com", 1883)]], + ], +) +def test_mqtt_hostnames_pass(hosts, result): + """ + Test parsing hostname(s) from env variables. + """ + env_params = MQTTParams(hosts=hosts, identifier="foo", username="bar", password="12345") + + assert env_params.hosts == result + + +@pytest.mark.parametrize( + ["hosts", "expectation"], + [ + ["example.com:123456", pytest.raises(ValidationError)], + ["e^xample.com", pytest.raises(ValidationError)], + ["example.com:-1", pytest.raises(ValidationError)], + ], +) +def test_mqtt_hostnames_fail(hosts, expectation): + """ + Test parsing hostname(s) from env variables. Test invalid hostnames. + """ + with expectation: + MQTTParams(hosts=hosts, identifier="foo", username="bar", password="12345")