From 077a5df170ff87c22e2402b93433bfd3148967b6 Mon Sep 17 00:00:00 2001 From: ntqbit Date: Wed, 8 Jan 2025 10:40:16 +0000 Subject: [PATCH] Add WireGuard handshake key logging --- Cargo.lock | 2 +- Cargo.toml | 2 +- mitmproxy-rs/src/server/wireguard.rs | 21 ++++++++++++++++++++- src/packet_sources/wireguard.rs | 24 +++++++++++++++++++++++- 4 files changed, 45 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 164100a9..a50fbfe6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -446,7 +446,7 @@ dependencies = [ [[package]] name = "boringtun" version = "0.6.0" -source = "git+https://github.com/cloudflare/boringtun?rev=e3252d9c4f4c8fc628995330f45369effd4660a1#e3252d9c4f4c8fc628995330f45369effd4660a1" +source = "git+https://github.com/ntqbit/boringtun?branch=feature/key_logger#0437b44b3d34bf74cd364f51f62c8ac3f4b6962d" dependencies = [ "aead", "base64 0.13.1", diff --git a/Cargo.toml b/Cargo.toml index 82ea5f7b..1f8d8592 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -71,7 +71,7 @@ socket2 = "0.5.8" [patch.crates-io] # tokio = { path = "../tokio/tokio" } -boringtun = { git = 'https://github.com/cloudflare/boringtun', rev = 'e3252d9c4f4c8fc628995330f45369effd4660a1' } +boringtun = { git = "https://github.com/ntqbit/boringtun", branch = "feature/key_logger"} [target.'cfg(windows)'.dependencies.windows] version = "0.58.0" diff --git a/mitmproxy-rs/src/server/wireguard.rs b/mitmproxy-rs/src/server/wireguard.rs index 1d5ea8b5..76f8fb6f 100644 --- a/mitmproxy-rs/src/server/wireguard.rs +++ b/mitmproxy-rs/src/server/wireguard.rs @@ -6,7 +6,7 @@ use mitmproxy::packet_sources::wireguard::WireGuardConf; use pyo3::prelude::*; -use boringtun::x25519::PublicKey; +use boringtun::{noise::keys_logger::KeyLogger, x25519::PublicKey}; use crate::server::base::Server; @@ -60,7 +60,9 @@ impl WireGuardServer { /// - `peer_public_keys`: List of public X25519 keys for WireGuard peers as base64-encoded strings. /// - `handle_tcp_stream`: An async function that will be called for each new TCP `Stream`. /// - `handle_udp_stream`: An async function that will be called for each new UDP `Stream`. +/// - `key_logger`: An optional function that will be called for each key when handshake is completed. #[pyfunction] +#[pyo3(signature = (host, port, private_key, peer_public_keys, handle_tcp_stream, handle_udp_stream, key_logger=None))] pub fn start_wireguard_server( py: Python<'_>, host: String, @@ -69,20 +71,37 @@ pub fn start_wireguard_server( peer_public_keys: Vec, handle_tcp_stream: PyObject, handle_udp_stream: PyObject, + key_logger: Option, ) -> PyResult> { let private_key = string_to_key(private_key)?; let peer_public_keys = peer_public_keys .into_iter() .map(string_to_key) .collect::>>()?; + + let key_logger = key_logger + .map(|key_logger| Box::new(PythonKeyLogger(key_logger)) as Box); + let conf = WireGuardConf { host, port, private_key, peer_public_keys, + key_logger, }; pyo3_async_runtimes::tokio::future_into_py(py, async move { let (server, local_addr) = Server::init(conf, handle_tcp_stream, handle_udp_stream).await?; Ok(WireGuardServer { server, local_addr }) }) } + +struct PythonKeyLogger(PyObject); + +impl KeyLogger for PythonKeyLogger { + fn log_key(&self, name: &str, keymaterial: &str) { + Python::with_gil(|py| { + // The error is intentionally ignored. + let _ = self.0.call1(py, (name, keymaterial)); + }); + } +} diff --git a/src/packet_sources/wireguard.rs b/src/packet_sources/wireguard.rs index 19fb5ef4..d1af5dc9 100755 --- a/src/packet_sources/wireguard.rs +++ b/src/packet_sources/wireguard.rs @@ -8,6 +8,8 @@ use crate::messages::{ use crate::network::{add_network_layer, MAX_PACKET_SIZE}; use crate::packet_sources::{PacketSourceConf, PacketSourceTask}; use anyhow::{anyhow, Context, Result}; +use boringtun::noise::handshake::HandshakeKeysListener; +use boringtun::noise::keys_logger::{KeyLogger, KeysLogger}; use boringtun::noise::{ errors::WireGuardError, handshake::parse_handshake_anon, Packet, Tunn, TunnResult, }; @@ -40,6 +42,15 @@ pub struct WireGuardConf { pub port: u16, pub private_key: StaticSecret, pub peer_public_keys: Vec, + pub key_logger: Option>, +} + +struct KeyLoggerWrapper(Box); + +impl KeyLogger for KeyLoggerWrapper { + fn log_key(&self, name: &str, keymaterial: &str) { + self.0.log_key(name, keymaterial); + } } impl PacketSourceConf for WireGuardConf { @@ -59,13 +70,17 @@ impl PacketSourceConf for WireGuardConf { let (network_task_handle, net_tx, net_rx) = add_network_layer(transport_events_tx, transport_commands_rx, shutdown); + let handshake_keys_listener = self + .key_logger + .map(|key_logger| Arc::new(KeysLogger::new(KeyLoggerWrapper(key_logger)))); + // initialize WireGuard server let mut peers_by_idx = HashMap::new(); let mut peers_by_key = HashMap::new(); for public_key in self.peer_public_keys { let index = peers_by_idx.len() as u32; - let tunnel = Tunn::new( + let mut tunnel = Tunn::new( self.private_key.clone(), public_key, None, @@ -75,6 +90,13 @@ impl PacketSourceConf for WireGuardConf { ) .map_err(|error| anyhow!(error))?; + // Set key logger, if any. + if let Some(keys_listener) = &handshake_keys_listener { + tunnel.set_handshake_keys_listener( + Arc::clone(keys_listener) as Arc + ); + } + let peer = Arc::new(Mutex::new(WireGuardPeer { tunnel, endpoint: None,