diff --git a/src/client.rs b/src/client.rs index 2564869c..55616fc1 100644 --- a/src/client.rs +++ b/src/client.rs @@ -19,6 +19,7 @@ use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpStream, UdpSocket}; use tokio::sync::{broadcast, mpsc, oneshot, RwLock}; use tokio::time::{self, Duration, Instant}; +use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, instrument, trace, warn, Instrument, Span}; #[cfg(feature = "noise")] @@ -33,7 +34,7 @@ use crate::constants::{run_control_chan_backoff, UDP_BUFFER_SIZE, UDP_SENDQ_SIZE // The entrypoint of running a client pub async fn run_client( config: Config, - shutdown_rx: broadcast::Receiver, + cancel: CancellationToken, update_rx: mpsc::Receiver, ) -> Result<()> { let config = config.client.ok_or_else(|| { @@ -45,13 +46,13 @@ pub async fn run_client( match config.transport.transport_type { TransportType::Tcp => { let mut client = Client::::from(config).await?; - client.run(shutdown_rx, update_rx).await + client.run(cancel, update_rx).await } TransportType::Tls => { #[cfg(any(feature = "native-tls", feature = "rustls"))] { let mut client = Client::::from(config).await?; - client.run(shutdown_rx, update_rx).await + client.run(cancel, update_rx).await } #[cfg(not(any(feature = "native-tls", feature = "rustls")))] crate::helper::feature_neither_compile("native-tls", "rustls") @@ -60,7 +61,7 @@ pub async fn run_client( #[cfg(feature = "noise")] { let mut client = Client::::from(config).await?; - client.run(shutdown_rx, update_rx).await + client.run(cancel, update_rx).await } #[cfg(not(feature = "noise"))] crate::helper::feature_not_compile("noise") @@ -69,7 +70,7 @@ pub async fn run_client( #[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))] { let mut client = Client::::from(config).await?; - client.run(shutdown_rx, update_rx).await + client.run(cancel, update_rx).await } #[cfg(not(any(feature = "websocket-native-tls", feature = "websocket-rustls")))] crate::helper::feature_neither_compile("websocket-native-tls", "websocket-rustls") @@ -102,7 +103,7 @@ impl Client { // The entrypoint of Client async fn run( &mut self, - mut shutdown_rx: broadcast::Receiver, + cancel: CancellationToken, mut update_rx: mpsc::Receiver, ) -> Result<()> { for (name, config) in &self.config.services { @@ -119,13 +120,7 @@ impl Client { // Wait for the shutdown signal loop { tokio::select! { - val = shutdown_rx.recv() => { - match val { - Ok(_) => {} - Err(err) => { - error!("Unable to listen for shutdown signal: {}", err); - } - } + _ = cancel.cancelled() => { break; }, e = update_rx.recv() => { diff --git a/src/config_watcher.rs b/src/config_watcher.rs index 993fdcce..bf58cb62 100644 --- a/src/config_watcher.rs +++ b/src/config_watcher.rs @@ -3,12 +3,13 @@ use crate::{ Config, }; use anyhow::{Context, Result}; +use tokio_util::sync::CancellationToken; use std::{ collections::HashMap, env, path::{Path, PathBuf}, }; -use tokio::sync::{broadcast, mpsc}; +use tokio::sync::mpsc; use tracing::{error, info, instrument}; #[cfg(feature = "notify")] @@ -98,7 +99,7 @@ pub struct ConfigWatcherHandle { } impl ConfigWatcherHandle { - pub async fn new(path: &Path, shutdown_rx: broadcast::Receiver) -> Result { + pub async fn new(path: &Path, cancel: CancellationToken) -> Result { let (event_tx, event_rx) = mpsc::unbounded_channel(); let origin_cfg = Config::from_file(path).await?; @@ -109,7 +110,7 @@ impl ConfigWatcherHandle { tokio::spawn(config_watcher( path.to_owned(), - shutdown_rx, + cancel, event_tx, origin_cfg, )); @@ -132,10 +133,10 @@ async fn config_watcher( } #[cfg(feature = "notify")] -#[instrument(skip(shutdown_rx, event_tx, old))] +#[instrument(skip(cancel, event_tx, old))] async fn config_watcher( path: PathBuf, - mut shutdown_rx: broadcast::Receiver, + cancel: CancellationToken, event_tx: mpsc::UnboundedSender, mut old: Config, ) -> Result<()> { @@ -190,7 +191,7 @@ async fn config_watcher( None => break } }, - _ = shutdown_rx.recv() => break + _ = cancel.cancelled() => break } } diff --git a/src/lib.rs b/src/lib.rs index 65beb7f7..3e371b55 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,7 @@ pub use constants::UDP_BUFFER_SIZE; use anyhow::Result; use tokio::sync::{broadcast, mpsc}; +use tokio_util::sync::CancellationToken; use tracing::{debug, info}; #[cfg(feature = "client")] @@ -59,7 +60,7 @@ fn genkey(curve: Option) -> Result<()> { crate::helper::feature_not_compile("nosie") } -pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver) -> Result<()> { +pub async fn run(args: Cli, cancel: CancellationToken) -> Result<()> { if args.genkey.is_some() { return genkey(args.genkey.unwrap()); } @@ -69,10 +70,9 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver) -> Result<() // Spawn a config watcher. The watcher will send a initial signal to start the instance with a config let config_path = args.config_path.as_ref().unwrap(); - let mut cfg_watcher = ConfigWatcherHandle::new(config_path, shutdown_rx).await?; + let mut cfg_watcher = ConfigWatcherHandle::new(config_path, cancel).await?; - // shutdown_tx owns the instance - let (shutdown_tx, _) = broadcast::channel(1); + let local_cancel_tx = CancellationToken::new(); // (The join handle of the last instance, The service update channel sender) let mut last_instance: Option<(tokio::task::JoinHandle<_>, mpsc::Sender)> = None; @@ -82,7 +82,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver) -> Result<() ConfigChange::General(config) => { if let Some((i, _)) = last_instance { info!("General configuration change detected. Restarting..."); - shutdown_tx.send(true)?; + local_cancel_tx.cancel(); i.await??; } @@ -94,7 +94,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver) -> Result<() tokio::spawn(run_instance( *config, args.clone(), - shutdown_tx.subscribe(), + local_cancel_tx.clone(), service_update_rx, )), service_update_tx, @@ -109,7 +109,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver) -> Result<() } } - let _ = shutdown_tx.send(true); + local_cancel_tx.cancel(); Ok(()) } @@ -117,7 +117,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver) -> Result<() async fn run_instance( config: Config, args: Cli, - shutdown_rx: broadcast::Receiver, + cancel: CancellationToken, service_update: mpsc::Receiver, ) -> Result<()> { match determine_run_mode(&config, &args) { @@ -126,13 +126,13 @@ async fn run_instance( #[cfg(not(feature = "client"))] crate::helper::feature_not_compile("client"); #[cfg(feature = "client")] - run_client(config, shutdown_rx, service_update).await + run_client(config, cancel, service_update).await } RunMode::Server => { #[cfg(not(feature = "server"))] crate::helper::feature_not_compile("server"); #[cfg(feature = "server")] - run_server(config, shutdown_rx, service_update).await + run_server(config, cancel, service_update).await } } } diff --git a/src/main.rs b/src/main.rs index 92ab75c1..19cf5903 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,25 +1,23 @@ use anyhow::Result; use clap::Parser; use rathole::{run, Cli}; -use tokio::{signal, sync::broadcast}; +use tokio::signal; use tracing_subscriber::EnvFilter; +use tokio_util::sync::CancellationToken; #[tokio::main] async fn main() -> Result<()> { let args = Cli::parse(); - let (shutdown_tx, shutdown_rx) = broadcast::channel::(1); + let cancel_tx = CancellationToken::new(); + let cancel_rx = cancel_tx.clone(); tokio::spawn(async move { if let Err(e) = signal::ctrl_c().await { // Something really weird happened. So just panic panic!("Failed to listen for the ctrl-c signal: {:?}", e); } - if let Err(e) = shutdown_tx.send(true) { - // shutdown signal must be catched and handle properly - // `rx` must not be dropped - panic!("Failed to send shutdown signal: {:?}", e); - } + cancel_tx.cancel(); // synchronously }); #[cfg(feature = "console")] @@ -41,5 +39,5 @@ async fn main() -> Result<()> { .init(); } - run(args, shutdown_rx).await + run(args, cancel_rx).await } diff --git a/src/server.rs b/src/server.rs index a4c49482..1b1c59fe 100644 --- a/src/server.rs +++ b/src/server.rs @@ -14,6 +14,7 @@ use backoff::backoff::Backoff; use backoff::ExponentialBackoff; use rand::RngCore; +use tokio_util::sync::CancellationToken; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; @@ -41,7 +42,7 @@ const HANDSHAKE_TIMEOUT: u64 = 5; // Timeout for transport handshake // The entrypoint of running a server pub async fn run_server( config: Config, - shutdown_rx: broadcast::Receiver, + cancel: CancellationToken, update_rx: mpsc::Receiver, ) -> Result<()> { let config = match config.server { @@ -54,13 +55,13 @@ pub async fn run_server( match config.transport.transport_type { TransportType::Tcp => { let mut server = Server::::from(config).await?; - server.run(shutdown_rx, update_rx).await?; + server.run(cancel, update_rx).await?; } TransportType::Tls => { #[cfg(any(feature = "native-tls", feature = "rustls"))] { let mut server = Server::::from(config).await?; - server.run(shutdown_rx, update_rx).await?; + server.run(cancel, update_rx).await?; } #[cfg(not(any(feature = "native-tls", feature = "rustls")))] crate::helper::feature_neither_compile("native-tls", "rustls") @@ -69,7 +70,7 @@ pub async fn run_server( #[cfg(feature = "noise")] { let mut server = Server::::from(config).await?; - server.run(shutdown_rx, update_rx).await?; + server.run(cancel, update_rx).await?; } #[cfg(not(feature = "noise"))] crate::helper::feature_not_compile("noise") @@ -78,7 +79,7 @@ pub async fn run_server( #[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))] { let mut server = Server::::from(config).await?; - server.run(shutdown_rx, update_rx).await?; + server.run(cancel, update_rx).await?; } #[cfg(not(any(feature = "websocket-native-tls", feature = "websocket-rustls")))] crate::helper::feature_neither_compile("websocket-native-tls", "websocket-rustls") @@ -134,7 +135,7 @@ impl Server { // The entry point of Server pub async fn run( &mut self, - mut shutdown_rx: broadcast::Receiver, + cancel: CancellationToken, mut update_rx: mpsc::Receiver, ) -> Result<()> { // Listen at `server.bind_addr` @@ -205,7 +206,7 @@ impl Server { } }, // Wait for the shutdown signal - _ = shutdown_rx.recv() => { + _ = cancel.cancelled() => { info!("Shuting down gracefully..."); break; }, diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 9e59f92d..191759e2 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -4,15 +4,15 @@ use anyhow::Result; use tokio::{ io::{self, AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream, ToSocketAddrs}, - sync::broadcast, }; +use tokio_util::sync::CancellationToken; pub const PING: &str = "ping"; pub const PONG: &str = "pong"; pub async fn run_rathole_server( config_path: &str, - shutdown_rx: broadcast::Receiver, + cancel: CancellationToken, ) -> Result<()> { let cli = rathole::Cli { config_path: Some(PathBuf::from(config_path)), @@ -20,12 +20,12 @@ pub async fn run_rathole_server( client: false, ..Default::default() }; - rathole::run(cli, shutdown_rx).await + rathole::run(cli, cancel).await } pub async fn run_rathole_client( config_path: &str, - shutdown_rx: broadcast::Receiver, + cancel: CancellationToken, ) -> Result<()> { let cli = rathole::Cli { config_path: Some(PathBuf::from(config_path)), @@ -33,7 +33,7 @@ pub async fn run_rathole_client( client: true, ..Default::default() }; - rathole::run(cli, shutdown_rx).await + rathole::run(cli, cancel).await } pub mod tcp { diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 7b5d408d..30981d84 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -1,11 +1,11 @@ use anyhow::{Ok, Result}; use common::{run_rathole_client, PING, PONG}; use rand::Rng; +use tokio_util::sync::CancellationToken; use std::time::Duration; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::{TcpStream, UdpSocket}, - sync::broadcast, time, }; use tracing::{debug, info, instrument}; @@ -117,13 +117,13 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> { return Ok(()); } - let (client_shutdown_tx, client_shutdown_rx) = broadcast::channel(1); - let (server_shutdown_tx, server_shutdown_rx) = broadcast::channel(1); + let (cancel_client_tx, cancel_server_tx) = (CancellationToken::new(), CancellationToken::new()); // Start the client info!("start the client"); + let cancel_client_rx = cancel_client_tx.clone(); let client = tokio::spawn(async move { - run_rathole_client(config_path, client_shutdown_rx) + run_rathole_client(config_path, cancel_client_rx) .await .unwrap(); }); @@ -133,8 +133,9 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> { // Start the server info!("start the server"); + let cancel_server_rx = cancel_server_tx.clone(); let server = tokio::spawn(async move { - run_rathole_server(config_path, server_shutdown_rx) + run_rathole_server(config_path, cancel_server_rx) .await .unwrap(); }); @@ -149,13 +150,14 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> { // Simulate the client crash and restart info!("shutdown the client"); - client_shutdown_tx.send(true)?; + cancel_client_tx.cancel(); let _ = tokio::join!(client); info!("restart the client"); - let client_shutdown_rx = client_shutdown_tx.subscribe(); + let restart_client_cancel_tx = CancellationToken::new(); + let restart_client_cancel_rx = restart_client_cancel_tx.clone(); let client = tokio::spawn(async move { - run_rathole_client(config_path, client_shutdown_rx) + run_rathole_client(config_path, restart_client_cancel_rx) .await .unwrap(); }); @@ -170,13 +172,14 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> { // Simulate the server crash and restart info!("shutdown the server"); - server_shutdown_tx.send(true)?; + cancel_server_tx.cancel(); let _ = tokio::join!(server); info!("restart the server"); - let server_shutdown_rx = server_shutdown_tx.subscribe(); + let restart_server_cancel_tx = CancellationToken::new(); + let restart_server_cancel_rx = restart_server_cancel_tx.clone(); let server = tokio::spawn(async move { - run_rathole_server(config_path, server_shutdown_rx) + run_rathole_server(config_path, restart_server_cancel_rx) .await .unwrap(); }); @@ -205,8 +208,8 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> { // Shutdown info!("shutdown the server and the client"); - server_shutdown_tx.send(true)?; - client_shutdown_tx.send(true)?; + restart_client_cancel_tx.cancel(); + restart_server_cancel_tx.cancel(); let _ = tokio::join!(server, client);