Skip to content

Commit

Permalink
Make worker cleanup threads more robust, handle peers without handsha…
Browse files Browse the repository at this point in the history
…ke time
  • Loading branch information
DasSkelett committed Jan 6, 2024
1 parent 7a17a2c commit 0d53e29
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 19 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,10 @@ sudo ip link set vx-welt up

### MQTT topics

Publishing keys broker->worker: `wireguard/{domain}/{worker}`
Publishing metrics worker->broker: `wireguard-metrics/{domain}/{worker}/connected_peers`
Publishing worker status: `wireguard-worker/{worker}/status`
Publishing worker data: `wireguard-worker/{worker}/{domain}/data`
- Publishing keys broker->worker: `wireguard/{domain}/{worker}`
- Publishing metrics worker->broker: `wireguard-metrics/{domain}/{worker}/connected_peers`
- Publishing worker status: `wireguard-worker/{worker}/status`
- Publishing worker data: `wireguard-worker/{worker}/{domain}/data`

## Contact

Expand Down
13 changes: 9 additions & 4 deletions wgkex/worker/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,14 @@ class InvalidDomain(Error):
def flush_workers(domain: Text) -> None:
"""Calls peer flush every _CLEANUP_TIME interval."""
while True:
time.sleep(_CLEANUP_TIME)
logger.info(f"Running cleanup task for {domain}")
logger.info("Cleaned up domains: %s", wg_flush_stale_peers(domain))
try:
time.sleep(_CLEANUP_TIME)
logger.info(f"Running cleanup task for {domain}")
logger.info("Cleaned up domains: %s", wg_flush_stale_peers(domain))
except Exception as e:
# Don't crash the thread when an exception is encountered
logger.error(f"Exception during cleanup task for {domain}:")
logger.error(e)


def clean_up_worker() -> None:
Expand Down Expand Up @@ -100,7 +105,7 @@ def check_all_domains_unique(domains, prefixes):
stripped_domain = domain.split(prefix)[1]
if stripped_domain in unique_domains:
logger.error(
"We have a non-unique domain here",
f"Domain {domain} is not unique after stripping the prefix",
domain,
)
return False
Expand Down
28 changes: 22 additions & 6 deletions wgkex/worker/app_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Unit tests for app.py"""
import threading
from time import sleep
import unittest
import mock

Expand Down Expand Up @@ -88,14 +90,28 @@ def test_main_fails_bad_domain(self, connect_mock, config_mock):
app.main()
connect_mock.assert_not_called()

@mock.patch("time.sleep", side_effect=InterruptedError)
@mock.patch.object(app, "_CLEANUP_TIME", 0)
@mock.patch.object(app, "wg_flush_stale_peers")
def test_flush_workers(self, flush_mock, sleep_mock):
def test_flush_workers_doesnt_throw(self, wg_flush_mock, cleanup_time_mock):
"""Ensure we fail when domains are badly formatted."""
flush_mock.return_value = ""
# Infinite loop in flush_workers has no exit value, so test will generate one, and test for that.
with self.assertRaises(InterruptedError):
app.flush_workers("test_domain")
wg_flush_mock.side_effect = AttributeError(
"'NoneType' object has no attribute 'get'"
)

thread = threading.Thread(
target=app.flush_workers, args=("dummy_domain"), daemon=True
)
thread.start()

i = 0
while i < 20 and not wg_flush_mock.called:
i += 1
sleep(0.1)

wg_flush_mock.assert_called()
# Assert that the thread hasn't crashed and is still running
self.assertTrue(thread.is_alive())
# If Python would allow it without writing custom signalling, this would be the place to stop the thread again


if __name__ == "__main__":
Expand Down
9 changes: 4 additions & 5 deletions wgkex/worker/netlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,12 @@ def wg_flush_stale_peers(domain: str) -> List[Dict]:
stale_clients = [
stale_client for stale_client in find_stale_wireguard_clients("wg-" + domain)
]
logger.debug("Found stale clients: %s", stale_clients)
logger.info("Searching for stale WireGuard clients.")
logger.debug("Found %s stale clients: %s", len(stale_clients), stale_clients)
stale_wireguard_clients = [
WireGuardClient(public_key=stale_client, domain=domain, remove=True)
for stale_client in stale_clients
]
logger.debug("Found stable WireGuard clients: %s", stale_wireguard_clients)
logger.debug("Found stale WireGuard clients: %s", stale_wireguard_clients)
logger.info("Processing clients.")
link_handled = [
link_handler(stale_client) for stale_client in stale_wireguard_clients
Expand Down Expand Up @@ -205,8 +204,8 @@ def find_stale_wireguard_clients(wg_interface: str) -> List:
ret = [
peer.get_attr("WGPEER_A_PUBLIC_KEY").decode("utf-8")
for peer in all_peers
if peer.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME").get("tv_sec", int())
< three_hrs_in_secs
if (hshk_time := peer.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME")) is not None
and hshk_time.get("tv_sec", int()) < three_hrs_in_secs
]
return ret

Expand Down

0 comments on commit 0d53e29

Please sign in to comment.