From 782d92f2c3561251133bed1dacfefda6c3a0e06b Mon Sep 17 00:00:00 2001 From: Dominik Stolz Date: Sat, 26 Sep 2020 11:29:41 +0200 Subject: [PATCH] Improved OnionListener --- src/lib.rs | 85 +++++++++++++++++++++++++++++++------------- src/onion/circuit.rs | 3 +- src/onion/tunnel.rs | 2 +- 3 files changed, 64 insertions(+), 26 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 4eb4868..4882c31 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,7 @@ use crate::onion::tunnel::{self, Target, TunnelBuilder, TunnelHandler}; use anyhow::anyhow; use bytes::Bytes; use futures::stream::StreamExt; -use log::{info, trace, warn}; +use log::{debug, info, trace, warn}; use ring::rand; use std::collections::{hash_map, HashMap}; use std::net::SocketAddr; @@ -90,19 +90,24 @@ pub struct OnionTunnel { tunnel_id: TunnelId, data_tx: mpsc::UnboundedSender, data_rx: mpsc::Receiver, + counted: bool, } impl OnionTunnel { pub(crate) fn new( tunnel_id: TunnelId, + counted: bool, ) -> (Self, mpsc::Sender, mpsc::UnboundedReceiver) { - TUNNEL_COUNT.fetch_add(1, Ordering::Relaxed); + if counted { + TUNNEL_COUNT.fetch_add(1, Ordering::Relaxed); + } let (data_tx, data_rx2) = mpsc::unbounded_channel(); let (data_tx2, data_rx) = mpsc::channel(DATA_BUFFER_SIZE); let tunnel = Self { tunnel_id, data_tx, data_rx, + counted, }; (tunnel, data_tx2, data_rx2) } @@ -134,6 +139,21 @@ impl OnionTunnel { data_tx: self.data_tx.clone(), } } + + async fn forward_data( + mut self, + mut tunnel_rx: mpsc::Receiver, + mut data_tx: mpsc::Sender, + mut data_rx: mpsc::UnboundedReceiver, + ) -> Option<()> { + loop { + tokio::select! { + t = tunnel_rx.recv() => self = t?, + d = self.read() => data_tx.send(d.ok()?).await.ok()?, + d = data_rx.recv() => self.write(d?).ok()?, + } + } + } } impl fmt::Debug for OnionTunnel { @@ -146,7 +166,12 @@ impl fmt::Debug for OnionTunnel { impl Drop for OnionTunnel { fn drop(&mut self) { - TUNNEL_COUNT.fetch_sub(1, Ordering::Relaxed); + if self.counted { + let c = TUNNEL_COUNT.fetch_sub(1, Ordering::Relaxed); + debug!("Dropping tunnel with ID {}, count: {}", self.id(), c); + } else { + debug!("Dropping tunnel with ID {}", self.id()); + } } } @@ -384,34 +409,46 @@ impl OnionListener { } async fn handle_incoming(&mut self, tunnel: OnionTunnel) { - let mut tunnels = self.tunnels.lock().await; + let tunnels = self.tunnels.clone(); + let mut tunnels = tunnels.lock().await; + match tunnels.entry(tunnel.id()) { hash_map::Entry::Occupied(mut e) => { - // FIXME replace entry if send fails - e.get_mut().send(tunnel).await; + if let Err(t) = e.get_mut().send(tunnel).await { + if let Ok(tunnel_tx) = self.handle_new_tunnel(t.0).await { + e.insert(tunnel_tx); + } + } } hash_map::Entry::Vacant(e) => { - let (tunnel_tx, mut tunnel_rx) = mpsc::channel(1); - // FIXME tunnels are never removed from the map - e.insert(tunnel_tx); - - let (e_tunnel, mut e_data_tx, mut e_data_rx) = OnionTunnel::new(tunnel.id()); - self.incoming.send(e_tunnel).await; - - tokio::spawn(async move { - let mut i_tunnel = tunnel; - loop { - tokio::select! { - Some(t) = tunnel_rx.recv() => i_tunnel = t, - Ok(d) = i_tunnel.read() => e_data_tx.send(d).await.unwrap(), - Some(d) = e_data_rx.recv() => i_tunnel.write(d).unwrap(), - else => break, - } - } - }); + if let Ok(tunnel_tx) = self.handle_new_tunnel(tunnel).await { + e.insert(tunnel_tx); + } } } } + + async fn handle_new_tunnel( + &mut self, + tunnel: OnionTunnel, + ) -> Result> { + let (tunnel_tx, tunnel_rx) = mpsc::channel(1); + let (e_tunnel, e_data_tx, e_data_rx) = OnionTunnel::new(tunnel.id(), true); + self.incoming.send(e_tunnel).await?; + + tokio::spawn({ + let tunnels = self.tunnels.clone(); + async move { + let tunnel_id = tunnel.id(); + debug!("Handling incoming tunnel {}", tunnel_id); + let _ = tunnel.forward_data(tunnel_rx, e_data_tx, e_data_rx).await; + tunnels.lock().await.remove(&tunnel_id); + debug!("Finished handling incoming tunnel {}", tunnel_id); + } + }); + + Ok(tunnel_tx) + } } /// Tunnels created in one period should be torn down and rebuilt for the next period. diff --git a/src/onion/circuit.rs b/src/onion/circuit.rs index 8762c02..7e35c89 100644 --- a/src/onion/circuit.rs +++ b/src/onion/circuit.rs @@ -346,7 +346,8 @@ impl CircuitHandler { state } (TunnelRequest::Begin(tunnel_id), State::Default) => { - let (tunnel, tx, rx) = OnionTunnel::new(tunnel_id); + // counted = false because these tunnels will be mapped to counted tunnels by the OnionListener + let (tunnel, tx, rx) = OnionTunnel::new(tunnel_id, false); if self.incoming.try_send(tunnel).is_ok() { State::Endpoint { tunnel_id, diff --git a/src/onion/tunnel.rs b/src/onion/tunnel.rs index 52222b0..871a3f3 100644 --- a/src/onion/tunnel.rs +++ b/src/onion/tunnel.rs @@ -497,7 +497,7 @@ impl TunnelHandler { self.state = match (evt, state) { (Event::Switchover, State::Building { ready }) => { self.tunnel.begin(&self.builder.rng).await?; - let (tunnel, data_tx, data_rx) = OnionTunnel::new(self.tunnel.id); + let (tunnel, data_tx, data_rx) = OnionTunnel::new(self.tunnel.id, true); let _ = ready.send(Ok(tunnel)); // TODO handle closed self.spawn_next_tunnel_task(); State::Ready { data_tx, data_rx }