Skip to content

Commit

Permalink
Use tokio_util::sync::CancellationToken to quit service
Browse files Browse the repository at this point in the history
  • Loading branch information
sword-jin committed Jun 17, 2024
1 parent be14d12 commit 0c3acd6
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 62 deletions.
21 changes: 8 additions & 13 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpStream, UdpSocket};
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
use tokio::time::{self, Duration, Instant};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, instrument, trace, warn, Instrument, Span};

#[cfg(feature = "noise")]
Expand All @@ -33,7 +34,7 @@ use crate::constants::{run_control_chan_backoff, UDP_BUFFER_SIZE, UDP_SENDQ_SIZE
// The entrypoint of running a client
pub async fn run_client(
config: Config,
shutdown_rx: broadcast::Receiver<bool>,
cancel: CancellationToken,
update_rx: mpsc::Receiver<ConfigChange>,
) -> Result<()> {
let config = config.client.ok_or_else(|| {
Expand All @@ -45,13 +46,13 @@ pub async fn run_client(
match config.transport.transport_type {
TransportType::Tcp => {
let mut client = Client::<TcpTransport>::from(config).await?;
client.run(shutdown_rx, update_rx).await
client.run(cancel, update_rx).await
}
TransportType::Tls => {
#[cfg(any(feature = "native-tls", feature = "rustls"))]
{
let mut client = Client::<TlsTransport>::from(config).await?;
client.run(shutdown_rx, update_rx).await
client.run(cancel, update_rx).await
}
#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
crate::helper::feature_neither_compile("native-tls", "rustls")
Expand All @@ -60,7 +61,7 @@ pub async fn run_client(
#[cfg(feature = "noise")]
{
let mut client = Client::<NoiseTransport>::from(config).await?;
client.run(shutdown_rx, update_rx).await
client.run(cancel, update_rx).await
}
#[cfg(not(feature = "noise"))]
crate::helper::feature_not_compile("noise")
Expand All @@ -69,7 +70,7 @@ pub async fn run_client(
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
{
let mut client = Client::<WebsocketTransport>::from(config).await?;
client.run(shutdown_rx, update_rx).await
client.run(cancel, update_rx).await
}
#[cfg(not(any(feature = "websocket-native-tls", feature = "websocket-rustls")))]
crate::helper::feature_neither_compile("websocket-native-tls", "websocket-rustls")
Expand Down Expand Up @@ -102,7 +103,7 @@ impl<T: 'static + Transport> Client<T> {
// The entrypoint of Client
async fn run(
&mut self,
mut shutdown_rx: broadcast::Receiver<bool>,
cancel: CancellationToken,
mut update_rx: mpsc::Receiver<ConfigChange>,
) -> Result<()> {
for (name, config) in &self.config.services {
Expand All @@ -119,13 +120,7 @@ impl<T: 'static + Transport> Client<T> {
// Wait for the shutdown signal
loop {
tokio::select! {
val = shutdown_rx.recv() => {
match val {
Ok(_) => {}
Err(err) => {
error!("Unable to listen for shutdown signal: {}", err);
}
}
_ = cancel.cancelled() => {
break;
},
e = update_rx.recv() => {
Expand Down
13 changes: 7 additions & 6 deletions src/config_watcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ use crate::{
Config,
};
use anyhow::{Context, Result};
use tokio_util::sync::CancellationToken;
use std::{
collections::HashMap,
env,
path::{Path, PathBuf},
};
use tokio::sync::{broadcast, mpsc};
use tokio::sync::mpsc;
use tracing::{error, info, instrument};

#[cfg(feature = "notify")]
Expand Down Expand Up @@ -98,7 +99,7 @@ pub struct ConfigWatcherHandle {
}

impl ConfigWatcherHandle {
pub async fn new(path: &Path, shutdown_rx: broadcast::Receiver<bool>) -> Result<Self> {
pub async fn new(path: &Path, cancel: CancellationToken) -> Result<Self> {
let (event_tx, event_rx) = mpsc::unbounded_channel();
let origin_cfg = Config::from_file(path).await?;

Expand All @@ -109,7 +110,7 @@ impl ConfigWatcherHandle {

tokio::spawn(config_watcher(
path.to_owned(),
shutdown_rx,
cancel,
event_tx,
origin_cfg,
));
Expand All @@ -132,10 +133,10 @@ async fn config_watcher(
}

#[cfg(feature = "notify")]
#[instrument(skip(shutdown_rx, event_tx, old))]
#[instrument(skip(cancel, event_tx, old))]
async fn config_watcher(
path: PathBuf,
mut shutdown_rx: broadcast::Receiver<bool>,
cancel: CancellationToken,
event_tx: mpsc::UnboundedSender<ConfigChange>,
mut old: Config,
) -> Result<()> {
Expand Down Expand Up @@ -190,7 +191,7 @@ async fn config_watcher(
None => break
}
},
_ = shutdown_rx.recv() => break
_ = cancel.cancelled() => break
}
}

Expand Down
20 changes: 10 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub use constants::UDP_BUFFER_SIZE;

use anyhow::Result;
use tokio::sync::{broadcast, mpsc};
use tokio_util::sync::CancellationToken;
use tracing::{debug, info};

#[cfg(feature = "client")]
Expand Down Expand Up @@ -59,7 +60,7 @@ fn genkey(curve: Option<KeypairType>) -> Result<()> {
crate::helper::feature_not_compile("nosie")
}

pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
pub async fn run(args: Cli, cancel: CancellationToken) -> Result<()> {
if args.genkey.is_some() {
return genkey(args.genkey.unwrap());
}
Expand All @@ -69,10 +70,9 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()

// Spawn a config watcher. The watcher will send a initial signal to start the instance with a config
let config_path = args.config_path.as_ref().unwrap();
let mut cfg_watcher = ConfigWatcherHandle::new(config_path, shutdown_rx).await?;
let mut cfg_watcher = ConfigWatcherHandle::new(config_path, cancel).await?;

// shutdown_tx owns the instance
let (shutdown_tx, _) = broadcast::channel(1);
let local_cancel_tx = CancellationToken::new();

// (The join handle of the last instance, The service update channel sender)
let mut last_instance: Option<(tokio::task::JoinHandle<_>, mpsc::Sender<ConfigChange>)> = None;
Expand All @@ -82,7 +82,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
ConfigChange::General(config) => {
if let Some((i, _)) = last_instance {
info!("General configuration change detected. Restarting...");
shutdown_tx.send(true)?;
local_cancel_tx.cancel();
i.await??;
}

Expand All @@ -94,7 +94,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
tokio::spawn(run_instance(
*config,
args.clone(),
shutdown_tx.subscribe(),
local_cancel_tx.clone(),
service_update_rx,
)),
service_update_tx,
Expand All @@ -109,15 +109,15 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
}
}

let _ = shutdown_tx.send(true);
local_cancel_tx.cancel();

Ok(())
}

async fn run_instance(
config: Config,
args: Cli,
shutdown_rx: broadcast::Receiver<bool>,
cancel: CancellationToken,
service_update: mpsc::Receiver<ConfigChange>,
) -> Result<()> {
match determine_run_mode(&config, &args) {
Expand All @@ -126,13 +126,13 @@ async fn run_instance(
#[cfg(not(feature = "client"))]
crate::helper::feature_not_compile("client");
#[cfg(feature = "client")]
run_client(config, shutdown_rx, service_update).await
run_client(config, cancel, service_update).await
}
RunMode::Server => {
#[cfg(not(feature = "server"))]
crate::helper::feature_not_compile("server");
#[cfg(feature = "server")]
run_server(config, shutdown_rx, service_update).await
run_server(config, cancel, service_update).await
}
}
}
Expand Down
14 changes: 6 additions & 8 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
use anyhow::Result;
use clap::Parser;
use rathole::{run, Cli};
use tokio::{signal, sync::broadcast};
use tokio::signal;
use tracing_subscriber::EnvFilter;
use tokio_util::sync::CancellationToken;

#[tokio::main]
async fn main() -> Result<()> {
let args = Cli::parse();

let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);
let cancel_tx = CancellationToken::new();
let cancel_rx = cancel_tx.clone();
tokio::spawn(async move {
if let Err(e) = signal::ctrl_c().await {
// Something really weird happened. So just panic
panic!("Failed to listen for the ctrl-c signal: {:?}", e);
}

if let Err(e) = shutdown_tx.send(true) {
// shutdown signal must be catched and handle properly
// `rx` must not be dropped
panic!("Failed to send shutdown signal: {:?}", e);
}
cancel_tx.cancel(); // synchronously
});

#[cfg(feature = "console")]
Expand All @@ -41,5 +39,5 @@ async fn main() -> Result<()> {
.init();
}

run(args, shutdown_rx).await
run(args, cancel_rx).await
}
15 changes: 8 additions & 7 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use backoff::backoff::Backoff;
use backoff::ExponentialBackoff;

use rand::RngCore;
use tokio_util::sync::CancellationToken;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
Expand Down Expand Up @@ -41,7 +42,7 @@ const HANDSHAKE_TIMEOUT: u64 = 5; // Timeout for transport handshake
// The entrypoint of running a server
pub async fn run_server(
config: Config,
shutdown_rx: broadcast::Receiver<bool>,
cancel: CancellationToken,
update_rx: mpsc::Receiver<ConfigChange>,
) -> Result<()> {
let config = match config.server {
Expand All @@ -54,13 +55,13 @@ pub async fn run_server(
match config.transport.transport_type {
TransportType::Tcp => {
let mut server = Server::<TcpTransport>::from(config).await?;
server.run(shutdown_rx, update_rx).await?;
server.run(cancel, update_rx).await?;
}
TransportType::Tls => {
#[cfg(any(feature = "native-tls", feature = "rustls"))]
{
let mut server = Server::<TlsTransport>::from(config).await?;
server.run(shutdown_rx, update_rx).await?;
server.run(cancel, update_rx).await?;
}
#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
crate::helper::feature_neither_compile("native-tls", "rustls")
Expand All @@ -69,7 +70,7 @@ pub async fn run_server(
#[cfg(feature = "noise")]
{
let mut server = Server::<NoiseTransport>::from(config).await?;
server.run(shutdown_rx, update_rx).await?;
server.run(cancel, update_rx).await?;
}
#[cfg(not(feature = "noise"))]
crate::helper::feature_not_compile("noise")
Expand All @@ -78,7 +79,7 @@ pub async fn run_server(
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
{
let mut server = Server::<WebsocketTransport>::from(config).await?;
server.run(shutdown_rx, update_rx).await?;
server.run(cancel, update_rx).await?;
}
#[cfg(not(any(feature = "websocket-native-tls", feature = "websocket-rustls")))]
crate::helper::feature_neither_compile("websocket-native-tls", "websocket-rustls")
Expand Down Expand Up @@ -134,7 +135,7 @@ impl<T: 'static + Transport> Server<T> {
// The entry point of Server
pub async fn run(
&mut self,
mut shutdown_rx: broadcast::Receiver<bool>,
cancel: CancellationToken,
mut update_rx: mpsc::Receiver<ConfigChange>,
) -> Result<()> {
// Listen at `server.bind_addr`
Expand Down Expand Up @@ -205,7 +206,7 @@ impl<T: 'static + Transport> Server<T> {
}
},
// Wait for the shutdown signal
_ = shutdown_rx.recv() => {
_ = cancel.cancelled() => {
info!("Shuting down gracefully...");
break;
},
Expand Down
10 changes: 5 additions & 5 deletions tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,36 @@ use anyhow::Result;
use tokio::{
io::{self, AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream, ToSocketAddrs},
sync::broadcast,
};
use tokio_util::sync::CancellationToken;

pub const PING: &str = "ping";
pub const PONG: &str = "pong";

pub async fn run_rathole_server(
config_path: &str,
shutdown_rx: broadcast::Receiver<bool>,
cancel: CancellationToken,
) -> Result<()> {
let cli = rathole::Cli {
config_path: Some(PathBuf::from(config_path)),
server: true,
client: false,
..Default::default()
};
rathole::run(cli, shutdown_rx).await
rathole::run(cli, cancel).await
}

pub async fn run_rathole_client(
config_path: &str,
shutdown_rx: broadcast::Receiver<bool>,
cancel: CancellationToken,
) -> Result<()> {
let cli = rathole::Cli {
config_path: Some(PathBuf::from(config_path)),
server: false,
client: true,
..Default::default()
};
rathole::run(cli, shutdown_rx).await
rathole::run(cli, cancel).await
}

pub mod tcp {
Expand Down
Loading

0 comments on commit 0c3acd6

Please sign in to comment.