Skip to content

Commit

Permalink
Add queues for netlink messages
Browse files Browse the repository at this point in the history
This resolves #103
  • Loading branch information
awlx committed Sep 18, 2023
1 parent 4a9436d commit 250fb3f
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 11 deletions.
11 changes: 11 additions & 0 deletions wgkex/worker/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ py_library(
],
)


py_test(
name = "netlink_test",
srcs = ["netlink_test.py"],
Expand Down Expand Up @@ -54,6 +55,7 @@ py_binary(
srcs = ["app.py"],
deps = [
":mqtt",
":msg_queue",
"//wgkex/config:config",
"//wgkex/common:logger",
],
Expand All @@ -67,3 +69,12 @@ py_test(
requirement("mock"),
],
)

py_library(
name = "msg_queue",
srcs = ["msg_queue.py"],
visibility = ["//visibility:public"],
deps = [
"//wgkex/common:logger",
],
)
3 changes: 2 additions & 1 deletion wgkex/worker/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import wgkex.config.config as config
from wgkex.worker import mqtt
from wgkex.worker.msg_queue import watch_queue
from wgkex.worker.netlink import wg_flush_stale_peers
import threading
import time
from wgkex.common import logger
from typing import List, Text
Expand Down Expand Up @@ -59,6 +59,7 @@ def main():
domains = config.load_config().get("domains")
if not domains:
raise DomainsNotInConfig("Could not locate domains in configuration.")
watch_queue()
clean_up_worker(domains)
mqtt.connect()

Expand Down
16 changes: 6 additions & 10 deletions wgkex/worker/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from wgkex.config.config import load_config
import socket
import re
from wgkex.worker.netlink import link_handler
from wgkex.worker.netlink import WireGuardClient
from typing import Optional, Dict, List, Any, Union
from typing import Optional, Dict, Any, Union
from wgkex.common import logger
import queue


q = queue.Queue()

def fetch_from_config(var: str) -> Optional[Union[Dict[str, str], str]]:
"""Fetches values from configuration file.
Expand Down Expand Up @@ -93,13 +94,8 @@ def on_message(client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage) ->
)
domain = domain.group(1)
logger.debug("Found domain %s", domain)
client = WireGuardClient(
public_key=str(message.payload.decode("utf-8")),
domain=domain,
remove=False,
)

logger.info(
f"Received create message for key {client.public_key} on domain {domain} with lladdr {client.lladdr}"
)
# TODO(ruairi): Verify return type here.
logger.debug(link_handler(client))
q.put(domain, message.payload.decode("utf-8"))
26 changes: 26 additions & 0 deletions wgkex/worker/msg_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/usr/bin/env python3
import queue
import threading
from wgkex.common import logger
from wgkex.worker.netlink import link_handler
from wgkex.worker.netlink import WireGuardClient

q = queue.Queue()

def watch_queue() -> None:
"""Watches the queue for new messages."""
threading.Thread(target=worker, daemon=True).start()
while q.empty() != True:
pick_from_queue()

def pick_from_queue() -> None:
"""Picks a message from the queue and processes it."""
domain, message = q.get()
logger.debug("Processing queue item %s for domain %s", message, domain)
client = WireGuardClient(
public_key=str(message.payload.decode("utf-8")),
domain=domain,
remove=False,
)
logger.debug(link_handler(client))
q.task_done()

0 comments on commit 250fb3f

Please sign in to comment.