Skip to content

Commit

Permalink
use proper errors
Browse files Browse the repository at this point in the history
  • Loading branch information
elmarx committed Aug 29, 2023
1 parent a386e1a commit bb0aab6
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 97 deletions.
60 changes: 34 additions & 26 deletions src/dns/node_repository.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use std::net::ToSocketAddrs;
use std::str::from_utf8;

use crate::error;
use crate::error::NodeRepository::{
InvalidNameserver, MissingPubkeyRecord, UnresolvableSocketAddress,
};
use async_trait::async_trait;
use futures::future::join_all;
use rsdns::clients::tokio::Client;
Expand All @@ -9,7 +13,6 @@ use rsdns::constants::Class;
use rsdns::records::data::{Txt, A};
use rsdns::Error;

use crate::error::WgMesh as WgMeshError;
use crate::model::Peer;
use crate::traits::NodeRepository;

Expand All @@ -18,53 +21,58 @@ pub struct DnsNodeRepository {
}

impl DnsNodeRepository {
pub fn from_address(address: &str) -> Self {
pub fn from_address(address: &str) -> Result<Self, error::NodeRepository> {
let nameserver = address
.to_socket_addrs()
.unwrap()
.map_err(|e| UnresolvableSocketAddress(e, address.to_string()))?
.next()
.ok_or("could not get socket-address for dns.quad9.net:53")
.unwrap();
.ok_or_else(|| InvalidNameserver(address.to_string()))?;

let config = ClientConfig::with_nameserver(nameserver);

DnsNodeRepository { config }
Ok(DnsNodeRepository { config })
}
}

#[async_trait(?Send)]
impl NodeRepository for DnsNodeRepository {
async fn list_mesh_nodes(&self, mesh_record: &str) -> Vec<String> {
let mut client = Client::new(self.config.clone()).await.unwrap();
async fn list_mesh_nodes(
&self,
mesh_record: &str,
) -> Result<Vec<String>, error::NodeRepository> {
let mut client = Client::new(self.config.clone()).await?;

let response = client
.query_rrset::<Txt>(mesh_record, Class::In)
.await
.unwrap();
let response = client.query_rrset::<Txt>(mesh_record, Class::In).await?;

response
Ok(response
.rdata
.iter()
.map(|txt| from_utf8(&txt.text).unwrap())
.map(|txt| from_utf8(&txt.text).expect("non-UTF-8 TXT record in mesh-record"))
.map(ToString::to_string)
.collect()
.collect())
}

async fn fetch_peer(&self, node_addr: &str) -> Result<Peer, WgMeshError> {
let mut client = Client::new(self.config.clone()).await.unwrap();
async fn fetch_peer(&self, node_addr: &str) -> Result<Peer, error::NodeRepository> {
let mut client = Client::new(self.config.clone()).await?;

let qname = format!("_wireguard.{node_addr}");

let has_public_ipv4_address = match client.query_rrset::<A>(node_addr, Class::In).await {
Ok(a_query) => Ok(a_query.rdata.iter().any(|a| !a.address.is_private())),
Err(Error::NoAnswer) => Ok(false),
Err(e) => Err(e),
}
.map_err(WgMeshError::Rsdns)?;

let pubkey_query = client.query_rrset::<Txt>(&qname, Class::In).await.unwrap();
let pubkey = from_utf8(&pubkey_query.rdata.first().unwrap().text).unwrap();
let allowed_ips_query = client.query_rrset::<A>(&qname, Class::In).await.unwrap();
}?;

let pubkey_query = client.query_rrset::<Txt>(&qname, Class::In).await?;
let pubkey = from_utf8(
&pubkey_query
.rdata
.first()
.ok_or_else(|| MissingPubkeyRecord(qname.clone()))?
.text,
)
.expect("non-UTF-8 TXT record for peer pubkey");
let allowed_ips_query = client.query_rrset::<A>(&qname, Class::In).await?;
let allowed_ips: Vec<_> = allowed_ips_query
.rdata
.iter()
Expand All @@ -87,17 +95,17 @@ impl NodeRepository for DnsNodeRepository {
})
}

async fn fetch_all_peers(&self, mesh_record: &str) -> Vec<Peer> {
async fn fetch_all_peers(&self, mesh_record: &str) -> Result<Vec<Peer>, error::NodeRepository> {
let all_peers: Result<Vec<_>, _> = join_all(
self.list_mesh_nodes(mesh_record)
.await
.await?
.iter()
.map(|p| self.fetch_peer(p)),
)
.await
.into_iter()
.collect();

all_peers.unwrap()
all_peers
}
}
64 changes: 50 additions & 14 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,72 @@ use wireguard_control::InvalidKey;

#[derive(Error, Debug)]
pub enum WgMesh {
#[error(transparent)]
Mesh(#[from] Mesh),

#[error(transparent)]
NodeRepository(#[from] NodeRepository),

#[error(transparent)]
Routing(#[from] Routing),

#[error(transparent)]
Wireguard(#[from] Wireguard),
}

#[derive(Error, Debug)]
pub enum Mesh {
#[error("This peer is not part of the mesh (public_key {0} not found)")]
PeerNotPartOfMesh(String),
}

#[error("Invalid interface name: {0}")]
InvalidInterfaceName(String),
#[derive(Error, Debug)]
pub enum NodeRepository {
#[error(transparent)]
Rsdns(#[from] RsdnsError),

#[error("Cant turn {1} into a socket address: {0}")]
UnresolvableSocketAddress(std::io::Error, String),

#[error("Can't resolve nameserver: {0}")]
InvalidNameserver(String),

#[error("Missing TXT-Record for {0} (expecting record to hold pubkey)")]
MissingPubkeyRecord(String),
}

#[derive(Error, Debug)]
pub enum Routing {
#[error(transparent)]
NoSuchDevice(std::io::Error),
NetlinkError(#[from] rtnetlink::Error),

#[error("Public key missing")]
NoPubkey,
#[error("Netlink interface {0} not found")]
NoSuchInterface(String),
}

#[error("Invalid public key, could not decode: {0}, given key: {1}")]
InvalidPublicKey(InvalidKey, String),
#[derive(Error, Debug)]
pub enum Wireguard {
#[error("Invalid interface name: {0}")]
InvalidInterfaceName(String),

#[error("No answer from DNS server for {0}")]
NoResolveResponse(String),

#[error("Cant turn {1}:{2} into a socket address: {0}")]
UnresolvableSocketAddress(std::io::Error, String, u16),

#[error("Invalid IP address: {0}")]
InvalidIpAddress(String),

#[error("No answer from DNS server for {0}")]
NoResolveResponse(String),
#[error("Public key missing")]
NoPubkey,

#[error(transparent)]
NoSuchDevice(std::io::Error),

#[error("Failed to apply wireguard config: {0}")]
FailedToApplyConfig(std::io::Error),

#[error(transparent)]
NetlinkError(rtnetlink::Error),

#[error(transparent)]
Rsdns(RsdnsError),
#[error("Invalid public key, could not decode: {0}, given key: {1}")]
InvalidPublicKey(InvalidKey, String),
}
8 changes: 4 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.nth(1)
.expect("please pass record name with the peer list");

let (connection, handle, _) = new_connection().unwrap();
let (connection, handle, _) = new_connection()?;
tokio::spawn(connection);

// TODO: parse address(es) from resolve.conf
let peer_repository = DnsNodeRepository::from_address("dns.quad9.net:53");
let peer_repository = DnsNodeRepository::from_address("dns.quad9.net:53")?;

// TODO: loop/wait for device to be available
let wireguard_device = WireguardImpl::new(&interface_name);
let wireguard_device = WireguardImpl::new(&interface_name)?;
let routing_service = RoutingServiceImpl::new(handle, &interface_name);

let wg_mesh = WgMesh::new(peer_repository, routing_service, wireguard_device);
// TODO: loop/re-execute
wg_mesh.execute(&mesh_record).await;
wg_mesh.execute(&mesh_record).await?;

Ok(())
}
12 changes: 7 additions & 5 deletions src/mesh/filter_peers.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use crate::error::WgMesh as WgMeshError;
use crate::error;
use crate::model::Peer;

/// given a list of all nodes of the mesh, filter out all nodes we need to actually peer with
pub fn filter_peers(interface_pubkey: &str, mut peers: Vec<Peer>) -> Vec<Peer> {
pub fn filter_peers(
interface_pubkey: &str,
mut peers: Vec<Peer>,
) -> Result<Vec<Peer>, error::Mesh> {
let this_peer = peers
.iter()
.position(|peer| peer.public_key == interface_pubkey)
.ok_or(WgMeshError::PeerNotPartOfMesh(interface_pubkey.to_string()))
.unwrap();
.ok_or(error::Mesh::PeerNotPartOfMesh(interface_pubkey.to_string()))?;

let this_peer = peers.swap_remove(this_peer);

Expand All @@ -18,5 +20,5 @@ pub fn filter_peers(interface_pubkey: &str, mut peers: Vec<Peer>) -> Vec<Peer> {
.filter(|peer| !this_peer.has_public_ipv4_address || !peer.has_public_ipv4_address)
.collect();

peers
Ok(peers)
}
15 changes: 9 additions & 6 deletions src/mesh/wgmesh.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::error;
use crate::mesh::filter_peers::filter_peers;
use crate::traits::{NodeRepository, RoutingService, Wireguard};

Expand Down Expand Up @@ -31,23 +32,25 @@ where
}
}

pub async fn execute(self, mesh_record: &str) {
let interface_pubkey = self.wireguard.get_interface_pubkey().unwrap();
pub async fn execute(self, mesh_record: &str) -> Result<(), error::WgMesh> {
let interface_pubkey = self.wireguard.get_interface_pubkey()?;

// first, get a list of all Peers belonging to the mesh
let peers = self.node_repository.fetch_all_peers(mesh_record).await;
let peers = self.node_repository.fetch_all_peers(mesh_record).await?;

// get the list of peers we actually want to peer with
let relevant_peers = filter_peers(&interface_pubkey, peers);
let relevant_peers = filter_peers(&interface_pubkey, peers)?;

// and add them to wireguard
// TODO: get list of current peers and remove, update, or add peers
self.wireguard.replace_peers(relevant_peers.as_slice());
self.wireguard.replace_peers(relevant_peers.as_slice())?;

// …as well as adding direct routes for the "allowed ips"
// TODO: get list of current routes and remove, update, or add routes
self.routing_service
.add_routes(relevant_peers.as_slice())
.await;
.await?;

Ok(())
}
}
27 changes: 16 additions & 11 deletions src/routing/routing_service.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use crate::error;
use async_trait::async_trait;
use error::Routing::NoSuchInterface;
use futures::TryStreamExt;
use ipnet::Ipv4Net;
use netlink_packet_core::ErrorMessage;
use nix::errno::Errno;
use rtnetlink::Error::NetlinkError;
use rtnetlink::RouteHandle;

use crate::error::{WgMesh as WgMeshError, WgMesh};
use crate::model::Peer;
use crate::traits::RoutingService;

Expand All @@ -30,25 +31,28 @@ impl RoutingServiceImpl {

#[async_trait(?Send)]
impl RoutingService for RoutingServiceImpl {
async fn add_routes(&self, peers: &[Peer]) {
async fn add_routes(&self, peers: &[Peer]) -> Result<(), error::Routing> {
let mut link_handle = self.handle.link();

let interface = link_handle
.get()
.match_name(self.interface_name.clone())
.execute()
.try_next()
.await
.unwrap()
.ok_or("no such interface")
.unwrap();
.await?;

let interface =
interface.ok_or_else(|| NoSuchInterface(self.interface_name.to_string()))?;

let route_add_requests: Vec<_> = peers
.iter()
.flat_map(|peer| {
peer.allowed_ips
.iter()
.map(|i| i.parse::<Ipv4Net>().unwrap())
.map(|i| {
i.parse::<Ipv4Net>()
.expect("got invalid IPv4 address for Peer's allowed_ips")
})
.map(|ip| {
self.route_handle
.add()
Expand All @@ -63,16 +67,17 @@ impl RoutingService for RoutingServiceImpl {
for r in route_add_requests {
r.execute()
.await
.or_else(|e| -> Result<(), WgMeshError> {
.or_else(|e| -> Result<(), error::Routing> {
match e {
// TODO: this is not very elegant, so better check in the first place if something has to be added or not
NetlinkError(ErrorMessage {
code: Some(code), ..
}) if i32::from(-code) == Errno::EEXIST as i32 => Ok(()),
err => Err(WgMesh::NetlinkError(err)),
err => Err(error::Routing::NetlinkError(err)),
}
})
.unwrap();
})?;
}

Ok(())
}
}
Loading

0 comments on commit bb0aab6

Please sign in to comment.