Skip to content

Commit

Permalink
Make worker cleanup threads more robust, handle peers without handsha…
Browse files Browse the repository at this point in the history
…ke time
  • Loading branch information
DasSkelett committed Jan 7, 2024
1 parent 7aa9967 commit bab86f7
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 36 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,10 @@ sudo ip link set vx-welt up

### MQTT topics

Publishing keys broker->worker: `wireguard/{domain}/{worker}`
Publishing metrics worker->broker: `wireguard-metrics/{domain}/{worker}/connected_peers`
Publishing worker status: `wireguard-worker/{worker}/status`
Publishing worker data: `wireguard-worker/{worker}/{domain}/data`
- Publishing keys broker->worker: `wireguard/{domain}/{worker}`
- Publishing metrics worker->broker: `wireguard-metrics/{domain}/{worker}/connected_peers`
- Publishing worker status: `wireguard-worker/{worker}/status`
- Publishing worker data: `wireguard-worker/{worker}/{domain}/data`

## Contact

Expand Down
14 changes: 9 additions & 5 deletions wgkex/worker/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,14 @@ class InvalidDomain(Error):
def flush_workers(domain: Text) -> None:
"""Calls peer flush every _CLEANUP_TIME interval."""
while True:
time.sleep(_CLEANUP_TIME)
logger.info(f"Running cleanup task for {domain}")
logger.info("Cleaned up domains: %s", wg_flush_stale_peers(domain))
try:
time.sleep(_CLEANUP_TIME)
logger.info(f"Running cleanup task for {domain}")
logger.info("Cleaned up domains: %s", wg_flush_stale_peers(domain))
except Exception as e:
# Don't crash the thread when an exception is encountered
logger.error(f"Exception during cleanup task for {domain}:")
logger.error(e)


def clean_up_worker() -> None:
Expand Down Expand Up @@ -100,8 +105,7 @@ def check_all_domains_unique(domains, prefixes):
stripped_domain = domain.split(prefix)[1]
if stripped_domain in unique_domains:
logger.error(
"We have a non-unique domain here",
domain,
f"Domain {domain} is not unique after stripping the prefix"
)
return False
unique_domains.append(stripped_domain)
Expand Down
30 changes: 23 additions & 7 deletions wgkex/worker/app_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Unit tests for app.py"""
import threading
from time import sleep
import unittest
import mock

Expand Down Expand Up @@ -88,14 +90,28 @@ def test_main_fails_bad_domain(self, connect_mock, config_mock):
app.main()
connect_mock.assert_not_called()

@mock.patch("time.sleep", side_effect=InterruptedError)
@mock.patch.object(app, "_CLEANUP_TIME", 0)
@mock.patch.object(app, "wg_flush_stale_peers")
def test_flush_workers(self, flush_mock, sleep_mock):
"""Ensure we fail when domains are badly formatted."""
flush_mock.return_value = ""
# Infinite loop in flush_workers has no exit value, so test will generate one, and test for that.
with self.assertRaises(InterruptedError):
app.flush_workers("test_domain")
def test_flush_workers_doesnt_throw(self, wg_flush_mock):
"""Ensure the flush_workers thread doesn't throw and exit if it encounters an exception."""
wg_flush_mock.side_effect = AttributeError(
"'NoneType' object has no attribute 'get'"
)

thread = threading.Thread(
target=app.flush_workers, args=("dummy_domain",), daemon=True
)
thread.start()

i = 0
while i < 20 and not wg_flush_mock.called:
i += 1
sleep(0.1)

wg_flush_mock.assert_called()
# Assert that the thread hasn't crashed and is still running
self.assertTrue(thread.is_alive())
# If Python would allow it without writing custom signalling, this would be the place to stop the thread again


if __name__ == "__main__":
Expand Down
21 changes: 6 additions & 15 deletions wgkex/worker/mqtt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,29 +91,20 @@ def test_publish_metrics_loop_success(self, conn_peers_mock, config_mock):

self.assertFalse(thread.is_alive())


""" @mock.patch.object(msg_queue, "link_handler")
@mock.patch.object(mqtt, "get_config")
def test_on_message_wireguard_success(self, config_mock, link_mock):
def test_on_message_wireguard_success(self, config_mock):
# Tests on_message for success.
config_mock.return_value = _get_config_mock()
link_mock.return_value = dict(WireGuard="result")
mqtt_msg = mock.patch.object(mqtt.mqtt, "MQTTMessage")
mqtt_msg.topic = "wireguard/_ffmuc_domain1/gateway"
mqtt_msg.payload = b"PUB_KEY"
mqtt.on_message_wireguard(None, None, mqtt_msg)
link_mock.assert_has_calls(
[
mock.call(
msg_queue.WireGuardClient(
public_key="PUB_KEY", domain="domain1", remove=False
)
)
],
any_order=True,
)
self.assertTrue(mqtt.q.qsize() > 0)
item = mqtt.q.get_nowait()
self.assertEqual(item, ("domain1", "PUB_KEY"))


@mock.patch.object(msg_queue, "link_handler")
""" @mock.patch.object(msg_queue, "link_handler")
@mock.patch.object(mqtt, "get_config")
def test_on_message_wireguard_fails_no_domain(self, config_mock, link_mock):
# Tests on_message for failure to parse domain.
Expand Down
9 changes: 4 additions & 5 deletions wgkex/worker/netlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,12 @@ def wg_flush_stale_peers(domain: str) -> List[Dict]:
stale_clients = [
stale_client for stale_client in find_stale_wireguard_clients("wg-" + domain)
]
logger.debug("Found stale clients: %s", stale_clients)
logger.info("Searching for stale WireGuard clients.")
logger.debug("Found %s stale clients: %s", len(stale_clients), stale_clients)
stale_wireguard_clients = [
WireGuardClient(public_key=stale_client, domain=domain, remove=True)
for stale_client in stale_clients
]
logger.debug("Found stable WireGuard clients: %s", stale_wireguard_clients)
logger.debug("Found stale WireGuard clients: %s", stale_wireguard_clients)
logger.info("Processing clients.")
link_handled = [
link_handler(stale_client) for stale_client in stale_wireguard_clients
Expand Down Expand Up @@ -205,8 +204,8 @@ def find_stale_wireguard_clients(wg_interface: str) -> List:
ret = [
peer.get_attr("WGPEER_A_PUBLIC_KEY").decode("utf-8")
for peer in all_peers
if peer.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME").get("tv_sec", int())
< three_hrs_in_secs
if (hshk_time := peer.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME")) is not None
and hshk_time.get("tv_sec", int()) < three_hrs_in_secs
]
return ret

Expand Down

0 comments on commit bab86f7

Please sign in to comment.