From 804107d6f692cac86f4b55fb05a56fef6c82956f Mon Sep 17 00:00:00 2001 From: DasSkelett Date: Sun, 17 Dec 2023 21:27:15 +0000 Subject: [PATCH 1/8] Refactor tests to allow running trough cmdline unittest --- README.md | 9 ++++++++- requirements.txt | 2 +- wgkex/common/BUILD | 3 ++- wgkex/common/utils.py | 2 ++ wgkex/common/utils_test.py | 2 +- wgkex/config/BUILD | 2 +- wgkex/config/config_test.py | 3 ++- wgkex/worker/BUILD | 13 +++++++------ wgkex/worker/app_test.py | 24 ++++++++++++++++-------- wgkex/worker/mqtt.py | 2 +- wgkex/worker/mqtt_test.py | 8 ++++---- wgkex/worker/netlink_test.py | 3 ++- 12 files changed, 47 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 18bf361..2c32b25 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,14 @@ Worker: python3 -c 'from wgkex.worker.app import main; main()' ``` -## Client usage + +## Development + +### Unit tests + +The test can be run using `bazel test ... --test_output=all` or `python3 -m unittest discover -p '*_test.py'`. + +### Client The client can be used via CLI: diff --git a/requirements.txt b/requirements.txt index 97a41ba..1821412 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,4 @@ waitress~=2.1.2 ipaddress~=1.0.23 mock~=5.1.0 coverage -paho-mqtt~=1.6.1 \ No newline at end of file +paho-mqtt~=1.6.1 diff --git a/wgkex/common/BUILD b/wgkex/common/BUILD index 4a12559..93b284b 100644 --- a/wgkex/common/BUILD +++ b/wgkex/common/BUILD @@ -15,7 +15,8 @@ py_test( name = "utils_test", srcs = ["utils_test.py"], deps = [ - ":utils", + "//wgkex/common:utils", + "//wgkex/config:config", requirement("mock"), ], ) diff --git a/wgkex/common/utils.py b/wgkex/common/utils.py index fecebef..276c2de 100644 --- a/wgkex/common/utils.py +++ b/wgkex/common/utils.py @@ -2,6 +2,8 @@ import ipaddress import re +from wgkex.config import config + def mac2eui64(mac: str, prefix=None) -> str: """Converts a MAC address to an EUI64 identifier. diff --git a/wgkex/common/utils_test.py b/wgkex/common/utils_test.py index e14b174..a0aa187 100644 --- a/wgkex/common/utils_test.py +++ b/wgkex/common/utils_test.py @@ -1,5 +1,5 @@ import unittest -import utils +from wgkex.common import utils class UtilsTest(unittest.TestCase): diff --git a/wgkex/config/BUILD b/wgkex/config/BUILD index 1ca5fb3..8167c58 100644 --- a/wgkex/config/BUILD +++ b/wgkex/config/BUILD @@ -16,7 +16,7 @@ py_test( name="config_test", srcs=["config_test.py"], deps=[ - ":config", + "//wgkex/config:config", requirement("mock"), ], ) diff --git a/wgkex/config/config_test.py b/wgkex/config/config_test.py index 3c33148..d8d6a15 100644 --- a/wgkex/config/config_test.py +++ b/wgkex/config/config_test.py @@ -1,9 +1,10 @@ """Tests for configuration handling class.""" import unittest import mock -import config import yaml +from wgkex.config import config + _VALID_CFG = ( "domain_prefixes:\n- ffmuc_\n- ffdon_\n- ffwert_\nlog_level: DEBUG\ndomains:\n- a\n- b\nmqtt:\n broker_port: 1883" "\n broker_url: mqtt://broker\n keepalive: 5\n password: pass\n tls: true\n username: user\n" diff --git a/wgkex/worker/BUILD b/wgkex/worker/BUILD index 80a82eb..7f1c2c3 100644 --- a/wgkex/worker/BUILD +++ b/wgkex/worker/BUILD @@ -21,8 +21,9 @@ py_test( name = "netlink_test", srcs = ["netlink_test.py"], deps = [ - ":netlink", + "//wgkex/worker:netlink", requirement("mock"), + requirement("pyroute2"), ], ) @@ -46,8 +47,8 @@ py_test( name = "mqtt_test", srcs = ["mqtt_test.py"], deps = [ - ":mqtt", - ":msg_queue", + "//wgkex/worker:mqtt", + "//wgkex/worker:msg_queue", requirement("mock"), ], ) @@ -67,8 +68,8 @@ py_test( name = "app_test", srcs = ["app_test.py"], deps = [ - ":app", - ":msg_queue", + "//wgkex/worker:app", + "//wgkex/worker:msg_queue", requirement("mock"), ], ) @@ -80,4 +81,4 @@ py_library( deps = [ "//wgkex/common:logger", ], -) \ No newline at end of file +) diff --git a/wgkex/worker/app_test.py b/wgkex/worker/app_test.py index 111590b..717fcfe 100644 --- a/wgkex/worker/app_test.py +++ b/wgkex/worker/app_test.py @@ -1,7 +1,9 @@ """Unit tests for app.py""" import unittest import mock -import app + +import wgkex.config.config +from wgkex.worker import app class AppTest(unittest.TestCase): @@ -48,43 +50,49 @@ def test_unique_domains_not_list(self): with self.assertRaises(TypeError): app.check_all_domains_unique(test_domains, test_prefixes) - @mock.patch.object(app.config, "load_config") + @mock.patch.object(wgkex.config.config, "fetch_from_config") + @mock.patch.object(wgkex.config.config, "load_config") @mock.patch.object(app.mqtt, "connect", autospec=True) - def test_main_success(self, connect_mock, config_mock): + def test_main_success(self, connect_mock, config_mock, config_fetch_mock): """Ensure we can execute main.""" connect_mock.return_value = None test_prefixes = ["TEST_PREFIX_", "TEST_PREFIX2_"] config_mock.return_value = dict( domains=[f"{test_prefixes[1]}domain.one"], domain_prefixes=test_prefixes ) - with mock.patch("app.flush_workers", return_value=None): + config_fetch_mock.side_effect = config_mock().get + with mock.patch.object(app, "flush_workers", return_value=None): app.main() - connect_mock.assert_called_with() + connect_mock.assert_called() + @mock.patch.object(wgkex.config.config, "fetch_from_config") @mock.patch.object(app.config, "load_config") @mock.patch.object(app.mqtt, "connect", autospec=True) - def test_main_fails_no_domain(self, connect_mock, config_mock): + def test_main_fails_no_domain(self, connect_mock, config_mock, config_fetch_mock): """Ensure we fail when domains are not configured.""" config_mock.return_value = dict(domains=None) + config_fetch_mock.side_effect = config_mock().get connect_mock.return_value = None with self.assertRaises(app.DomainsNotInConfig): app.main() + @mock.patch.object(wgkex.config.config, "fetch_from_config") @mock.patch.object(app.config, "load_config") @mock.patch.object(app.mqtt, "connect", autospec=True) - def test_main_fails_bad_domain(self, connect_mock, config_mock): + def test_main_fails_bad_domain(self, connect_mock, config_mock, config_fetch_mock): """Ensure we fail when domains are badly formatted.""" test_prefixes = ["TEST_PREFIX_", "TEST_PREFIX2_"] config_mock.return_value = dict( domains=[f"cant_split_domain"], domain_prefixes=test_prefixes ) + config_fetch_mock.side_effect = config_mock().get connect_mock.return_value = None with mock.patch("app.flush_workers", return_value=None): app.main() connect_mock.assert_called_with() @mock.patch("time.sleep", side_effect=InterruptedError) - @mock.patch("app.wg_flush_stale_peers") + @mock.patch.object(app, "wg_flush_stale_peers") def test_flush_workers(self, flush_mock, sleep_mock): """Ensure we fail when domains are badly formatted.""" flush_mock.return_value = "" diff --git a/wgkex/worker/mqtt.py b/wgkex/worker/mqtt.py index 1c1cf31..2216d69 100644 --- a/wgkex/worker/mqtt.py +++ b/wgkex/worker/mqtt.py @@ -92,7 +92,7 @@ def on_message(client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage) -> break if not domain: raise ValueError( - "Could not find a match for %s on %s", repr(domain_prefixes), message.topic + f"Could not find a match for {domain_prefixes} on {message.topic}" ) # this will not work, if we have non-unique prefix stripped domains domain = domain.group(1) diff --git a/wgkex/worker/mqtt_test.py b/wgkex/worker/mqtt_test.py index 8e2fcbf..8aece41 100644 --- a/wgkex/worker/mqtt_test.py +++ b/wgkex/worker/mqtt_test.py @@ -1,8 +1,8 @@ """Unit tests for mqtt.py""" import unittest import mock -import mqtt -import msg_queue + +from wgkex.worker import mqtt class MQTTTest(unittest.TestCase): @@ -48,7 +48,7 @@ def test_on_message_success(self, config_mock, link_mock): config_mock.return_value = {"domain_prefix": "_ffmuc_"} link_mock.return_value = dict(WireGuard="result") mqtt_msg = mock.patch.object(mqtt.mqtt, "MQTTMessage") - mqtt_msg.topic = "/_ffmuc_domain1/" + mqtt_msg.topic = "wireguard/_ffmuc_domain1/gateway" mqtt_msg.payload = b"PUB_KEY" mqtt.on_message(None, None, mqtt_msg) link_mock.assert_has_calls( @@ -80,7 +80,7 @@ def test_on_message_fails_no_domain(self, config_mock, link_mock): } link_mock.return_value = dict(WireGuard="result") mqtt_msg = mock.patch.object(mqtt.mqtt, "MQTTMessage") - mqtt_msg.topic = "bad_domain_match" + mqtt_msg.topic = "wireguard/bad_domain_match" with self.assertRaises(ValueError): mqtt.on_message(None, None, mqtt_msg) """ diff --git a/wgkex/worker/netlink_test.py b/wgkex/worker/netlink_test.py index aeb4ff3..c209731 100644 --- a/wgkex/worker/netlink_test.py +++ b/wgkex/worker/netlink_test.py @@ -13,7 +13,8 @@ sys.modules["pyroute2.IPRoute"] = mock.MagicMock() from pyroute2 import WireGuard from pyroute2 import IPRoute -import netlink + +from wgkex.worker import netlink _WG_CLIENT_ADD = netlink.WireGuardClient( public_key="public_key", domain="add", remove=False From 2f1a9586fb1e9fae404d0a3bb2d01a79b0bdf294 Mon Sep 17 00:00:00 2001 From: DasSkelett Date: Sun, 17 Dec 2023 21:27:15 +0000 Subject: [PATCH 2/8] Use Config class over raw dict everywhere --- wgkex/broker/app.py | 35 +++---------- wgkex/common/BUILD | 2 +- wgkex/common/utils.py | 14 ++++++ wgkex/config/__init__.py | 4 +- wgkex/config/config.py | 98 +++++++++++++++++++++++-------------- wgkex/config/config_test.py | 18 ++++--- wgkex/worker/app.py | 30 ++++++++---- wgkex/worker/app_test.py | 40 +++++++-------- wgkex/worker/mqtt.py | 47 +++++------------- wgkex/worker/mqtt_test.py | 74 +++++++++++++++------------- 10 files changed, 188 insertions(+), 174 deletions(-) diff --git a/wgkex/broker/app.py b/wgkex/broker/app.py index f01ec3f..f5d23d7 100644 --- a/wgkex/broker/app.py +++ b/wgkex/broker/app.py @@ -2,7 +2,6 @@ """wgkex broker""" import re import dataclasses -import logging from typing import Tuple, Any from flask import Flask @@ -17,6 +16,7 @@ from waitress import serve from wgkex.config import config from wgkex.common import logger +from wgkex.common.utils import is_valid_domain WG_PUBKEY_PATTERN = re.compile(r"^[A-Za-z0-9+/]{42}[AEIMQUYcgkosw480]=$") @@ -43,7 +43,9 @@ def from_dict(cls, msg: dict) -> "KeyExchange": A KeyExchange object. """ public_key = is_valid_wg_pubkey(msg.get("public_key")) - domain = is_valid_domain(msg.get("domain")) + domain = str(msg.get("domain")) + if not is_valid_domain(domain): + raise ValueError(f"Domain {domain} not in configured domains.") return cls(public_key=public_key, domain=domain) @@ -54,8 +56,7 @@ def _fetch_app_config() -> Flask_app: A created Flask app. """ app = Flask(__name__) - # TODO(ruairi): Refactor load_config to return Dataclass. - mqtt_cfg = config.Config.from_dict(config.load_config()).mqtt + mqtt_cfg = config.get_config().mqtt app.config["MQTT_BROKER_URL"] = mqtt_cfg.broker_url app.config["MQTT_BROKER_PORT"] = mqtt_cfg.broker_port app.config["MQTT_USERNAME"] = mqtt_cfg.username @@ -140,33 +141,13 @@ def is_valid_wg_pubkey(pubkey: str) -> str: return pubkey -def is_valid_domain(domain: str) -> str: - """Verifies if the domain is configured. - - Arguments: - domain: The domain to verify. - - Raises: - ValueError: If the domain is not configured. - - Returns: - The domain. - """ - # TODO(ruairi): Refactor to return bool. - if domain not in config.fetch_from_config("domains"): - raise ValueError( - f'Domains {domain} not in configured domains({config.fetch_from_config("domains")}) a valid domain' - ) - return domain - - if __name__ == "__main__": listen_host = None listen_port = None - listen_config = config.fetch_from_config("broker_listen") + listen_config = config.get_config().broker_listen if listen_config is not None: - listen_host = listen_config.get("host") - listen_port = listen_config.get("port") + listen_host = listen_config.host + listen_port = listen_config.port serve(app, host=listen_host, port=listen_port) diff --git a/wgkex/common/BUILD b/wgkex/common/BUILD index 93b284b..7a79f93 100644 --- a/wgkex/common/BUILD +++ b/wgkex/common/BUILD @@ -25,4 +25,4 @@ py_library( name = "logger", srcs = ["logger.py"], visibility = ["//visibility:public"] -) \ No newline at end of file +) diff --git a/wgkex/common/utils.py b/wgkex/common/utils.py index 276c2de..45c7b7b 100644 --- a/wgkex/common/utils.py +++ b/wgkex/common/utils.py @@ -37,3 +37,17 @@ def mac2eui64(mac: str, prefix=None) -> str: net = ipaddress.ip_network(prefix, strict=False) euil = int(f"0x{eui64:16}", 16) return f"{net[euil]}/{net.prefixlen}" + + +def is_valid_domain(domain: str) -> bool: + """Verifies if the domain is configured. + + Arguments: + domain: The domain to verify. + + Returns: + True if the domain is valid, False otherwise. + """ + return domain in config.get_config().domains and domain.startswith( + config.get_config().domain_prefix + ) diff --git a/wgkex/config/__init__.py b/wgkex/config/__init__.py index 1b48be8..9c9cace 100644 --- a/wgkex/config/__init__.py +++ b/wgkex/config/__init__.py @@ -1,3 +1,3 @@ -from wgkex.config.config import load_config +from wgkex.config.config import get_config -__all__ = ["load_config"] +__all__ = ["get_config"] diff --git a/wgkex/config/config.py b/wgkex/config/config.py index 0659d69..29ba0a5 100644 --- a/wgkex/config/config.py +++ b/wgkex/config/config.py @@ -1,11 +1,11 @@ """Configuration handling class.""" +import dataclasses import logging import os import sys +from typing import Dict, Any, List, Optional + import yaml -from functools import lru_cache -from typing import Dict, Union, Any, List, Optional -import dataclasses class Error(Exception): @@ -20,9 +20,29 @@ class ConfigFileNotFoundError(Error): WG_CONFIG_DEFAULT_LOCATION = "/etc/wgkex.yaml" +@dataclasses.dataclass +class BrokerListen: + """A representation of the 'broker_listen' key in Configuration file. + + Attributes: + host: The listen address the broker should listen to for the HTTP API. + port: The port the broker should listen to for the HTTP API. + """ + + host: Optional[str] + port: Optional[int] + + @classmethod + def from_dict(cls, broker_listen: Dict[str, Any]) -> "BrokerListen": + return cls( + host=broker_listen.get("host"), + port=broker_listen.get("port"), + ) + + @dataclasses.dataclass class MQTT: - """A representation of MQTT key in Configuration file. + """A representation of the 'mqtt' key in Configuration file. Attributes: broker_url: The broker URL for MQTT to connect to. @@ -54,11 +74,9 @@ def from_dict(cls, mqtt_cfg: Dict[str, str]) -> "MQTT": broker_url=mqtt_cfg["broker_url"], username=mqtt_cfg["username"], password=mqtt_cfg["password"], - tls=mqtt_cfg["tls"] if mqtt_cfg["tls"] else False, - broker_port=int(mqtt_cfg["broker_port"]) - if mqtt_cfg["broker_port"] - else None, - keepalive=int(mqtt_cfg["keepalive"]) if mqtt_cfg["keepalive"] else None, + tls=bool(mqtt_cfg.get("tls", cls.tls)), + broker_port=int(mqtt_cfg.get("broker_port", cls.broker_port)), + keepalive=int(mqtt_cfg.get("keepalive", cls.keepalive)), ) @@ -68,59 +86,65 @@ class Config: Attributes: domains: The list of domains to listen for. + domain_prefixes: The list of prefixes to pre-pend to a given domain. mqtt: The MQTT configuration. - domain_prefixes: The list of prefixes to pre-pend to a given domain.""" + """ + raw: Dict[str, Any] domains: List[str] - mqtt: MQTT domain_prefixes: List[str] + broker_listen: BrokerListen + mqtt: MQTT @classmethod - def from_dict(cls, cfg: Dict[str, str]) -> "Config": + def from_dict(cls, cfg: Dict[str, Any]) -> "Config": """Creates a Config object from a configuration file. Arguments: cfg: The configuration file as a dict. Returns: A Config object. """ + broker_listen = BrokerListen.from_dict(cfg.get("broker_listen", {})) mqtt_cfg = MQTT.from_dict(cfg["mqtt"]) return cls( + raw=cfg, domains=cfg["domains"], - mqtt=mqtt_cfg, domain_prefixes=cfg["domain_prefixes"], + broker_listen=broker_listen, + mqtt=mqtt_cfg, ) + def get(self, key: str) -> Any: + """Get the value of key from the raw dict representation of the config file""" + return self.raw.get(key) -@lru_cache(maxsize=10) -def fetch_from_config(key: str) -> Optional[Union[Dict[str, Any], List[str]]]: - """Fetches a specific key from configuration. - Arguments: - key: The named key to fetch. - Returns: - The config value associated with the key - """ - return load_config().get(key) +_parsed_config: Optional[Config] = None -def load_config() -> Dict[str, str]: - """Fetches and validates configuration file from disk. +def get_config() -> Config: + """Returns a parsed Config object. + Raises: + ConfigFileNotFoundError: If we could not find the configuration file on disk. Returns: - Linted configuration file. + The Config representation of the config file """ - cfg_contents = fetch_config_from_disk() - try: - config = yaml.safe_load(cfg_contents) - except yaml.YAMLError as e: - print("Failed to load YAML file: %s", e) - sys.exit(1) - try: - _ = Config.from_dict(config) - return config - except (KeyError, TypeError) as e: - print("Failed to lint file: %s", e) - sys.exit(2) + global _parsed_config + if _parsed_config is None: + cfg_contents = fetch_config_from_disk() + try: + config = yaml.safe_load(cfg_contents) + except yaml.YAMLError as e: + print("Failed to load YAML file: %s" % e) + sys.exit(1) + try: + config = Config.from_dict(config) + except (KeyError, TypeError, AttributeError) as e: + print("Failed to lint file: %s" % e) + sys.exit(2) + _parsed_config = config + return _parsed_config def fetch_config_from_disk() -> str: diff --git a/wgkex/config/config_test.py b/wgkex/config/config_test.py index d8d6a15..6e30eb3 100644 --- a/wgkex/config/config_test.py +++ b/wgkex/config/config_test.py @@ -17,18 +17,22 @@ class TestConfig(unittest.TestCase): + def tearDown(self) -> None: + config._parsed_config = None + return super().tearDown() + def test_load_config_success(self): """Test loads and lint config successfully.""" mock_open = mock.mock_open(read_data=_VALID_CFG) with mock.patch("builtins.open", mock_open): - self.assertDictEqual(yaml.safe_load(_VALID_CFG), config.load_config()) + self.assertDictEqual(yaml.safe_load(_VALID_CFG), config.get_config().raw) @mock.patch.object(config.sys, "exit", autospec=True) def test_load_config_fails_good_yaml_bad_format(self, exit_mock): """Test loads yaml successfully and fails lint.""" mock_open = mock.mock_open(read_data=_INVALID_LINT) with mock.patch("builtins.open", mock_open): - config.load_config() + config.get_config() exit_mock.assert_called_with(2) @mock.patch.object(config.sys, "exit", autospec=True) @@ -36,7 +40,7 @@ def test_load_config_fails_bad_yaml(self, exit_mock): """Test loads bad YAML.""" mock_open = mock.mock_open(read_data=_INVALID_CFG) with mock.patch("builtins.open", mock_open): - config.load_config() + config.get_config() exit_mock.assert_called_with(2) def test_fetch_config_from_disk_success(self): @@ -53,17 +57,17 @@ def test_fetch_config_from_disk_fails_file_not_found(self): with self.assertRaises(config.ConfigFileNotFoundError): config.fetch_config_from_disk() - def test_fetch_from_config_success(self): + def test_raw_get_success(self): """Test fetch key from configuration.""" mock_open = mock.mock_open(read_data=_VALID_CFG) with mock.patch("builtins.open", mock_open): - self.assertListEqual(["a", "b"], config.fetch_from_config("domains")) + self.assertListEqual(["a", "b"], config.get_config().raw.get("domains")) - def test_fetch_from_config_no_key_in_config(self): + def test_raw_get_no_key_in_config(self): """Test fetch non-existent key from configuration.""" mock_open = mock.mock_open(read_data=_VALID_CFG) with mock.patch("builtins.open", mock_open): - self.assertIsNone(config.fetch_from_config("key_does_not_exist")) + self.assertIsNone(config.get_config().raw.get("key_does_not_exist")) if __name__ == "__main__": diff --git a/wgkex/worker/app.py b/wgkex/worker/app.py index 70aa8fc..7fd8b7f 100644 --- a/wgkex/worker/app.py +++ b/wgkex/worker/app.py @@ -1,13 +1,15 @@ """Initialises the MQTT worker.""" -import wgkex.config.config as config +import threading +import time +from typing import Text + +from wgkex.common import logger +from wgkex.common.utils import is_valid_domain +from wgkex.config import config from wgkex.worker import mqtt from wgkex.worker.msg_queue import watch_queue from wgkex.worker.netlink import wg_flush_stale_peers -import time -import threading -from wgkex.common import logger -from typing import List, Text _CLEANUP_TIME = 3600 @@ -28,6 +30,10 @@ class DomainsAreNotUnique(Error): """If non-unique domains exist in configuration file.""" +class InvalidDomain(Error): + """If the domains is invalid and is not listed in the configuration file.""" + + def flush_workers(domain: Text) -> None: """Calls peer flush every _CLEANUP_TIME interval.""" while True: @@ -36,14 +42,15 @@ def flush_workers(domain: Text) -> None: logger.info("Cleaned up domains: %s", wg_flush_stale_peers(domain)) -def clean_up_worker(domains: List[Text]) -> None: +def clean_up_worker() -> None: """Wraps flush_workers in a thread for all given domains. Arguments: domains: list of domains. """ + domains = config.get_config().domains + prefixes = config.get_config().domain_prefixes logger.debug("Cleaning up the following domains: %s", domains) - prefixes = config.load_config().get("domain_prefixes") cleanup_counter = 0 # ToDo: do we need a check if every domain got gleaned? for prefix in prefixes: @@ -104,13 +111,16 @@ def main(): DomainsNotInConfig: If no domains were found in configuration file. DomainsAreNotUnique: If there were non-unique domains after stripping prefix """ - domains = config.load_config().get("domains") - prefixes = config.load_config().get("domain_prefixes") + domains = config.get_config().domains + prefixes = config.get_config().domain_prefixes if not domains: raise DomainsNotInConfig("Could not locate domains in configuration.") if not check_all_domains_unique(domains, prefixes): raise DomainsAreNotUnique("There are non-unique domains! Check config.") - clean_up_worker(domains) + for domain in domains: + if not is_valid_domain(domain): + raise InvalidDomain(f"Domain {domain} has invalid prefix.") + clean_up_worker() watch_queue() mqtt.connect() diff --git a/wgkex/worker/app_test.py b/wgkex/worker/app_test.py index 717fcfe..cfb2f37 100644 --- a/wgkex/worker/app_test.py +++ b/wgkex/worker/app_test.py @@ -6,6 +6,16 @@ from wgkex.worker import app +def _get_config_mock(domains=None): + test_prefixes = ["_TEST_PREFIX_", "_TEST_PREFIX2_"] + config_mock = mock.MagicMock() + config_mock.domains = ( + domains if domains is not None else [f"{test_prefixes[1]}domain.one"] + ) + config_mock.domain_prefixes = test_prefixes + return config_mock + + class AppTest(unittest.TestCase): """unittest.TestCase class""" @@ -50,46 +60,34 @@ def test_unique_domains_not_list(self): with self.assertRaises(TypeError): app.check_all_domains_unique(test_domains, test_prefixes) - @mock.patch.object(wgkex.config.config, "fetch_from_config") - @mock.patch.object(wgkex.config.config, "load_config") + @mock.patch.object(app.config, "get_config") @mock.patch.object(app.mqtt, "connect", autospec=True) - def test_main_success(self, connect_mock, config_mock, config_fetch_mock): + def test_main_success(self, connect_mock, config_mock): """Ensure we can execute main.""" connect_mock.return_value = None - test_prefixes = ["TEST_PREFIX_", "TEST_PREFIX2_"] - config_mock.return_value = dict( - domains=[f"{test_prefixes[1]}domain.one"], domain_prefixes=test_prefixes - ) - config_fetch_mock.side_effect = config_mock().get + config_mock.return_value = _get_config_mock() with mock.patch.object(app, "flush_workers", return_value=None): app.main() connect_mock.assert_called() - @mock.patch.object(wgkex.config.config, "fetch_from_config") - @mock.patch.object(app.config, "load_config") + @mock.patch.object(app.config, "get_config") @mock.patch.object(app.mqtt, "connect", autospec=True) def test_main_fails_no_domain(self, connect_mock, config_mock, config_fetch_mock): """Ensure we fail when domains are not configured.""" - config_mock.return_value = dict(domains=None) - config_fetch_mock.side_effect = config_mock().get + config_mock.return_value = _get_config_mock(domains=[]) connect_mock.return_value = None with self.assertRaises(app.DomainsNotInConfig): app.main() - @mock.patch.object(wgkex.config.config, "fetch_from_config") - @mock.patch.object(app.config, "load_config") + @mock.patch.object(app.config, "get_config") @mock.patch.object(app.mqtt, "connect", autospec=True) def test_main_fails_bad_domain(self, connect_mock, config_mock, config_fetch_mock): """Ensure we fail when domains are badly formatted.""" - test_prefixes = ["TEST_PREFIX_", "TEST_PREFIX2_"] - config_mock.return_value = dict( - domains=[f"cant_split_domain"], domain_prefixes=test_prefixes - ) - config_fetch_mock.side_effect = config_mock().get + config_mock.return_value = _get_config_mock(domains=["cant_split_domain"]) connect_mock.return_value = None - with mock.patch("app.flush_workers", return_value=None): + with self.assertRaises(app.InvalidDomain): app.main() - connect_mock.assert_called_with() + connect_mock.assert_not_called() @mock.patch("time.sleep", side_effect=InterruptedError) @mock.patch.object(app, "wg_flush_stale_peers") diff --git a/wgkex/worker/mqtt.py b/wgkex/worker/mqtt.py index 2216d69..c58be0d 100644 --- a/wgkex/worker/mqtt.py +++ b/wgkex/worker/mqtt.py @@ -1,46 +1,25 @@ #!/usr/bin/env python3 """Process messages from MQTT.""" -import paho.mqtt.client as mqtt - # TODO(ruairi): Deprecate __init__.py from config, as it masks namespace. -from wgkex.config.config import load_config import socket import re -from typing import Optional, Dict, Any, Union -from wgkex.common import logger -from wgkex.worker.msg_queue import q - - -def fetch_from_config(var: str) -> Optional[Union[Dict[str, str], str]]: - """Fetches values from configuration file. - - Arguments: - var: The variable to fetch from config. +from typing import Any - Raises: - ValueError: If given key cannot be found in configuration. +import paho.mqtt.client as mqtt - Returns: - The given variable from configuration. - """ - config = load_config() - ret = config.get(var) - if not ret: - raise ValueError("Failed to get %s from configuration, failing", var) - return config.get(var) +from wgkex.common import logger +from wgkex.config.config import get_config +from wgkex.worker.msg_queue import q +from wgkex.worker.netlink import link_handler, WireGuardClient def connect() -> None: - """Connect to MQTT for the given domains. - - Argument: - domains: The domains to connect to. - """ - base_config = fetch_from_config("mqtt") - broker_address = base_config.get("broker_url") - broker_port = base_config.get("broker_port") - broker_keepalive = base_config.get("keepalive") + """Connect to MQTT.""" + base_config = get_config().mqtt + broker_address = base_config.broker_url + broker_port = base_config.broker_port + broker_keepalive = base_config.keepalive # TODO(ruairi): Move the hostname to a global variable. client = mqtt.Client(socket.gethostname()) @@ -64,7 +43,7 @@ def on_connect(client: mqtt.Client, userdata: Any, flags, rc) -> None: rc: The MQTT rc. """ logger.debug("Connected with result code " + str(rc)) - domains = load_config().get("domains") + domains = get_config().domains # Subscribing in on_connect() means that if we lose the connection and # reconnect then subscriptions will be renewed. @@ -84,7 +63,7 @@ def on_message(client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage) -> """ # TODO(ruairi): Check bounds and raise exception here. logger.debug("Got message %s from MTQQ", message) - domain_prefixes = load_config().get("domain_prefixes") + domain_prefixes = get_config().domain_prefixes domain = None for domain_prefix in domain_prefixes: domain = re.search(r"/.*" + domain_prefix + "(\w+)/", message.topic) diff --git a/wgkex/worker/mqtt_test.py b/wgkex/worker/mqtt_test.py index 8aece41..cd19aa0 100644 --- a/wgkex/worker/mqtt_test.py +++ b/wgkex/worker/mqtt_test.py @@ -5,47 +5,53 @@ from wgkex.worker import mqtt -class MQTTTest(unittest.TestCase): - @mock.patch.object(mqtt, "load_config") - def test_fetch_from_config_success(self, config_mock): - """Ensure we can fetch a value from config.""" - config_mock.return_value = dict(key="value") - self.assertEqual("value", mqtt.fetch_from_config("key")) +def _get_config_mock(domains=None, mqtt=None): + test_prefix = "_ffmuc_" + config_mock = mock.MagicMock() + config_mock.domains = ( + domains if domains is not None else [f"{test_prefix}domain.one"] + ) + config_mock.domain_prefix = test_prefix + if mqtt: + config_mock.mqtt = mqtt + return config_mock - @mock.patch.object(mqtt, "load_config") - def test_fetch_from_config_fails_no_key(self, config_mock): - """Tests we fail with ValueError for missing key in config.""" - config_mock.return_value = dict(key="value") - with self.assertRaises(ValueError): - mqtt.fetch_from_config("does_not_exist") +class MQTTTest(unittest.TestCase): @mock.patch.object(mqtt.mqtt, "Client") @mock.patch.object(mqtt.socket, "gethostname") - @mock.patch.object(mqtt, "load_config") + @mock.patch.object(mqtt, "get_config") def test_connect_success(self, config_mock, hostname_mock, mqtt_mock): """Tests successful connection to MQTT server.""" hostname_mock.return_value = "hostname" - config_mock.return_value = dict(mqtt={"broker_url": "some_url"}) + config_mqtt_mock = mock.MagicMock() + config_mqtt_mock.broker_url = "some_url" + config_mqtt_mock.broker_port = 1833 + config_mqtt_mock.keepalive = False + config_mock.return_value = _get_config_mock(mqtt=config_mqtt_mock) mqtt.connect() mqtt_mock.assert_has_calls( - [mock.call().connect("some_url", port=None, keepalive=None)], + [mock.call().connect("some_url", port=1833, keepalive=False)], any_order=True, ) @mock.patch.object(mqtt.mqtt, "Client") - @mock.patch.object(mqtt, "load_config") + @mock.patch.object(mqtt, "get_config") def test_connect_fails_mqtt_error(self, config_mock, mqtt_mock): """Tests failure for connect - ValueError.""" mqtt_mock.side_effect = ValueError("barf") - config_mock.return_value = dict(mqtt={"broker_url": "some_url"}) + config_mqtt_mock = mock.MagicMock() + config_mqtt_mock.broker_url = "some_url" + config_mock.return_value = _get_config_mock(mqtt=config_mqtt_mock) with self.assertRaises(ValueError): mqtt.connect() -""" @mock.patch.object(msg_queue, "link_handler") - @mock.patch.object(mqtt, "load_config") +""" @mock.patch.object(msg_queue, "link_handler") + @mock.patch.object(mqtt, "get_config") def test_on_message_success(self, config_mock, link_mock): - config_mock.return_value = {"domain_prefix": "_ffmuc_"} + # Tests on_message for success. + config_mock.return_value = _get_config_mock() link_mock.return_value = dict(WireGuard="result") mqtt_msg = mock.patch.object(mqtt.mqtt, "MQTTMessage") mqtt_msg.topic = "wireguard/_ffmuc_domain1/gateway" @@ -63,27 +69,25 @@ def test_on_message_success(self, config_mock, link_mock): ) @mock.patch.object(msg_queue, "link_handler") - @mock.patch.object(mqtt, "load_config") + @mock.patch.object(mqtt, "get_config") def test_on_message_fails_no_domain(self, config_mock, link_mock): - config_mock.return_value = { - "domain_prefix": "ffmuc_", - "log_level": "DEBUG", - "domains": ["a", "b"], - "mqtt": { - "broker_port": 1883, - "broker_url": "mqtt://broker", - "keepalive": 5, - "password": "pass", - "tls": True, - "username": "user", - }, - } + # Tests on_message for failure to parse domain. + config_mqtt_mock = mock.MagicMock() + config_mqtt_mock.broker_url = "mqtt://broker" + config_mqtt_mock.broker_port = 1883 + config_mqtt_mock.keepalive = 5 + config_mqtt_mock.password = "pass" + config_mqtt_mock.tls = True + config_mqtt_mock.username = "user" + config_mock.return_value = _get_config_mock( + domains=["a", "b"], mqtt=config_mqtt_mock + ) link_mock.return_value = dict(WireGuard="result") mqtt_msg = mock.patch.object(mqtt.mqtt, "MQTTMessage") mqtt_msg.topic = "wireguard/bad_domain_match" with self.assertRaises(ValueError): mqtt.on_message(None, None, mqtt_msg) - """ +""" if __name__ == "__main__": unittest.main() From 216008e70f7fff2c89ecdcbdd58d4a49f28b0579 Mon Sep 17 00:00:00 2001 From: DasSkelett Date: Sun, 17 Dec 2023 21:27:15 +0000 Subject: [PATCH 3/8] Publish worker metrics and data, assign gateways to clients * Workers publish their number of connected peers per domain * Workers publish their status, i.e. up or down * The new /api/v2/exchange endpoint returns a predetermined gateway endpoint for clients * This gateway is chosen based on weighted loadbalancing between online workers/gateways * Fetch worker data through netlink and publish with MQTT: * Read worker pubkey, port and link address from interface. * Publish it together with the external domain / address (read from the config file) via MQTT to the broker. --- README.md | 47 ++++++++- wgkex.yaml.example | 26 ++++- wgkex/broker/BUILD | 13 +++ wgkex/broker/app.py | 163 ++++++++++++++++++++++++++---- wgkex/broker/metrics.py | 122 ++++++++++++++++++++++ wgkex/common/BUILD | 6 ++ wgkex/common/mqtt.py | 6 ++ wgkex/config/config.py | 56 +++++++++++ wgkex/worker/BUILD | 3 +- wgkex/worker/app.py | 18 +++- wgkex/worker/app_test.py | 5 +- wgkex/worker/mqtt.py | 190 +++++++++++++++++++++++++++++++++-- wgkex/worker/mqtt_test.py | 15 +-- wgkex/worker/netlink.py | 91 ++++++++++++++--- wgkex/worker/netlink_test.py | 2 + 15 files changed, 705 insertions(+), 58 deletions(-) create mode 100644 wgkex/broker/metrics.py create mode 100644 wgkex/common/mqtt.py diff --git a/README.md b/README.md index 2c32b25..9f673f8 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ - [Overview](#overview) - [Frontend broker](#frontend-broker) - [POST /api/v1/wg/key/exchange](#post-apiv1wgkeyexchange) + - [POST /api/v2/wg/key/exchange](#post-apiv2wgkeyexchange) - [Backend worker](#backend-worker) - [Installation](#installation) - [Configuration](#configuration) @@ -41,6 +42,7 @@ The frontend broker exposes the following API endpoints for use: ``` /api/v1/wg/key/exchange +/api/v2/wg/key/exchange ``` The listen address and port for the Flask server can be configured in `wgkex.yaml` under the `broker_listen` key: @@ -66,6 +68,35 @@ JSON POST'd to this endpoint should be in this format: The broker will validate the domain and public key, and if valid, will push the key onto the MQTT bus. + +#### POST /api/v2/wg/key/exchange + +JSON POST'd to this endpoint should be in this format: + +```json +{ + "domain": "CONFIGURED_DOMAIN", + "public_key": "PUBLIC_KEY" +} +``` + +The broker will validate the domain and public key, and if valid, will push the key onto the MQTT bus. +Additionally it chooses a worker (aka gateway, endpoint) that the client should connect to. +The response is JSON data containing the connection details for the chosen gateway: + +```json +{ + "Endpoint": { + "Address": "GATEWAY_ADDRESS", + "Port": "GATEWAY_WIREGUARD_PORT", + "AllowedIPs": [ + "GATEWAY_WIREGUARD_INTERFACE_ADDRESS" + ], + "PublicKey": "GATEWAY_PUBLIC_KEY" + } +} +``` + ### Backend worker The backend (worker) waits for new keys to appear on the MQTT message bus. Once a new key appears, the worker performs @@ -141,8 +172,13 @@ The test can be run using `bazel test ... --test_output=all` or `python3 -m unit The client can be used via CLI: ``` -$ wget -q -O- --post-data='{"domain": "ffmuc_welt","public_key": "o52Ge+Rpj4CUSitVag9mS7pSXUesNM0ESnvj/wwehkg="}' --header='Content-Type:application/json' 'http://127.0.0.1:5000/api/v1/wg/key/exchange' +$ wget -q -O- --post-data='{"domain": "ffmuc_welt","public_key": "o52Ge+Rpj4CUSitVag9mS7pSXUesNM0ESnvj/wwehkg="}' --header='Content-Type:application/json' 'http://127.0.0.1:5000/api/v2/wg/key/exchange' { + "Endpoint": { + "Address": "gw04.ext.ffmuc.net:40011", + "LinkAddress": "fe80::27c:16ff:fec0:6c74", + "PublicKey": "TszFS3oFRdhsJP3K0VOlklGMGYZy+oFCtlaghXJqW2g=" + }, "Message": "OK" } ``` @@ -153,7 +189,7 @@ Or via python: import requests key_data = {"domain": "ffmuc_welt","public_key": "o52Ge+Rpj4CUSitVag9mS7pSXUesNM0ESnvj/wwehkg="} broker_url = "http://127.0.0.1:5000" -push_key = requests.get(f'{broker_url}/api/v1/wg/key/exchange', json=key_data) +push_key = requests.get(f'{broker_url}/api/v2/wg/key/exchange', json=key_data) print(f'Key push was: {push_key.json().get("Message")}') ``` @@ -180,6 +216,13 @@ sudo ip link set wg-welt up sudo ip link set vx-welt up ``` +### MQTT topics + +Publishing keys broker->worker: `wireguard/{domain}/{worker}` +Publishing metrics worker->broker: `wireguard-metrics/{domain}/{worker}/connected_peers` +Publishing worker status: `wireguard-worker/{worker}/status` +Publishing worker data: `wireguard-worker/{worker}/{domain}/data` + ## Contact [Freifunk Munich Mattermost](https://chat.ffmuc.net) diff --git a/wgkex.yaml.example b/wgkex.yaml.example index 7b82c71..30340fe 100644 --- a/wgkex.yaml.example +++ b/wgkex.yaml.example @@ -1,3 +1,4 @@ +# [broker] The domains that should be accepted by clients and for which matching WireGuard interfaces exist domains: - ffmuc_muc_cty - ffmuc_muc_nord @@ -6,6 +7,25 @@ domains: - ffmuc_muc_west - ffmuc_welt - ffwert_city +# [broker, worker] The prefix is trimmed from the domain name and replaced with 'wg-' and 'vx-' +# to calculate the WireGuard and VXLAN interface names +domain_prefixes: + - ffmuc_ + - ffdon_ + - ffwert_ +# [broker] The dict of workers mapping their hostname to their respective weight for weighted peer distribution +workers: + gw04.in.ffmuc.net: + weight: 30 + gw05.in.ffmuc.net: + weight: 30 + gw06.in.ffmuc.net: + weight: 20 + gw07.in.ffmuc.net: + weight: 20 +# [worker] The external hostname of this worker +externalName: gw04.ext.ffmuc.net +# [broker, worker] MQTT connection informations mqtt: broker_url: broker.hivemq.com broker_port: 1883 @@ -13,13 +33,11 @@ mqtt: password: SECRET keepalive: 5 tls: False +# [broker] broker_listen: host: 0.0.0.0 port: 5000 -domain_prefixes: - - ffmuc_ - - ffdon_ - - ffwert_ +# [broker, worker] logging_config: formatters: standard: diff --git a/wgkex/broker/BUILD b/wgkex/broker/BUILD index 260fe45..414da32 100644 --- a/wgkex/broker/BUILD +++ b/wgkex/broker/BUILD @@ -1,6 +1,17 @@ load("@rules_python//python:defs.bzl", "py_binary", "py_test") load("@pip//:requirements.bzl", "requirement") +py_library( + name = "metrics", + srcs = ["metrics.py"], + visibility = ["//visibility:public"], + deps = [ + "//wgkex/common:mqtt", + "//wgkex/common:logger", + "//wgkex/config:config", + ], +) + py_binary( name="app", srcs=["app.py"], @@ -11,5 +22,7 @@ py_binary( requirement("flask-mqtt"), requirement("waitress"), "//wgkex/config:config", + "//wgkex/common:mqtt", + ":metrics" ], ) diff --git a/wgkex/broker/app.py b/wgkex/broker/app.py index f5d23d7..1d753ff 100644 --- a/wgkex/broker/app.py +++ b/wgkex/broker/app.py @@ -1,22 +1,25 @@ #!/usr/bin/env python3 """wgkex broker""" -import re import dataclasses -from typing import Tuple, Any +import json +import re +from typing import Dict, Tuple, Any -from flask import Flask -from flask import abort -from flask import jsonify -from flask import render_template -from flask import request +import paho.mqtt.client as mqtt_client +from flask import Flask, render_template, request, Response from flask.app import Flask as Flask_app from flask_mqtt import Mqtt -import paho.mqtt.client as mqtt_client from waitress import serve from wgkex.config import config from wgkex.common import logger from wgkex.common.utils import is_valid_domain +from wgkex.broker.metrics import WorkerMetricsCollection +from wgkex.common.mqtt import ( + CONNECTED_PEERS_METRIC, + TOPIC_WORKER_STATUS, + TOPIC_WORKER_WG_DATA, +) WG_PUBKEY_PATTERN = re.compile(r"^[A-Za-z0-9+/]{42}[AEIMQUYcgkosw480]=$") @@ -68,34 +71,86 @@ def _fetch_app_config() -> Flask_app: app = _fetch_app_config() mqtt = Mqtt(app) +worker_metrics = WorkerMetricsCollection() +worker_data: Dict[Tuple[str, str], Dict] = {} @app.route("/", methods=["GET"]) -def index() -> None: +def index() -> str: """Returns main page""" return render_template("index.html") @app.route("/api/v1/wg/key/exchange", methods=["POST"]) -def wg_key_exchange() -> Tuple[str, int]: +def wg_api_v1_key_exchange() -> Tuple[Response | Dict, int]: """Retrieves a new key and validates. - Returns: Status message. """ try: data = KeyExchange.from_dict(request.get_json(force=True)) - except TypeError as ex: - return abort(400, jsonify({"error": {"message": str(ex)}})) + except Exception as ex: + return {"error": {"message": str(ex)}}, 400 key = data.public_key domain = data.domain # in case we want to decide here later we want to publish it only to dedicated gateways gateway = "all" - logger.info(f"wg_key_exchange: Domain: {domain}, Key:{key}") + logger.info(f"wg_api_v1_key_exchange: Domain: {domain}, Key:{key}") mqtt.publish(f"wireguard/{domain}/{gateway}", key) - return jsonify({"Message": "OK"}), 200 + return {"Message": "OK"}, 200 + + +@app.route("/api/v2/wg/key/exchange", methods=["POST"]) +def wg_api_v2_key_exchange() -> Tuple[Response | Dict, int]: + """Retrieves a new key, validates it and responds with a worker/gateway the client should connect to. + + Returns: + Status message, Endpoint with address/domain, port pubic key and link address. + """ + try: + data = KeyExchange.from_dict(request.get_json(force=True)) + except Exception as ex: + return {"error": {"message": str(ex)}}, 400 + + key = data.public_key + domain = data.domain + # in case we want to decide here later we want to publish it only to dedicated gateways + gateway = "all" + logger.info(f"wg_api_v2_key_exchange: Domain: {domain}, Key:{key}") + + mqtt.publish(f"wireguard/{domain}/{gateway}", key) + + best_worker, diff, current_peers = worker_metrics.get_best_worker(domain) + if best_worker is None: + logger.warning(f"No worker online for domain {domain}") + return { + "error": { + "message": "no gateway online for this domain, please check the domain value and try again later" + } + }, 400 + + worker_metrics.update( + best_worker, domain, CONNECTED_PEERS_METRIC, current_peers + 1 + ) + logger.debug( + f"Chose worker {best_worker} with {current_peers} connected clients ({diff})" + ) + + w_data = worker_data.get((best_worker, domain), None) + if w_data is None: + logger.error(f"Couldn't get worker endpoint data for {best_worker}/{domain}") + return {"error": {"message": "could not get gateway data"}}, 500 + + endpoint = { + "Address": w_data.get("ExternalAddress"), + "Port": str(w_data.get("Port")), + "AllowedIPs": [w_data.get("LinkAddress")], + "PublicKey": w_data.get("PublicKey"), + } + + return {"Endpoint": endpoint}, 200 @mqtt.on_connect() @@ -109,7 +164,69 @@ def handle_mqtt_connect( app.config["MQTT_BROKER_URL"], app.config["MQTT_BROKER_PORT"] ) ) - # mqtt.subscribe("wireguard/#") + mqtt.subscribe("wireguard-metrics/#") + mqtt.subscribe(TOPIC_WORKER_STATUS.format(worker="+")) + mqtt.subscribe(TOPIC_WORKER_WG_DATA.format(worker="+", domain="+")) + + +@mqtt.on_topic("wireguard-metrics/#") +def handle_mqtt_message_metrics( + client: mqtt_client.Client, userdata: bytes, message: mqtt_client.MQTTMessage +) -> None: + """Processes published metrics from workers.""" + logger.debug( + f"MQTT message received on {message.topic}: {message.payload.decode()}" + ) + _, domain, worker, metric = message.topic.split("/", 3) + if not is_valid_domain(domain): + logger.error(f"Domain {domain} not in configured domains") + return + + if not worker or not metric: + logger.error("Ignored MQTT message with empty worker or metrics label") + return + + data = int(message.payload) + + logger.info(f"Update worker metrics: {metric} on {worker}/{domain} = {data}") + worker_metrics.update(worker, domain, metric, data) + + +@mqtt.on_topic(TOPIC_WORKER_STATUS.format(worker="+")) +def handle_mqtt_message_status( + client: mqtt_client.Client, userdata: bytes, message: mqtt_client.MQTTMessage +) -> None: + """Processes status messages from workers.""" + _, worker, _ = message.topic.split("/", 2) + + status = int(message.payload) + if status < 1: + logger.warning(f"Marking worker as offline: {worker}") + worker_metrics.set_offline(worker) + else: + logger.warning(f"Marking worker as online: {worker}") + worker_metrics.set_online(worker) + + +@mqtt.on_topic(TOPIC_WORKER_WG_DATA.format(worker="+", domain="+")) +def handle_mqtt_message_data( + client: mqtt_client.Client, userdata: bytes, message: mqtt_client.MQTTMessage +) -> None: + """Processes data messages from workers. + + Stores them in a local dict""" + _, worker, domain, _ = message.topic.split("/", 3) + if not is_valid_domain(domain): + logger.error(f"Domain {domain} not in configured domains.") + return + + data = json.loads(message.payload) + if not isinstance(data, dict): + logger.error("Invalid worker data received for %s/%s: %s", worker, domain, data) + return + + logger.info("Worker data received for %s/%s: %s", worker, domain, data) + worker_data[(worker, domain)] = data @mqtt.on_message() @@ -117,7 +234,6 @@ def handle_mqtt_message( client: mqtt_client.Client, userdata: bytes, message: mqtt_client.MQTTMessage ) -> None: """Prints message contents.""" - # TODO(ruairi): Clarify current usage of this function. logger.debug( f"MQTT message received on {message.topic}: {message.payload.decode()}" ) @@ -141,6 +257,19 @@ def is_valid_wg_pubkey(pubkey: str) -> str: return pubkey +def join_host_port(host: str, port: str) -> str: + """Concatenate a port string with a host string using a colon. + The host may be either a hostname, IPv4 or IPv6 address. + An IPv6 address as host will be automatically encapsulated in square brackets. + + Returns: + The joined host:port string + """ + if host.find(":") >= 0: + return "[" + host + "]:" + port + return host + ":" + port + + if __name__ == "__main__": listen_host = None listen_port = None diff --git a/wgkex/broker/metrics.py b/wgkex/broker/metrics.py new file mode 100644 index 0000000..9e5fdc4 --- /dev/null +++ b/wgkex/broker/metrics.py @@ -0,0 +1,122 @@ +import dataclasses +from operator import itemgetter +from typing import Any, Dict, Optional, Tuple + +from wgkex.config import config +from wgkex.common import logger +from wgkex.common.mqtt import CONNECTED_PEERS_METRIC + + +@dataclasses.dataclass +class WorkerMetrics: + """Metrics of a single worker""" + + worker: str + # domain -> [metric name -> metric data] + domain_data: Dict[str, Dict[str, Any]] = dataclasses.field(default_factory=dict) + online: bool = False + + def is_online(self, domain: str = "") -> bool: + if domain: + return ( + self.online + and self.get_domain_metrics(domain).get(CONNECTED_PEERS_METRIC, -1) >= 0 + ) + else: + return self.online + + def get_domain_metrics(self, domain: str) -> Dict[str, Any]: + return self.domain_data.get(domain, {}) + + def set_metric(self, domain: str, metric: str, value: Any) -> None: + if domain in self.domain_data: + self.domain_data[domain][metric] = value + else: + self.domain_data[domain] = {metric: value} + + +@dataclasses.dataclass +class WorkerMetricsCollection: + """A container for all worker metrics""" + + # worker -> WorkerMetrics + data: Dict[str, WorkerMetrics] = dataclasses.field(default_factory=dict) + + def get(self, worker: str) -> WorkerMetrics: + return self.data.get(worker, WorkerMetrics(worker=worker)) + + def set(self, worker: str, metrics: WorkerMetrics) -> None: + self.data[worker] = metrics + + def update(self, worker: str, domain: str, metric: str, value: Any) -> None: + if worker in self.data: + self.data[worker].set_metric(domain, metric, value) + else: + metrics = WorkerMetrics(worker) + metrics.set_metric(domain, metric, value) + self.data[worker] = metrics + + def set_online(self, worker: str) -> None: + if worker in self.data: + self.data[worker].online = True + else: + metrics = WorkerMetrics(worker) + metrics.online = True + self.data[worker] = metrics + + def set_offline(self, worker: str) -> None: + if worker in self.data: + self.data[worker].online = False + + def get_total_peers(self) -> int: + total = 0 + for worker in self.data: + worker_data = self.data.get(worker) + if not worker_data: + continue + for domain in worker_data.domain_data: + total += max( + worker_data.get_domain_metrics(domain).get( + CONNECTED_PEERS_METRIC, 0 + ), + 0, + ) + + return total + + def get_best_worker(self, domain: str) -> Tuple[Optional[str], int, int]: + """Analyzes the metrics and determines the best worker that a new client should connect to. + The best worker is defined as the one with the most number of clients missing + to its should-be target value according to its weight. + + Returns: + A 3-tuple containing the worker name, difference to target peers, number of connected peers. + The worker name can be None if none is online. + """ + # Map metrics to a list of (target diff, peer count, worker) tuples for online workers + + peers_worker_tuples = [] + total_peers = self.get_total_peers() + workerCfg = config.get_config().workers + + for wm in self.data.values(): + if not wm.online: + continue + peers = wm.get_domain_metrics(domain).get(CONNECTED_PEERS_METRIC, -1) + if peers < 0: + continue + + rel_weight = workerCfg.relative_worker_weight(wm.worker) + target = rel_weight * total_peers + diff = peers - target + logger.debug( + f"Worker {wm.worker}: rel weight {rel_weight}, target {target} (total {total_peers}), diff {diff}" + ) + peers_worker_tuples.append((diff, peers, wm.worker)) + + peers_worker_tuples = sorted(peers_worker_tuples, key=itemgetter(0)) + + if len(peers_worker_tuples) > 0: + best = peers_worker_tuples[0] + return best[2], best[0], best[1] + return None, 0, 0 diff --git a/wgkex/common/BUILD b/wgkex/common/BUILD index 7a79f93..b203348 100644 --- a/wgkex/common/BUILD +++ b/wgkex/common/BUILD @@ -26,3 +26,9 @@ py_library( srcs = ["logger.py"], visibility = ["//visibility:public"] ) + +py_library( + name = "mqtt", + srcs = ["mqtt.py"], + visibility = ["//visibility:public"] +) diff --git a/wgkex/common/mqtt.py b/wgkex/common/mqtt.py new file mode 100644 index 0000000..69bf15b --- /dev/null +++ b/wgkex/common/mqtt.py @@ -0,0 +1,6 @@ +"""Common MQTT constants like topic string templates""" + +TOPIC_WORKER_WG_DATA = "wireguard-worker/{worker}/{domain}/data" +TOPIC_WORKER_STATUS = "wireguard-worker/{worker}/status" +CONNECTED_PEERS_METRIC = "connected_peers" +TOPIC_CONNECTED_PEERS = "wireguard-metrics/{domain}/{worker}/" + CONNECTED_PEERS_METRIC diff --git a/wgkex/config/config.py b/wgkex/config/config.py index 29ba0a5..b0239e2 100644 --- a/wgkex/config/config.py +++ b/wgkex/config/config.py @@ -20,6 +20,55 @@ class ConfigFileNotFoundError(Error): WG_CONFIG_DEFAULT_LOCATION = "/etc/wgkex.yaml" +@dataclasses.dataclass +class Worker: + """A representation of the values of the 'workers' dict in the configuration file. + + Attributes: + weight: The relative weight of a worker, defaults to 1. + """ + + weight: int + + @classmethod + def from_dict(cls, worker_cfg: Dict[str, Any]) -> "Worker": + return cls( + weight=int(worker_cfg["weight"]) if worker_cfg["weight"] else 1, + ) + + +@dataclasses.dataclass +class Workers: + """A representation of the 'workers' key in the configuration file. + + Attributes: + total_weight: Calculated on init, the total weight of all configured workers. + """ + + total_weight: int + _workers: Dict[str, Worker] + + @classmethod + def from_dict(cls, workers_cfg: Dict[str, Dict[str, Any]]) -> "Workers": + d = {key: Worker.from_dict(value) for (key, value) in workers_cfg.items()} + + total = 0 + for worker in d.values(): + total += worker.weight + total = max(total, 1) + + return cls(total_weight=total, _workers=d) + + def get(self, worker: str) -> Optional[Worker]: + return self._workers.get(worker) + + def relative_worker_weight(self, worker_name: str) -> float: + worker = self.get(worker_name) + if worker is None: + return 1 / self.total_weight + return worker.weight / self.total_weight + + @dataclasses.dataclass class BrokerListen: """A representation of the 'broker_listen' key in Configuration file. @@ -88,6 +137,8 @@ class Config: domains: The list of domains to listen for. domain_prefixes: The list of prefixes to pre-pend to a given domain. mqtt: The MQTT configuration. + workers: The worker weights configuration (broker-only). + externalName: The publicly resolvable domain name or public IP address of this worker (worker-only). """ raw: Dict[str, Any] @@ -95,6 +146,8 @@ class Config: domain_prefixes: List[str] broker_listen: BrokerListen mqtt: MQTT + workers: Workers + external_name: Optional[str] @classmethod def from_dict(cls, cfg: Dict[str, Any]) -> "Config": @@ -106,12 +159,15 @@ def from_dict(cls, cfg: Dict[str, Any]) -> "Config": """ broker_listen = BrokerListen.from_dict(cfg.get("broker_listen", {})) mqtt_cfg = MQTT.from_dict(cfg["mqtt"]) + workers_cfg = Workers.from_dict(cfg.get("workers", {})) return cls( raw=cfg, domains=cfg["domains"], domain_prefixes=cfg["domain_prefixes"], broker_listen=broker_listen, mqtt=mqtt_cfg, + workers=workers_cfg, + external_name=cfg.get("externalName"), ) def get(self, key: str) -> Any: diff --git a/wgkex/worker/BUILD b/wgkex/worker/BUILD index 7f1c2c3..b1d9b6d 100644 --- a/wgkex/worker/BUILD +++ b/wgkex/worker/BUILD @@ -35,8 +35,9 @@ py_library( requirement("NetLink"), requirement("paho-mqtt"), requirement("pyroute2"), - "//wgkex/common:utils", "//wgkex/common:logger", + "//wgkex/common:mqtt", + "//wgkex/common:utils", "//wgkex/config:config", ":msg_queue", ":netlink", diff --git a/wgkex/worker/app.py b/wgkex/worker/app.py index 7fd8b7f..9a07d97 100644 --- a/wgkex/worker/app.py +++ b/wgkex/worker/app.py @@ -1,5 +1,7 @@ """Initialises the MQTT worker.""" +import signal +import sys import threading import time from typing import Text @@ -67,7 +69,9 @@ def clean_up_worker() -> None: domain, ) continue - thread = threading.Thread(target=flush_workers, args=(cleaned_domain,)) + thread = threading.Thread( + target=flush_workers, args=(cleaned_domain,), daemon=True + ) thread.start() if cleanup_counter < len(domains): logger.error( @@ -111,6 +115,16 @@ def main(): DomainsNotInConfig: If no domains were found in configuration file. DomainsAreNotUnique: If there were non-unique domains after stripping prefix """ + exit_event = threading.Event() + + def on_exit(sig_number, stack_frame) -> None: + logger.info("Shutting down...") + exit_event.set() + time.sleep(2) + sys.exit() + + signal.signal(signal.SIGINT, on_exit) + domains = config.get_config().domains prefixes = config.get_config().domain_prefixes if not domains: @@ -122,7 +136,7 @@ def main(): raise InvalidDomain(f"Domain {domain} has invalid prefix.") clean_up_worker() watch_queue() - mqtt.connect() + mqtt.connect(exit_event) if __name__ == "__main__": diff --git a/wgkex/worker/app_test.py b/wgkex/worker/app_test.py index cfb2f37..0cf525f 100644 --- a/wgkex/worker/app_test.py +++ b/wgkex/worker/app_test.py @@ -2,7 +2,6 @@ import unittest import mock -import wgkex.config.config from wgkex.worker import app @@ -72,7 +71,7 @@ def test_main_success(self, connect_mock, config_mock): @mock.patch.object(app.config, "get_config") @mock.patch.object(app.mqtt, "connect", autospec=True) - def test_main_fails_no_domain(self, connect_mock, config_mock, config_fetch_mock): + def test_main_fails_no_domain(self, connect_mock, config_mock): """Ensure we fail when domains are not configured.""" config_mock.return_value = _get_config_mock(domains=[]) connect_mock.return_value = None @@ -81,7 +80,7 @@ def test_main_fails_no_domain(self, connect_mock, config_mock, config_fetch_mock @mock.patch.object(app.config, "get_config") @mock.patch.object(app.mqtt, "connect", autospec=True) - def test_main_fails_bad_domain(self, connect_mock, config_mock, config_fetch_mock): + def test_main_fails_bad_domain(self, connect_mock, config_mock): """Ensure we fail when domains are badly formatted.""" config_mock.return_value = _get_config_mock(domains=["cant_split_domain"]) connect_mock.return_value = None diff --git a/wgkex/worker/mqtt.py b/wgkex/worker/mqtt.py index c58be0d..ef45345 100644 --- a/wgkex/worker/mqtt.py +++ b/wgkex/worker/mqtt.py @@ -2,36 +2,105 @@ """Process messages from MQTT.""" # TODO(ruairi): Deprecate __init__.py from config, as it masks namespace. -import socket +import json import re -from typing import Any +import socket +import threading +from typing import Any, Optional import paho.mqtt.client as mqtt from wgkex.common import logger +from wgkex.common.mqtt import ( + TOPIC_CONNECTED_PEERS, + TOPIC_WORKER_STATUS, + TOPIC_WORKER_WG_DATA, +) from wgkex.config.config import get_config from wgkex.worker.msg_queue import q -from wgkex.worker.netlink import link_handler, WireGuardClient +from wgkex.worker.netlink import ( + get_device_data, + link_handler, + get_connected_peers_count, + WireGuardClient, +) + +_HOSTNAME = socket.gethostname() +_METRICS_SEND_INTERVAL = 60 + +def connect(exit_event: threading.Event) -> None: + """Connect to MQTT. -def connect() -> None: - """Connect to MQTT.""" + Argument: + exit_event: A threading.Event that signals application shutdown. + """ base_config = get_config().mqtt broker_address = base_config.broker_url broker_port = base_config.broker_port broker_keepalive = base_config.keepalive - # TODO(ruairi): Move the hostname to a global variable. - client = mqtt.Client(socket.gethostname()) + client = mqtt.Client(_HOSTNAME) + domains = get_config().domains + + # Register LWT to set worker status down when lossing connection + client.will_set(TOPIC_WORKER_STATUS.format(worker=_HOSTNAME), 0, qos=1, retain=True) # Register handlers client.on_connect = on_connect + client.on_disconnect = on_disconnect client.on_message = on_message + client.message_callback_add("wireguard/#", on_message_wireguard) logger.info("connecting to broker %s", broker_address) client.connect(broker_address, port=broker_port, keepalive=broker_keepalive) + + # Start background threads that should not be restarted on reconnect + + # Mark worker as offline on graceful shutdown, after exit_event is set + def mark_offline_on_exit(exit_event: threading.Event): + exit_event.wait() + if client.is_connected(): + logger.info("Marking worker as down") + client.publish( + TOPIC_WORKER_STATUS.format(worker=_HOSTNAME), 0, qos=1, retain=True + ) + + mark_offline_on_exit_thread = threading.Thread( + target=mark_offline_on_exit, args=(exit_event,) + ) + mark_offline_on_exit_thread.start() + + for domain in domains: + # Schedule repeated metrics publishing + metrics_thread = threading.Thread( + target=publish_metrics_loop, args=(exit_event, client, domain) + ) + metrics_thread.start() + client.loop_forever() +def on_disconnect(client: mqtt.Client, userdata: Any, rc): + """Handles MQTT disconnect and logs the event + + Expected signature for MQTT v3.1.1 and v3.1 is: + disconnect_callback(client, userdata, rc) + + and for MQTT v5.0: + disconnect_callback(client, userdata, reasonCode, properties) + + Arguments: + client: the client instance for this callback + userdata: the private user data as set in Client() or userdata_set() + rc: the disconnection result + The rc parameter indicates the disconnection state. If + MQTT_ERR_SUCCESS (0), the callback was called in response to + a disconnect() call. If any other value the disconnection + was unexpected, such as might be caused by a network error. + """ + logger.debug("Disconnected with result code " + str(rc)) + + # The callback for when the client receives a CONNACK response from the server. def on_connect(client: mqtt.Client, userdata: Any, flags, rc) -> None: """Handles MQTT connect and subscribes to topics on connect @@ -45,15 +114,59 @@ def on_connect(client: mqtt.Client, userdata: Any, flags, rc) -> None: logger.debug("Connected with result code " + str(rc)) domains = get_config().domains - # Subscribing in on_connect() means that if we lose the connection and - # reconnect then subscriptions will be renewed. + own_external_name = ( + get_config().external_name + if get_config().external_name is not None + else _HOSTNAME + ) + for domain in domains: + # Subscribing in on_connect() means that if we lose the connection and + # reconnect then subscriptions will be renewed. topic = f"wireguard/{domain}/+" logger.info(f"Subscribing to topic {topic}") client.subscribe(topic) + # Publish worker data (WG pubkeys, ports, local addresses) + iface = wg_interface_name(domain) + if iface: + (port, public_key, link_address) = get_device_data(iface) + data = { + "ExternalAddress": own_external_name, + "Port": port, + "PublicKey": public_key, + "LinkAddress": link_address, + } + client.publish( + TOPIC_WORKER_WG_DATA.format(worker=_HOSTNAME, domain=domain), + json.dumps(data), + qos=1, + retain=True, + ) + else: + logger.error( + f"Could not get interface name for domain {domain}. Skipping worker data publication" + ) + + # Mark worker as online + client.publish(TOPIC_WORKER_STATUS.format(worker=_HOSTNAME), 1, qos=1, retain=True) + def on_message(client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage) -> None: + """Fallback handler for MQTT messages that do not match any other registered handler. + + Arguments: + client: the client instance for this callback. + userdata: the private user data. + message: The MQTT message. + """ + logger.info("Got unhandled message on %s from MQTT", message.topic) + return + + +def on_message_wireguard( + client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage +) -> None: """Processes messages from MQTT and forwards them to netlink. Arguments: @@ -62,11 +175,12 @@ def on_message(client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage) -> message: The MQTT message. """ # TODO(ruairi): Check bounds and raise exception here. - logger.debug("Got message %s from MTQQ", message) + logger.debug("Got message on %s from MQTT", message.topic) + domain_prefixes = get_config().domain_prefixes domain = None for domain_prefix in domain_prefixes: - domain = re.search(r"/.*" + domain_prefix + "(\w+)/", message.topic) + domain = re.search(r".*/" + domain_prefix + r"(\w+)/", message.topic) if domain: break if not domain: @@ -80,3 +194,57 @@ def on_message(client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage) -> f"Received create message for key {str(message.payload.decode('utf-8'))} on domain {domain} adding to queue" ) q.put((domain, str(message.payload.decode("utf-8")))) + + +def publish_metrics_loop( + exit_event: threading.Event, client: mqtt.Client, domain: str +) -> None: + """Continuously send metrics every METRICS_SEND_INTERVAL seconds for this gateway and the given domain.""" + logger.info("Scheduling metrics task for %s, ", domain) + + topic = TOPIC_CONNECTED_PEERS.format(domain=domain, worker=_HOSTNAME) + + while not exit_event.is_set(): + publish_metrics(client, topic, domain) + # This drifts slightly over time, doesn't matter for us + exit_event.wait(_METRICS_SEND_INTERVAL) + + # Set peers metric to -1 to mark worker as offline + # Use QoS 1 (at least once) to make sure the broker notices + client.publish(topic, -1, qos=1, retain=True) + + +def publish_metrics(client: mqtt.Client, topic: str, domain: str) -> None: + """Publish metrics for this gateway and the given domain. + + The metrics currently only consist of the number of connected peers. + """ + logger.debug(f"Publishing metrics for domain {domain}") + iface = wg_interface_name(domain) + if not iface: + logger.error( + f"Could not get interface name for domain {domain}. Skipping metrics publication" + ) + return + peer_count = get_connected_peers_count(iface) + + # Publish metrics, retain it at MQTT broker so restarted wgkex broker has metrics right away + client.publish(topic, peer_count, retain=True) + + +def wg_interface_name(domain: str) -> Optional[str]: + """Calculates the WireGuard interface name for a domain""" + domain_prefixes = get_config().domain_prefixes + cleaned_domain = None + for prefix in domain_prefixes: + try: + cleaned_domain = domain.split(prefix[1]) + except IndexError: + continue + break + if not cleaned_domain: + raise ValueError( + f"Could not find a match for {domain_prefixes} on {domain}" + ) + # this will not work, if we have non-unique prefix stripped domains + return f"wg-{cleaned_domain}" diff --git a/wgkex/worker/mqtt_test.py b/wgkex/worker/mqtt_test.py index cd19aa0..803125a 100644 --- a/wgkex/worker/mqtt_test.py +++ b/wgkex/worker/mqtt_test.py @@ -1,4 +1,5 @@ """Unit tests for mqtt.py""" +import threading import unittest import mock @@ -29,7 +30,9 @@ def test_connect_success(self, config_mock, hostname_mock, mqtt_mock): config_mqtt_mock.broker_port = 1833 config_mqtt_mock.keepalive = False config_mock.return_value = _get_config_mock(mqtt=config_mqtt_mock) - mqtt.connect() + ee = threading.Event() + mqtt.connect(ee) + ee.set() mqtt_mock.assert_has_calls( [mock.call().connect("some_url", port=1833, keepalive=False)], any_order=True, @@ -44,19 +47,19 @@ def test_connect_fails_mqtt_error(self, config_mock, mqtt_mock): config_mqtt_mock.broker_url = "some_url" config_mock.return_value = _get_config_mock(mqtt=config_mqtt_mock) with self.assertRaises(ValueError): - mqtt.connect() + mqtt.connect(threading.Event()) """ @mock.patch.object(msg_queue, "link_handler") @mock.patch.object(mqtt, "get_config") - def test_on_message_success(self, config_mock, link_mock): + def test_on_message_wireguard_success(self, config_mock, link_mock): # Tests on_message for success. config_mock.return_value = _get_config_mock() link_mock.return_value = dict(WireGuard="result") mqtt_msg = mock.patch.object(mqtt.mqtt, "MQTTMessage") mqtt_msg.topic = "wireguard/_ffmuc_domain1/gateway" mqtt_msg.payload = b"PUB_KEY" - mqtt.on_message(None, None, mqtt_msg) + mqtt.on_message_wireguard(None, None, mqtt_msg) link_mock.assert_has_calls( [ mock.call( @@ -70,7 +73,7 @@ def test_on_message_success(self, config_mock, link_mock): @mock.patch.object(msg_queue, "link_handler") @mock.patch.object(mqtt, "get_config") - def test_on_message_fails_no_domain(self, config_mock, link_mock): + def test_on_message_wireguard_fails_no_domain(self, config_mock, link_mock): # Tests on_message for failure to parse domain. config_mqtt_mock = mock.MagicMock() config_mqtt_mock.broker_url = "mqtt://broker" @@ -86,7 +89,7 @@ def test_on_message_fails_no_domain(self, config_mock, link_mock): mqtt_msg = mock.patch.object(mqtt.mqtt, "MQTTMessage") mqtt_msg.topic = "wireguard/bad_domain_match" with self.assertRaises(ValueError): - mqtt.on_message(None, None, mqtt_msg) + mqtt.on_message_wireguard(None, None, mqtt_msg) """ if __name__ == "__main__": diff --git a/wgkex/worker/netlink.py b/wgkex/worker/netlink.py index d4f0656..a880d64 100644 --- a/wgkex/worker/netlink.py +++ b/wgkex/worker/netlink.py @@ -1,12 +1,15 @@ """Functions related to netlink manipulation for Wireguard, IPRoute and FDB on Linux.""" +# See https://docs.pyroute2.org/iproute.html for a documentation of the layout of netlink responses import hashlib import re from dataclasses import dataclass from datetime import datetime from datetime import timedelta from textwrap import wrap -from typing import Dict, List +from typing import Any, Dict, List, Tuple + import pyroute2 +import pyroute2.netlink from wgkex.common.utils import mac2eui64 from wgkex.common import logger @@ -191,18 +194,82 @@ def find_stale_wireguard_clients(wg_interface: str) -> List: "Starting search for stale wireguard peers for interface %s.", wg_interface ) with pyroute2.WireGuard() as wg: - all_clients = [] - peers_on_interface = wg.info(wg_interface) - logger.info("Got infos: %s.", peers_on_interface) - for peer in peers_on_interface: - clients = peer.get_attr("WGDEVICE_A_PEERS") - logger.info("Got clients: %s.", clients) - if clients: - all_clients.extend(clients) + all_peers = [] + msgs = wg.info(wg_interface) + logger.debug("Got infos for stale peers: %s.", msgs) + for msg in msgs: + peers = msg.get_attr("WGDEVICE_A_PEERS") + logger.debug("Got clients: %s.", peers) + if peers: + all_peers.extend(peers) ret = [ - client.get_attr("WGPEER_A_PUBLIC_KEY").decode("utf-8") - for client in all_clients - if client.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME").get("tv_sec", int()) + peer.get_attr("WGPEER_A_PUBLIC_KEY").decode("utf-8") + for peer in all_peers + if peer.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME").get("tv_sec", int()) < three_hrs_in_secs ] return ret + + +def get_connected_peers_count(wg_interface: str) -> int: + """Fetches and returns the number of connected peers, i.e. which had recent handshakes. + + Arguments: + wg_interface: The WireGuard interface to query. + + Returns: + # The number of peers which have recently seen a handshake. + """ + three_mins_ago_in_secs = int((datetime.now() - timedelta(minutes=3)).timestamp()) + logger.info("Counting connected wireguard peers for interface %s.", wg_interface) + with pyroute2.WireGuard() as wg: + msgs = wg.info(wg_interface) + logger.debug("Got infos for connected peers: %s.", msgs) + count = 0 + for msg in msgs: + peers = msg.get_attr("WGDEVICE_A_PEERS") + logger.debug("Got clients: %s.", peers) + if peers: + for peer in peers: + if ( + peer.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME").get( + "tv_sec", int() + ) + > three_mins_ago_in_secs + ): + count += 1 + + return count + + +def get_device_data(wg_interface: str) -> Tuple[Any, Any, Any]: + """Returns the listening port, public key and local IP address. + + Arguments: + wg_interface: The WireGuard interface to query. + + Returns: + # The listening port, public key, and local IP address of the WireGuard interface. + """ + logger.info("Reading data from interface %s.", wg_interface) + with pyroute2.WireGuard() as wg, pyroute2.NDB() as ndb: + msgs = wg.info(wg_interface) + logger.debug("Got infos for interface data: %s.", msgs) + if len(msgs) > 1: + logger.warning( + "Got multiple messages from netlink, expected one. Using only first one." + ) + info: pyroute2.netlink.nla = msgs[0] + + port = int(info.get_attr("WGDEVICE_A_LISTEN_PORT")) + public_key = info.get_attr("WGDEVICE_A_PUBLIC_KEY").decode("ascii") + link_address = ndb.interfaces[wg_interface].ipaddr[0].get("address") + + logger.debug( + "Interface data: port '%s', public key '%s', link address '%s", + port, + public_key, + link_address, + ) + + return (port, public_key, link_address) diff --git a/wgkex/worker/netlink_test.py b/wgkex/worker/netlink_test.py index c209731..68cd573 100644 --- a/wgkex/worker/netlink_test.py +++ b/wgkex/worker/netlink_test.py @@ -11,6 +11,8 @@ sys.modules["pyroute2"] = mock.MagicMock() sys.modules["pyroute2.WireGuard"] = mock.MagicMock() sys.modules["pyroute2.IPRoute"] = mock.MagicMock() +sys.modules["pyroute2.NDB"] = mock.MagicMock() +sys.modules["pyroute2.netlink"] = mock.MagicMock() from pyroute2 import WireGuard from pyroute2 import IPRoute From 21a125bca15b7ab3133751fa3764877b4eeb6642 Mon Sep 17 00:00:00 2001 From: DasSkelett Date: Sun, 17 Dec 2023 21:27:15 +0000 Subject: [PATCH 4/8] Add more tests for broker/metrics.py --- wgkex/broker/BUILD | 9 +++ wgkex/broker/metrics.py | 10 ++- wgkex/broker/metrics_test.py | 125 +++++++++++++++++++++++++++++++++++ 3 files changed, 138 insertions(+), 6 deletions(-) create mode 100644 wgkex/broker/metrics_test.py diff --git a/wgkex/broker/BUILD b/wgkex/broker/BUILD index 414da32..780f340 100644 --- a/wgkex/broker/BUILD +++ b/wgkex/broker/BUILD @@ -12,6 +12,15 @@ py_library( ], ) +py_test( + name="metrics_test", + srcs=["metrics_test.py"], + deps = [ + "//wgkex/broker:metrics", + requirement("mock"), + ], +) + py_binary( name="app", srcs=["app.py"], diff --git a/wgkex/broker/metrics.py b/wgkex/broker/metrics.py index 9e5fdc4..a2e2893 100644 --- a/wgkex/broker/metrics.py +++ b/wgkex/broker/metrics.py @@ -97,16 +97,14 @@ def get_best_worker(self, domain: str) -> Tuple[Optional[str], int, int]: peers_worker_tuples = [] total_peers = self.get_total_peers() - workerCfg = config.get_config().workers + worker_cfg = config.get_config().workers for wm in self.data.values(): - if not wm.online: - continue - peers = wm.get_domain_metrics(domain).get(CONNECTED_PEERS_METRIC, -1) - if peers < 0: + if not wm.is_online(domain): continue - rel_weight = workerCfg.relative_worker_weight(wm.worker) + peers = wm.get_domain_metrics(domain).get(CONNECTED_PEERS_METRIC) + rel_weight = worker_cfg.relative_worker_weight(wm.worker) target = rel_weight * total_peers diff = peers - target logger.debug( diff --git a/wgkex/broker/metrics_test.py b/wgkex/broker/metrics_test.py new file mode 100644 index 0000000..63b25bf --- /dev/null +++ b/wgkex/broker/metrics_test.py @@ -0,0 +1,125 @@ +import unittest + +import mock +from wgkex.config import config +from wgkex.broker.metrics import WorkerMetricsCollection + + +class TestMetrics(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + # Give each test a placeholder config + test_config = config.Config.from_dict( + { + "domains": [], + "domain_prefix": "", + "workers": {}, + "mqtt": {"broker_url": "", "username": "", "password": ""}, + } + ) + mocked_config = mock.create_autospec(spec=test_config, spec_set=True) + config._parsed_config = mocked_config + + @classmethod + def tearDownClass(cls) -> None: + config._parsed_config = None + + def test_set_online_matches_is_online(self): + """Verify set_online sets worker online and matches result of is_online.""" + worker_metrics = WorkerMetricsCollection() + worker_metrics.set_online("worker1") + + ret = worker_metrics.get("worker1").is_online() + self.assertTrue(ret) + + def test_set_offline_matches_is_online(self): + """Verify set_offline sets worker offline and matches negated result of is_online.""" + worker_metrics = WorkerMetricsCollection() + worker_metrics.set_offline("worker1") + + ret = worker_metrics.get("worker1").is_online() + self.assertFalse(ret) + + def test_unkown_is_offline(self): + """Verify an unkown worker is considered offline.""" + worker_metrics = WorkerMetricsCollection() + + ret = worker_metrics.get("worker1").is_online() + self.assertFalse(ret) + + def test_set_online_matches_is_online_domain(self): + """Verify set_online sets worker online and matches result of is_online with domain.""" + worker_metrics = WorkerMetricsCollection() + worker_metrics.set_online("worker1") + worker_metrics.update("worker1", "d", "connected_peers", 5) + + ret = worker_metrics.get("worker1").is_online("d") + self.assertTrue(ret) + + def test_set_online_matches_is_online_offline_domain(self): + """Verify worker is considered offline if connected_peers for domain is <0.""" + worker_metrics = WorkerMetricsCollection() + worker_metrics.set_online("worker1") + worker_metrics.update("worker1", "d", "connected_peers", -1) + + ret = worker_metrics.get("worker1").is_online("d") + self.assertFalse(ret) + + @mock.patch("wgkex.broker.metrics.config.get_config", autospec=True) + def test_get_best_worker_returns_best(self, config_mock): + """Verify get_best_worker returns the worker with least connected clients for equally weighted workers.""" + test_config = mock.MagicMock(spec=config.Config) + test_config.workers = config.Workers.from_dict({}) + config_mock.return_value = test_config + + worker_metrics = WorkerMetricsCollection() + worker_metrics.update("1", "d", "connected_peers", 20) + worker_metrics.update("2", "d", "connected_peers", 19) + worker_metrics.set_online("1") + worker_metrics.set_online("2") + + (worker, diff, connected) = worker_metrics.get_best_worker("d") + self.assertEqual(worker, "2") + self.assertEqual(diff, -20) # 19-(1*(20+19)) + self.assertEqual(connected, 19) + + @mock.patch("wgkex.broker.metrics.config.get_config", autospec=True) + def test_get_best_worker_weighted_returns_best(self, config_mock): + """Verify get_best_worker returns the worker with least client differential for weighted workers.""" + test_config = mock.MagicMock(spec=config.Config) + test_config.workers = config.Workers.from_dict( + {"1": {"weight": 84}, "2": {"weight": 42}} + ) + config_mock.return_value = test_config + + worker_metrics = WorkerMetricsCollection() + worker_metrics.update("1", "d", "connected_peers", 21) + worker_metrics.update("2", "d", "connected_peers", 19) + worker_metrics.set_online("1") + worker_metrics.set_online("2") + + (worker, _, _) = worker_metrics.get_best_worker("d") + config_mock.assert_called() + self.assertEqual(worker, "1") + + def test_get_best_worker_no_worker_online_returns_none(self): + """Verify get_best_worker returns None if there is no online worker.""" + worker_metrics = WorkerMetricsCollection() + worker_metrics.update("1", "d", "connected_peers", 20) + worker_metrics.update("2", "d", "connected_peers", 19) + worker_metrics.set_offline("1") + worker_metrics.set_offline("2") + + (worker, _, _) = worker_metrics.get_best_worker("d") + self.assertIsNone(worker) + + def test_get_best_worker_no_worker_registered_returns_none(self): + """Verify get_best_worker returns None if there is no online worker.""" + worker_metrics = WorkerMetricsCollection() + + (worker, _, _) = worker_metrics.get_best_worker("d") + self.assertIsNone(worker) + + +if __name__ == "__main__": + unittest.main() From 50a7ca686472a1f638f07900c761815ce0552139 Mon Sep 17 00:00:00 2001 From: DasSkelett Date: Sun, 17 Dec 2023 21:27:15 +0000 Subject: [PATCH 5/8] Add more tests for worker/netlink.py --- wgkex/worker/netlink.py | 2 +- wgkex/worker/netlink_test.py | 67 +++++++++++++++++++++++++++++++++--- 2 files changed, 63 insertions(+), 6 deletions(-) diff --git a/wgkex/worker/netlink.py b/wgkex/worker/netlink.py index a880d64..a1b5411 100644 --- a/wgkex/worker/netlink.py +++ b/wgkex/worker/netlink.py @@ -242,7 +242,7 @@ def get_connected_peers_count(wg_interface: str) -> int: return count -def get_device_data(wg_interface: str) -> Tuple[Any, Any, Any]: +def get_device_data(wg_interface: str) -> Tuple[int, str, str]: """Returns the listening port, public key and local IP address. Arguments: diff --git a/wgkex/worker/netlink_test.py b/wgkex/worker/netlink_test.py index 68cd573..0874005 100644 --- a/wgkex/worker/netlink_test.py +++ b/wgkex/worker/netlink_test.py @@ -26,15 +26,31 @@ ) -def _get_wg_mock(key_name, stale_time): - pm = mock.Mock() - pm.get_attr.side_effect = [{"tv_sec": stale_time}, key_name.encode()] +def _get_peer_mock(public_key, last_handshake_time): + def peer_get_attr(attr: str): + if attr == "WGPEER_A_LAST_HANDSHAKE_TIME": + return {"tv_sec": last_handshake_time} + if attr == "WGPEER_A_PUBLIC_KEY": + return public_key.encode() + peer_mock = mock.Mock() - peer_mock.get_attr.side_effect = [[pm]] + peer_mock.get_attr.side_effect = peer_get_attr + return peer_mock + + +def _get_wg_mock(public_key, last_handshake_time): + peer_mock = _get_peer_mock(public_key, last_handshake_time) + + def msg_get_attr(attr: str): + if attr == "WGDEVICE_A_PEERS": + return [peer_mock] + + msg_mock = mock.Mock() + msg_mock.get_attr.side_effect = msg_get_attr wg_instance = WireGuard() wg_info_mock = wg_instance.__enter__.return_value wg_info_mock.set.return_value = {"WireGuard": "set"} - wg_info_mock.info.return_value = [peer_mock] + wg_info_mock.info.return_value = [msg_mock] return wg_info_mock @@ -188,6 +204,47 @@ def test_wg_flush_stale_peers_stale_success(self): "del", dst="fe80::281:16ff:fe49:395e/128", oif=mock.ANY ) + def test_get_connected_peers_count_success(self): + """Tests getting the correct number of connected peers for an interface.""" + peers = [] + for i in range(10): + peer_mock = _get_peer_mock( + "TEST_KEY", + int((datetime.now() - timedelta(minutes=i, seconds=5)).timestamp()), + ) + peers.append(peer_mock) + + def msg_get_attr(attr: str): + if attr == "WGDEVICE_A_PEERS": + return peers + + msg_mock = mock.Mock() + msg_mock.get_attr.side_effect = msg_get_attr + + wg_instance = WireGuard() + wg_info_mock = wg_instance.__enter__.return_value + wg_info_mock.info.return_value = [msg_mock] + + ret = netlink.get_connected_peers_count("wg-welt") + self.assertEqual(ret, 3) + + def test_get_device_data_success(self): + def msg_get_attr(attr: str): + if attr == "WGDEVICE_A_LISTEN_PORT": + return 51820 + if attr == "WGDEVICE_A_PUBLIC_KEY": + return "TEST_PUBLIC_KEY".encode("ascii") + + msg_mock = mock.Mock() + msg_mock.get_attr.side_effect = msg_get_attr + + wg_instance = WireGuard() + wg_info_mock = wg_instance.__enter__.return_value + wg_info_mock.info.return_value = [msg_mock] + + ret = netlink.get_device_data("wg-welt") + self.assertTupleEqual(ret, (51820, "TEST_PUBLIC_KEY", mock.ANY)) + if __name__ == "__main__": unittest.main() From be4206325364c907d276676230b3f5cdc58d8bac Mon Sep 17 00:00:00 2001 From: DasSkelett Date: Sun, 17 Dec 2023 21:27:15 +0000 Subject: [PATCH 6/8] Add more tests for worker/mqtt.py --- wgkex/worker/mqtt_test.py | 44 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/wgkex/worker/mqtt_test.py b/wgkex/worker/mqtt_test.py index 803125a..1c6bf81 100644 --- a/wgkex/worker/mqtt_test.py +++ b/wgkex/worker/mqtt_test.py @@ -1,8 +1,13 @@ """Unit tests for mqtt.py""" +import socket import threading import unittest +from time import sleep + import mock +import paho.mqtt.client +from wgkex.common.mqtt import TOPIC_CONNECTED_PEERS from wgkex.worker import mqtt @@ -50,6 +55,44 @@ def test_connect_fails_mqtt_error(self, config_mock, mqtt_mock): mqtt.connect(threading.Event()) + @mock.patch.object(mqtt, "get_config") + @mock.patch.object(mqtt, "get_connected_peers_count") + def test_publish_metrics_loop_success(self, conn_peers_mock, config_mock): + config_mock.return_value = _get_config_mock() + conn_peers_mock.return_value = 20 + mqtt_client = mock.MagicMock(spec=paho.mqtt.client.Client) + + ee = threading.Event() + thread = threading.Thread( + target=mqtt.publish_metrics_loop, + args=(ee, mqtt_client, "_ffmuc_domain.one"), + ) + thread.start() + + i = 0 + while i < 20 and not mqtt_client.publish.called: + i += 1 + sleep(0.1) + + conn_peers_mock.assert_called_with("wg-domain.one") + mqtt_client.publish.assert_called_with( + TOPIC_CONNECTED_PEERS.format( + domain="_ffmuc_domain.one", worker=socket.gethostname() + ), + 20, + retain=True, + ) + + ee.set() + + i = 0 + while i < 20 and thread.is_alive(): + i += 1 + sleep(0.1) + + self.assertFalse(thread.is_alive()) + + """ @mock.patch.object(msg_queue, "link_handler") @mock.patch.object(mqtt, "get_config") def test_on_message_wireguard_success(self, config_mock, link_mock): @@ -92,5 +135,6 @@ def test_on_message_wireguard_fails_no_domain(self, config_mock, link_mock): mqtt.on_message_wireguard(None, None, mqtt_msg) """ + if __name__ == "__main__": unittest.main() From 7aa996797204ad7933f1e0fb3e1b00d6367b71f7 Mon Sep 17 00:00:00 2001 From: DasSkelett Date: Sun, 17 Dec 2023 21:27:15 +0000 Subject: [PATCH 7/8] Fix logical merge conflicts --- wgkex/broker/metrics_test.py | 2 +- wgkex/common/utils.py | 9 ++++++--- wgkex/worker/mqtt.py | 6 ++---- wgkex/worker/mqtt_test.py | 7 +++---- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/wgkex/broker/metrics_test.py b/wgkex/broker/metrics_test.py index 63b25bf..520e6a9 100644 --- a/wgkex/broker/metrics_test.py +++ b/wgkex/broker/metrics_test.py @@ -12,7 +12,7 @@ def setUpClass(cls) -> None: test_config = config.Config.from_dict( { "domains": [], - "domain_prefix": "", + "domain_prefixes": "", "workers": {}, "mqtt": {"broker_url": "", "username": "", "password": ""}, } diff --git a/wgkex/common/utils.py b/wgkex/common/utils.py index 45c7b7b..8fa201c 100644 --- a/wgkex/common/utils.py +++ b/wgkex/common/utils.py @@ -48,6 +48,9 @@ def is_valid_domain(domain: str) -> bool: Returns: True if the domain is valid, False otherwise. """ - return domain in config.get_config().domains and domain.startswith( - config.get_config().domain_prefix - ) + if not domain in config.get_config().domains: + return False + for prefix in config.get_config().domain_prefixes: + if domain.startswith(prefix): + return True + return False diff --git a/wgkex/worker/mqtt.py b/wgkex/worker/mqtt.py index ef45345..caf7011 100644 --- a/wgkex/worker/mqtt.py +++ b/wgkex/worker/mqtt.py @@ -238,13 +238,11 @@ def wg_interface_name(domain: str) -> Optional[str]: cleaned_domain = None for prefix in domain_prefixes: try: - cleaned_domain = domain.split(prefix[1]) + cleaned_domain = domain.split(prefix)[1] except IndexError: continue break if not cleaned_domain: - raise ValueError( - f"Could not find a match for {domain_prefixes} on {domain}" - ) + raise ValueError(f"Could not find a match for {domain_prefixes} on {domain}") # this will not work, if we have non-unique prefix stripped domains return f"wg-{cleaned_domain}" diff --git a/wgkex/worker/mqtt_test.py b/wgkex/worker/mqtt_test.py index 1c6bf81..b17d1d6 100644 --- a/wgkex/worker/mqtt_test.py +++ b/wgkex/worker/mqtt_test.py @@ -12,12 +12,12 @@ def _get_config_mock(domains=None, mqtt=None): - test_prefix = "_ffmuc_" + test_prefixes = ["_ffmuc_", "_TEST_PREFIX2_"] config_mock = mock.MagicMock() config_mock.domains = ( - domains if domains is not None else [f"{test_prefix}domain.one"] + domains if domains is not None else [f"{test_prefixes[0]}domain.one"] ) - config_mock.domain_prefix = test_prefix + config_mock.domain_prefixes = test_prefixes if mqtt: config_mock.mqtt = mqtt return config_mock @@ -54,7 +54,6 @@ def test_connect_fails_mqtt_error(self, config_mock, mqtt_mock): with self.assertRaises(ValueError): mqtt.connect(threading.Event()) - @mock.patch.object(mqtt, "get_config") @mock.patch.object(mqtt, "get_connected_peers_count") def test_publish_metrics_loop_success(self, conn_peers_mock, config_mock): From bab86f7c11ed68a18848cecd1f00297f052a0af8 Mon Sep 17 00:00:00 2001 From: DasSkelett Date: Sat, 6 Jan 2024 19:02:54 +0000 Subject: [PATCH 8/8] Make worker cleanup threads more robust, handle peers without handshake time --- README.md | 8 ++++---- wgkex/worker/app.py | 14 +++++++++----- wgkex/worker/app_test.py | 30 +++++++++++++++++++++++------- wgkex/worker/mqtt_test.py | 21 ++++++--------------- wgkex/worker/netlink.py | 9 ++++----- 5 files changed, 46 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 9f673f8..911aeb8 100644 --- a/README.md +++ b/README.md @@ -218,10 +218,10 @@ sudo ip link set vx-welt up ### MQTT topics -Publishing keys broker->worker: `wireguard/{domain}/{worker}` -Publishing metrics worker->broker: `wireguard-metrics/{domain}/{worker}/connected_peers` -Publishing worker status: `wireguard-worker/{worker}/status` -Publishing worker data: `wireguard-worker/{worker}/{domain}/data` +- Publishing keys broker->worker: `wireguard/{domain}/{worker}` +- Publishing metrics worker->broker: `wireguard-metrics/{domain}/{worker}/connected_peers` +- Publishing worker status: `wireguard-worker/{worker}/status` +- Publishing worker data: `wireguard-worker/{worker}/{domain}/data` ## Contact diff --git a/wgkex/worker/app.py b/wgkex/worker/app.py index 9a07d97..432955c 100644 --- a/wgkex/worker/app.py +++ b/wgkex/worker/app.py @@ -39,9 +39,14 @@ class InvalidDomain(Error): def flush_workers(domain: Text) -> None: """Calls peer flush every _CLEANUP_TIME interval.""" while True: - time.sleep(_CLEANUP_TIME) - logger.info(f"Running cleanup task for {domain}") - logger.info("Cleaned up domains: %s", wg_flush_stale_peers(domain)) + try: + time.sleep(_CLEANUP_TIME) + logger.info(f"Running cleanup task for {domain}") + logger.info("Cleaned up domains: %s", wg_flush_stale_peers(domain)) + except Exception as e: + # Don't crash the thread when an exception is encountered + logger.error(f"Exception during cleanup task for {domain}:") + logger.error(e) def clean_up_worker() -> None: @@ -100,8 +105,7 @@ def check_all_domains_unique(domains, prefixes): stripped_domain = domain.split(prefix)[1] if stripped_domain in unique_domains: logger.error( - "We have a non-unique domain here", - domain, + f"Domain {domain} is not unique after stripping the prefix" ) return False unique_domains.append(stripped_domain) diff --git a/wgkex/worker/app_test.py b/wgkex/worker/app_test.py index 0cf525f..04cc6fb 100644 --- a/wgkex/worker/app_test.py +++ b/wgkex/worker/app_test.py @@ -1,4 +1,6 @@ """Unit tests for app.py""" +import threading +from time import sleep import unittest import mock @@ -88,14 +90,28 @@ def test_main_fails_bad_domain(self, connect_mock, config_mock): app.main() connect_mock.assert_not_called() - @mock.patch("time.sleep", side_effect=InterruptedError) + @mock.patch.object(app, "_CLEANUP_TIME", 0) @mock.patch.object(app, "wg_flush_stale_peers") - def test_flush_workers(self, flush_mock, sleep_mock): - """Ensure we fail when domains are badly formatted.""" - flush_mock.return_value = "" - # Infinite loop in flush_workers has no exit value, so test will generate one, and test for that. - with self.assertRaises(InterruptedError): - app.flush_workers("test_domain") + def test_flush_workers_doesnt_throw(self, wg_flush_mock): + """Ensure the flush_workers thread doesn't throw and exit if it encounters an exception.""" + wg_flush_mock.side_effect = AttributeError( + "'NoneType' object has no attribute 'get'" + ) + + thread = threading.Thread( + target=app.flush_workers, args=("dummy_domain",), daemon=True + ) + thread.start() + + i = 0 + while i < 20 and not wg_flush_mock.called: + i += 1 + sleep(0.1) + + wg_flush_mock.assert_called() + # Assert that the thread hasn't crashed and is still running + self.assertTrue(thread.is_alive()) + # If Python would allow it without writing custom signalling, this would be the place to stop the thread again if __name__ == "__main__": diff --git a/wgkex/worker/mqtt_test.py b/wgkex/worker/mqtt_test.py index b17d1d6..8bd6672 100644 --- a/wgkex/worker/mqtt_test.py +++ b/wgkex/worker/mqtt_test.py @@ -91,29 +91,20 @@ def test_publish_metrics_loop_success(self, conn_peers_mock, config_mock): self.assertFalse(thread.is_alive()) - -""" @mock.patch.object(msg_queue, "link_handler") @mock.patch.object(mqtt, "get_config") - def test_on_message_wireguard_success(self, config_mock, link_mock): + def test_on_message_wireguard_success(self, config_mock): # Tests on_message for success. config_mock.return_value = _get_config_mock() - link_mock.return_value = dict(WireGuard="result") mqtt_msg = mock.patch.object(mqtt.mqtt, "MQTTMessage") mqtt_msg.topic = "wireguard/_ffmuc_domain1/gateway" mqtt_msg.payload = b"PUB_KEY" mqtt.on_message_wireguard(None, None, mqtt_msg) - link_mock.assert_has_calls( - [ - mock.call( - msg_queue.WireGuardClient( - public_key="PUB_KEY", domain="domain1", remove=False - ) - ) - ], - any_order=True, - ) + self.assertTrue(mqtt.q.qsize() > 0) + item = mqtt.q.get_nowait() + self.assertEqual(item, ("domain1", "PUB_KEY")) + - @mock.patch.object(msg_queue, "link_handler") +""" @mock.patch.object(msg_queue, "link_handler") @mock.patch.object(mqtt, "get_config") def test_on_message_wireguard_fails_no_domain(self, config_mock, link_mock): # Tests on_message for failure to parse domain. diff --git a/wgkex/worker/netlink.py b/wgkex/worker/netlink.py index a1b5411..366d430 100644 --- a/wgkex/worker/netlink.py +++ b/wgkex/worker/netlink.py @@ -72,13 +72,12 @@ def wg_flush_stale_peers(domain: str) -> List[Dict]: stale_clients = [ stale_client for stale_client in find_stale_wireguard_clients("wg-" + domain) ] - logger.debug("Found stale clients: %s", stale_clients) - logger.info("Searching for stale WireGuard clients.") + logger.debug("Found %s stale clients: %s", len(stale_clients), stale_clients) stale_wireguard_clients = [ WireGuardClient(public_key=stale_client, domain=domain, remove=True) for stale_client in stale_clients ] - logger.debug("Found stable WireGuard clients: %s", stale_wireguard_clients) + logger.debug("Found stale WireGuard clients: %s", stale_wireguard_clients) logger.info("Processing clients.") link_handled = [ link_handler(stale_client) for stale_client in stale_wireguard_clients @@ -205,8 +204,8 @@ def find_stale_wireguard_clients(wg_interface: str) -> List: ret = [ peer.get_attr("WGPEER_A_PUBLIC_KEY").decode("utf-8") for peer in all_peers - if peer.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME").get("tv_sec", int()) - < three_hrs_in_secs + if (hshk_time := peer.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME")) is not None + and hshk_time.get("tv_sec", int()) < three_hrs_in_secs ] return ret