Skip to content

Commit

Permalink
Some fixes for the loadbalancing changes
Browse files Browse the repository at this point in the history
- [worker] Use pyroute2.IPRoute instead of .NDB to get wg link address, as NDB() takes 20 seconds to instantiate
- [worker] Fix get_connected_peers_coun() for peers without handshake time
- [broker] Use total_peers count to correctly calculate diff to expected peers
- [broker] Don't update worker status on MQTT messages if it hasn't actually changed
  • Loading branch information
DasSkelett committed Jan 9, 2024
1 parent bab86f7 commit 4f5dc66
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 14 deletions.
13 changes: 10 additions & 3 deletions wgkex/broker/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,15 @@ def wg_api_v2_key_exchange() -> Tuple[Response | Dict, int]:
}
}, 400

# Update number of peers locally to interpolate data between MQTT updates from the worker
# TODO fix data race
current_peers_domain = (
worker_metrics.get(best_worker)
.get_domain_metrics(domain)
.get(CONNECTED_PEERS_METRIC, 0)
)
worker_metrics.update(
best_worker, domain, CONNECTED_PEERS_METRIC, current_peers + 1
best_worker, domain, CONNECTED_PEERS_METRIC, current_peers_domain + 1
)
logger.debug(
f"Chose worker {best_worker} with {current_peers} connected clients ({diff})"
Expand Down Expand Up @@ -200,10 +207,10 @@ def handle_mqtt_message_status(
_, worker, _ = message.topic.split("/", 2)

status = int(message.payload)
if status < 1:
if status < 1 and worker_metrics.get(worker).is_online():
logger.warning(f"Marking worker as offline: {worker}")
worker_metrics.set_offline(worker)
else:
elif status >= 1 and not worker_metrics.get(worker).is_online():
logger.warning(f"Marking worker as online: {worker}")
worker_metrics.set_online(worker)

Expand Down
10 changes: 6 additions & 4 deletions wgkex/broker/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def set_metric(self, domain: str, metric: str, value: Any) -> None:

@dataclasses.dataclass
class WorkerMetricsCollection:
"""A container for all worker metrics"""
"""A container for all worker metrics
# TODO make threadsafe / fix data races
"""

# worker -> WorkerMetrics
data: Dict[str, WorkerMetrics] = dataclasses.field(default_factory=dict)
Expand Down Expand Up @@ -103,15 +105,15 @@ def get_best_worker(self, domain: str) -> Tuple[Optional[str], int, int]:
if not wm.is_online(domain):
continue

peers = wm.get_domain_metrics(domain).get(CONNECTED_PEERS_METRIC)
rel_weight = worker_cfg.relative_worker_weight(wm.worker)
target = rel_weight * total_peers
diff = peers - target
diff = total_peers - target
logger.debug(
f"Worker {wm.worker}: rel weight {rel_weight}, target {target} (total {total_peers}), diff {diff}"
)
peers_worker_tuples.append((diff, peers, wm.worker))
peers_worker_tuples.append((diff, total_peers, wm.worker))

# Sort by diff (ascending), workers with most peers missing to target are sorted first
peers_worker_tuples = sorted(peers_worker_tuples, key=itemgetter(0))

if len(peers_worker_tuples) > 0:
Expand Down
1 change: 1 addition & 0 deletions wgkex/worker/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def on_connect(client: mqtt.Client, userdata: Any, flags, rc) -> None:
logger.info(f"Subscribing to topic {topic}")
client.subscribe(topic)

for domain in domains:
# Publish worker data (WG pubkeys, ports, local addresses)
iface = wg_interface_name(domain)
if iface:
Expand Down
17 changes: 10 additions & 7 deletions wgkex/worker/netlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,10 @@ def get_connected_peers_count(wg_interface: str) -> int:
if peers:
for peer in peers:
if (
peer.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME").get(
"tv_sec", int()
)
> three_mins_ago_in_secs
):
hshk_time := peer.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME")
) is not None and hshk_time.get(
"tv_sec", int()
) > three_mins_ago_in_secs:
count += 1

return count
Expand All @@ -251,7 +250,7 @@ def get_device_data(wg_interface: str) -> Tuple[int, str, str]:
# The listening port, public key, and local IP address of the WireGuard interface.
"""
logger.info("Reading data from interface %s.", wg_interface)
with pyroute2.WireGuard() as wg, pyroute2.NDB() as ndb:
with pyroute2.WireGuard() as wg, pyroute2.IPRoute() as ipr:
msgs = wg.info(wg_interface)
logger.debug("Got infos for interface data: %s.", msgs)
if len(msgs) > 1:
Expand All @@ -262,7 +261,11 @@ def get_device_data(wg_interface: str) -> Tuple[int, str, str]:

port = int(info.get_attr("WGDEVICE_A_LISTEN_PORT"))
public_key = info.get_attr("WGDEVICE_A_PUBLIC_KEY").decode("ascii")
link_address = ndb.interfaces[wg_interface].ipaddr[0].get("address")

# Get link address using IPRoute
ipr_link = ipr.link_lookup(ifname=wg_interface)[0]
msgs = ipr.get_addr(index=ipr_link)
link_address = msgs[0].get_attr("IFA_ADDRESS")

logger.debug(
"Interface data: port '%s', public key '%s', link address '%s",
Expand Down

0 comments on commit 4f5dc66

Please sign in to comment.