diff --git a/wgkex.yaml.example b/wgkex.yaml.example index 30340fe..afef1e6 100644 --- a/wgkex.yaml.example +++ b/wgkex.yaml.example @@ -1,4 +1,4 @@ -# [broker] The domains that should be accepted by clients and for which matching WireGuard interfaces exist +# [broker, worker] The domains that should be accepted by clients and for which matching WireGuard interfaces exist domains: - ffmuc_muc_cty - ffmuc_muc_nord diff --git a/wgkex/broker/app.py b/wgkex/broker/app.py index 1d753ff..e8122cc 100644 --- a/wgkex/broker/app.py +++ b/wgkex/broker/app.py @@ -131,8 +131,15 @@ def wg_api_v2_key_exchange() -> Tuple[Response | Dict, int]: } }, 400 + # Update number of peers locally to interpolate data between MQTT updates from the worker + # TODO fix data race + current_peers_domain = ( + worker_metrics.get(best_worker) + .get_domain_metrics(domain) + .get(CONNECTED_PEERS_METRIC, 0) + ) worker_metrics.update( - best_worker, domain, CONNECTED_PEERS_METRIC, current_peers + 1 + best_worker, domain, CONNECTED_PEERS_METRIC, current_peers_domain + 1 ) logger.debug( f"Chose worker {best_worker} with {current_peers} connected clients ({diff})" @@ -200,10 +207,10 @@ def handle_mqtt_message_status( _, worker, _ = message.topic.split("/", 2) status = int(message.payload) - if status < 1: + if status < 1 and worker_metrics.get(worker).is_online(): logger.warning(f"Marking worker as offline: {worker}") worker_metrics.set_offline(worker) - else: + elif status >= 1 and not worker_metrics.get(worker).is_online(): logger.warning(f"Marking worker as online: {worker}") worker_metrics.set_online(worker) diff --git a/wgkex/broker/metrics.py b/wgkex/broker/metrics.py index a2e2893..73a27c0 100644 --- a/wgkex/broker/metrics.py +++ b/wgkex/broker/metrics.py @@ -34,10 +34,23 @@ def set_metric(self, domain: str, metric: str, value: Any) -> None: else: self.domain_data[domain] = {metric: value} + def get_peer_count(self) -> int: + """Returns the sum of connected peers on this worker over all domains""" + total = 0 + for data in self.domain_data.values(): + total += max( + data.get(CONNECTED_PEERS_METRIC, 0), + 0, + ) + + return total + @dataclasses.dataclass class WorkerMetricsCollection: - """A container for all worker metrics""" + """A container for all worker metrics + # TODO make threadsafe / fix data races + """ # worker -> WorkerMetrics data: Dict[str, WorkerMetrics] = dataclasses.field(default_factory=dict) @@ -68,7 +81,8 @@ def set_offline(self, worker: str) -> None: if worker in self.data: self.data[worker].online = False - def get_total_peers(self) -> int: + def get_total_peer_count(self) -> int: + """Returns the sum of connected peers over all workers and domains""" total = 0 for worker in self.data: worker_data = self.data.get(worker) @@ -96,22 +110,23 @@ def get_best_worker(self, domain: str) -> Tuple[Optional[str], int, int]: # Map metrics to a list of (target diff, peer count, worker) tuples for online workers peers_worker_tuples = [] - total_peers = self.get_total_peers() + total_peers = self.get_total_peer_count() worker_cfg = config.get_config().workers for wm in self.data.values(): if not wm.is_online(domain): continue - peers = wm.get_domain_metrics(domain).get(CONNECTED_PEERS_METRIC) + peers = wm.get_peer_count() rel_weight = worker_cfg.relative_worker_weight(wm.worker) target = rel_weight * total_peers diff = peers - target logger.debug( - f"Worker {wm.worker}: rel weight {rel_weight}, target {target} (total {total_peers}), diff {diff}" + f"Worker candidate {wm.worker}: current {peers}, target {target} (total {total_peers}, rel weight {rel_weight}), diff {diff}" ) peers_worker_tuples.append((diff, peers, wm.worker)) + # Sort by diff (ascending), workers with most peers missing to target are sorted first peers_worker_tuples = sorted(peers_worker_tuples, key=itemgetter(0)) if len(peers_worker_tuples) > 0: diff --git a/wgkex/broker/metrics_test.py b/wgkex/broker/metrics_test.py index 520e6a9..97fc138 100644 --- a/wgkex/broker/metrics_test.py +++ b/wgkex/broker/metrics_test.py @@ -83,6 +83,26 @@ def test_get_best_worker_returns_best(self, config_mock): self.assertEqual(diff, -20) # 19-(1*(20+19)) self.assertEqual(connected, 19) + @mock.patch("wgkex.broker.metrics.config.get_config", autospec=True) + def test_get_best_worker_returns_best_imbalanced_domains(self, config_mock): + """Verify get_best_worker returns the worker with overall least connected clients even if it has more clients on this domain.""" + test_config = mock.MagicMock(spec=config.Config) + test_config.workers = config.Workers.from_dict({}) + config_mock.return_value = test_config + + worker_metrics = WorkerMetricsCollection() + worker_metrics.update("1", "domain1", "connected_peers", 25) + worker_metrics.update("1", "domain2", "connected_peers", 5) + worker_metrics.update("2", "domain1", "connected_peers", 20) + worker_metrics.update("2", "domain2", "connected_peers", 20) + worker_metrics.set_online("1") + worker_metrics.set_online("2") + + (worker, diff, connected) = worker_metrics.get_best_worker("domain1") + self.assertEqual(worker, "1") + self.assertEqual(diff, -40) # 30-(1*(25+5+20+20)) + self.assertEqual(connected, 30) + @mock.patch("wgkex.broker.metrics.config.get_config", autospec=True) def test_get_best_worker_weighted_returns_best(self, config_mock): """Verify get_best_worker returns the worker with least client differential for weighted workers.""" diff --git a/wgkex/worker/app_test.py b/wgkex/worker/app_test.py index 04cc6fb..efe774d 100644 --- a/wgkex/worker/app_test.py +++ b/wgkex/worker/app_test.py @@ -90,7 +90,7 @@ def test_main_fails_bad_domain(self, connect_mock, config_mock): app.main() connect_mock.assert_not_called() - @mock.patch.object(app, "_CLEANUP_TIME", 0) + @mock.patch.object(app, "_CLEANUP_TIME", 1) @mock.patch.object(app, "wg_flush_stale_peers") def test_flush_workers_doesnt_throw(self, wg_flush_mock): """Ensure the flush_workers thread doesn't throw and exit if it encounters an exception.""" diff --git a/wgkex/worker/mqtt.py b/wgkex/worker/mqtt.py index caf7011..d5941cd 100644 --- a/wgkex/worker/mqtt.py +++ b/wgkex/worker/mqtt.py @@ -127,6 +127,7 @@ def on_connect(client: mqtt.Client, userdata: Any, flags, rc) -> None: logger.info(f"Subscribing to topic {topic}") client.subscribe(topic) + for domain in domains: # Publish worker data (WG pubkeys, ports, local addresses) iface = wg_interface_name(domain) if iface: diff --git a/wgkex/worker/mqtt_test.py b/wgkex/worker/mqtt_test.py index 8bd6672..3ea186e 100644 --- a/wgkex/worker/mqtt_test.py +++ b/wgkex/worker/mqtt_test.py @@ -54,6 +54,42 @@ def test_connect_fails_mqtt_error(self, config_mock, mqtt_mock): with self.assertRaises(ValueError): mqtt.connect(threading.Event()) + @mock.patch.object(mqtt.mqtt, "Client") + @mock.patch.object(mqtt, "get_config") + @mock.patch.object(mqtt, "get_device_data") + def test_on_connect_subscribes( + self, get_device_data_mock, config_mock, mqtt_client_mock + ): + """Test that the on_connect callback correctly subscribes to all domains and pushes device data""" + config_mqtt_mock = mock.MagicMock() + config_mqtt_mock.broker_url = "some_url" + config_mqtt_mock.broker_port = 1833 + config_mqtt_mock.keepalive = False + config = _get_config_mock(mqtt=config_mqtt_mock) + config.external_name = None + config_mock.return_value = config + get_device_data_mock.return_value = (51820, "456asdf=", "fe80::1") + + hostname = socket.gethostname() + + mqtt.on_connect(mqtt.mqtt.Client(), None, None, 0) + + mqtt_client_mock.assert_has_calls( + [ + mock.call().subscribe("wireguard/_ffmuc_domain.one/+"), + mock.call().publish( + f"wireguard-worker/{hostname}/_ffmuc_domain.one/data", + '{"ExternalAddress": "%s", "Port": 51820, "PublicKey": "456asdf=", "LinkAddress": "fe80::1"}' + % hostname, + qos=1, + retain=True, + ), + mock.call().publish( + f"wireguard-worker/{hostname}/status", 1, qos=1, retain=True + ), + ] + ) + @mock.patch.object(mqtt, "get_config") @mock.patch.object(mqtt, "get_connected_peers_count") def test_publish_metrics_loop_success(self, conn_peers_mock, config_mock): diff --git a/wgkex/worker/netlink.py b/wgkex/worker/netlink.py index 366d430..bb413f1 100644 --- a/wgkex/worker/netlink.py +++ b/wgkex/worker/netlink.py @@ -231,11 +231,10 @@ def get_connected_peers_count(wg_interface: str) -> int: if peers: for peer in peers: if ( - peer.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME").get( - "tv_sec", int() - ) - > three_mins_ago_in_secs - ): + hshk_time := peer.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME") + ) is not None and hshk_time.get( + "tv_sec", int() + ) > three_mins_ago_in_secs: count += 1 return count @@ -251,7 +250,7 @@ def get_device_data(wg_interface: str) -> Tuple[int, str, str]: # The listening port, public key, and local IP address of the WireGuard interface. """ logger.info("Reading data from interface %s.", wg_interface) - with pyroute2.WireGuard() as wg, pyroute2.NDB() as ndb: + with pyroute2.WireGuard() as wg, pyroute2.IPRoute() as ipr: msgs = wg.info(wg_interface) logger.debug("Got infos for interface data: %s.", msgs) if len(msgs) > 1: @@ -262,7 +261,11 @@ def get_device_data(wg_interface: str) -> Tuple[int, str, str]: port = int(info.get_attr("WGDEVICE_A_LISTEN_PORT")) public_key = info.get_attr("WGDEVICE_A_PUBLIC_KEY").decode("ascii") - link_address = ndb.interfaces[wg_interface].ipaddr[0].get("address") + + # Get link address using IPRoute + ipr_link = ipr.link_lookup(ifname=wg_interface)[0] + msgs = ipr.get_addr(index=ipr_link) + link_address = msgs[0].get_attr("IFA_ADDRESS") logger.debug( "Interface data: port '%s', public key '%s', link address '%s",