diff --git a/wgkex/worker/BUILD b/wgkex/worker/BUILD index 22e2424..ec4ffaa 100644 --- a/wgkex/worker/BUILD +++ b/wgkex/worker/BUILD @@ -16,6 +16,7 @@ py_library( ], ) + py_test( name = "netlink_test", srcs = ["netlink_test.py"], @@ -54,6 +55,7 @@ py_binary( srcs = ["app.py"], deps = [ ":mqtt", + ":msg_queue", "//wgkex/config:config", "//wgkex/common:logger", ], @@ -67,3 +69,12 @@ py_test( requirement("mock"), ], ) + +py_library( + name = "msg_queue", + srcs = ["msg_queue.py"], + visibility = ["//visibility:public"], + deps = [ + "//wgkex/common:logger", + ], +) \ No newline at end of file diff --git a/wgkex/worker/app.py b/wgkex/worker/app.py index e99575e..a917ed5 100644 --- a/wgkex/worker/app.py +++ b/wgkex/worker/app.py @@ -2,8 +2,8 @@ import wgkex.config.config as config from wgkex.worker import mqtt +from wgkex.worker.msg_queue import watch_queue from wgkex.worker.netlink import wg_flush_stale_peers -import threading import time from wgkex.common import logger from typing import List, Text @@ -59,6 +59,7 @@ def main(): domains = config.load_config().get("domains") if not domains: raise DomainsNotInConfig("Could not locate domains in configuration.") + watch_queue() clean_up_worker(domains) mqtt.connect() diff --git a/wgkex/worker/mqtt.py b/wgkex/worker/mqtt.py index dfa742a..21e749e 100644 --- a/wgkex/worker/mqtt.py +++ b/wgkex/worker/mqtt.py @@ -7,12 +7,13 @@ from wgkex.config.config import load_config import socket import re -from wgkex.worker.netlink import link_handler -from wgkex.worker.netlink import WireGuardClient -from typing import Optional, Dict, List, Any, Union +from typing import Optional, Dict, Any, Union from wgkex.common import logger +import queue +q = queue.Queue() + def fetch_from_config(var: str) -> Optional[Union[Dict[str, str], str]]: """Fetches values from configuration file. @@ -93,13 +94,8 @@ def on_message(client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage) -> ) domain = domain.group(1) logger.debug("Found domain %s", domain) - client = WireGuardClient( - public_key=str(message.payload.decode("utf-8")), - domain=domain, - remove=False, - ) + logger.info( f"Received create message for key {client.public_key} on domain {domain} with lladdr {client.lladdr}" ) - # TODO(ruairi): Verify return type here. - logger.debug(link_handler(client)) + q.put(domain, message.payload.decode("utf-8")) diff --git a/wgkex/worker/msg_queue.py b/wgkex/worker/msg_queue.py new file mode 100644 index 0000000..b164b7a --- /dev/null +++ b/wgkex/worker/msg_queue.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +import queue +import threading +from wgkex.common import logger +from wgkex.worker.netlink import link_handler +from wgkex.worker.netlink import WireGuardClient + +q = queue.Queue() + +def watch_queue() -> None: + """Watches the queue for new messages.""" + threading.Thread(target=worker, daemon=True).start() + while q.empty() != True: + pick_from_queue() + +def pick_from_queue() -> None: + """Picks a message from the queue and processes it.""" + domain, message = q.get() + logger.debug("Processing queue item %s for domain %s", message, domain) + client = WireGuardClient( + public_key=str(message.payload.decode("utf-8")), + domain=domain, + remove=False, + ) + logger.debug(link_handler(client)) + q.task_done() \ No newline at end of file