Skip to content

Update example tls-graceful-shutdown to axum 0.7 #2384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions examples/tls-graceful-shutdown/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@ publish = false

[dependencies]
axum = { path = "../../axum" }
axum-server = { version = "0.3", features = ["tls-rustls"] }
hyper = { version = "0.14", features = ["full"] }
futures-util = { version = "0.3", default-features = false }
hyper = { version = "1.0.0", features = ["full"] }
hyper-util = { version = "0.1" }
rustls-pemfile = "1.0.4"
scopeguard = "1.2.0"
tokio = { version = "1", features = ["full"] }
tokio-rustls = "0.24.1"
tower = { version = "0.4", features = [] }
tower-service = "0.3.2"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
376 changes: 241 additions & 135 deletions examples/tls-graceful-shutdown/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,140 +4,246 @@
//! cargo run -p example-tls-graceful-shutdown
//! ```

fn main() {
// This example has not yet been updated to Hyper 1.0
use axum::extract::Host;
use axum::handler::HandlerWithoutStateExt;
use axum::http::{StatusCode, Uri};
use axum::response::Redirect;
use axum::{extract::Request, routing::get, BoxError, Router};
use futures_util::{pin_mut, FutureExt};
use hyper::body::Incoming;
use hyper_util::rt::{TokioExecutor, TokioIo};
use rustls_pemfile::{certs, pkcs8_private_keys};
use std::future::{Future, IntoFuture};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use std::{
fs::File,
io::BufReader,
path::{Path, PathBuf},
sync::Arc,
};
use tokio::net::TcpListener;
use tokio::{select, signal};
use tokio_rustls::{
rustls::{Certificate, PrivateKey, ServerConfig},
TlsAcceptor,
};
use tower_service::Service;
use tracing::{error, info, warn};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

// !!!!!!!!! WARNING !!!!!!!!!!
//
// The code only gracefully shutdowns connections/tasks that are managed by axum.
// If inside your handler you spawn a task (with tokio::spawn) and return a response. This task will not be awaited before shutdown.
// So it is up to you to track those tasks and await for them correctly before terminating
//
// !!!!!!!!! WARNING !!!!!!!!!!
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "example_tls_graceful_shutdown=debug".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();

let rustls_config = rustls_server_config(
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("self_signed_certs")
.join("key.pem"),
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("self_signed_certs")
.join("cert.pem"),
);

let app = Router::new().route("/", get(handler));

let nb_inflight_requests = Arc::new(AtomicU64::new(0));
let shutdown_signal = mk_shutdown_signal().fuse();
let tls_acceptor = TlsAcceptor::from(rustls_config);

let ports = Ports {
http: 3080,
https: 3443,
};
let bind = format!("[::1]:{}", ports.https);
let tcp_listener = TcpListener::bind(&bind).await.unwrap();
info!(
"HTTPS server listening on {bind}. To contact curl -k https://localhost:{}",
ports.https
);
tokio::spawn(redirect_http_to_https(ports, mk_shutdown_signal()));

pin_mut!(shutdown_signal);
loop {
let tower_service = app.clone();
let tls_acceptor = tls_acceptor.clone();

// Wait for new tcp connection or shutdown signal
let (cnx, addr) = select! {
biased;

_ = &mut shutdown_signal => {
break;
}

cnx = tcp_listener.accept() => {
let Ok(cnx) = cnx else {
error!("error accepting connection");
break;
};
nb_inflight_requests.fetch_add(1, Ordering::Relaxed);
cnx
}
};

let nb_inflight_requests = nb_inflight_requests.clone();
tokio::spawn(async move {
let _guard = scopeguard::guard((), |_| {
nb_inflight_requests.fetch_sub(1, Ordering::Relaxed);
});

// Wait for tls handshake to happen
let Ok(stream) = tls_acceptor.accept(cnx).await else {
error!("error during tls handshake connection from {}", addr);
return;
};

// Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.
// `TokioIo` converts between them.
let stream = TokioIo::new(stream);

// Hyper has also its own `Service` trait and doesn't use tower. We can use
// `hyper::service::service_fn` to create a hyper `Service` that calls our app through
// `tower::Service::call`.
let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| {
// We have to clone `tower_service` because hyper's `Service` uses `&self` whereas
// tower's `Service` requires `&mut self`.
//
// We don't need to call `poll_ready` since `Router` is always ready.
tower_service.clone().call(request)
});

let ret = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
.serve_connection_with_upgrades(stream, hyper_service)
.await;

if let Err(err) = ret {
warn!("error serving connection from {}: {}", addr, err);
}
});
}

// drop tcp_listener to stop accepting new connections
drop(tls_acceptor);
drop(tcp_listener);
info!("Server is shutting down. Waiting for inflight requests to complete before terminating");
loop {
let nb_inflights = nb_inflight_requests.load(Ordering::Relaxed);
if nb_inflights == 0 {
break;
}
info!("Server is shutting down. Waiting for {} inflight requests to complete before terminating", nb_inflights);
tokio::time::sleep(Duration::from_secs(1)).await;
}
}

//use axum::{
// extract::Host,
// handler::HandlerWithoutStateExt,
// http::{StatusCode, Uri},
// response::Redirect,
// routing::get,
// BoxError, Router,
//};
//use axum_server::tls_rustls::RustlsConfig;
//use std::{future::Future, net::SocketAddr, path::PathBuf, time::Duration};
//use tokio::signal;
//use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

//#[derive(Clone, Copy)]
//struct Ports {
// http: u16,
// https: u16,
//}

//#[tokio::main]
//async fn main() {
// tracing_subscriber::registry()
// .with(
// tracing_subscriber::EnvFilter::try_from_default_env()
// .unwrap_or_else(|_| "example_tls_graceful_shutdown=debug".into()),
// )
// .with(tracing_subscriber::fmt::layer())
// .init();

// let ports = Ports {
// http: 7878,
// https: 3000,
// };

// //Create a handle for our TLS server so the shutdown signal can all shutdown
// let handle = axum_server::Handle::new();
// //save the future for easy shutting down of redirect server
// let shutdown_future = shutdown_signal(handle.clone());

// // optional: spawn a second server to redirect http requests to this server
// tokio::spawn(redirect_http_to_https(ports, shutdown_future));

// // configure certificate and private key used by https
// let config = RustlsConfig::from_pem_file(
// PathBuf::from(env!("CARGO_MANIFEST_DIR"))
// .join("self_signed_certs")
// .join("cert.pem"),
// PathBuf::from(env!("CARGO_MANIFEST_DIR"))
// .join("self_signed_certs")
// .join("key.pem"),
// )
// .await
// .unwrap();

// let app = Router::new().route("/", get(handler));

// // run https server
// let addr = SocketAddr::from(([127, 0, 0, 1], ports.https));
// tracing::debug!("listening on {addr}");
// axum_server::bind_rustls(addr, config)
// .handle(handle)
// .serve(app.into_make_service())
// .await
// .unwrap();
//}

//async fn shutdown_signal(handle: axum_server::Handle) {
// let ctrl_c = async {
// signal::ctrl_c()
// .await
// .expect("failed to install Ctrl+C handler");
// };

// #[cfg(unix)]
// let terminate = async {
// signal::unix::signal(signal::unix::SignalKind::terminate())
// .expect("failed to install signal handler")
// .recv()
// .await;
// };

// #[cfg(not(unix))]
// let terminate = std::future::pending::<()>();

// tokio::select! {
// _ = ctrl_c => {},
// _ = terminate => {},
// }

// tracing::info!("Received termination signal shutting down");
// handle.graceful_shutdown(Some(Duration::from_secs(10))); // 10 secs is how long docker will wait
// // to force shutdown
//}

//async fn handler() -> &'static str {
// "Hello, World!"
//}

//async fn redirect_http_to_https(ports: Ports, signal: impl Future<Output = ()>) {
// fn make_https(host: String, uri: Uri, ports: Ports) -> Result<Uri, BoxError> {
// let mut parts = uri.into_parts();

// parts.scheme = Some(axum::http::uri::Scheme::HTTPS);

// if parts.path_and_query.is_none() {
// parts.path_and_query = Some("/".parse().unwrap());
// }

// let https_host = host.replace(&ports.http.to_string(), &ports.https.to_string());
// parts.authority = Some(https_host.parse()?);

// Ok(Uri::from_parts(parts)?)
// }

// let redirect = move |Host(host): Host, uri: Uri| async move {
// match make_https(host, uri, ports) {
// Ok(uri) => Ok(Redirect::permanent(&uri.to_string())),
// Err(error) => {
// tracing::warn!(%error, "failed to convert URI to HTTPS");
// Err(StatusCode::BAD_REQUEST)
// }
// }
// };

// let addr = SocketAddr::from(([127, 0, 0, 1], ports.http));
// //let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
// tracing::debug!("listening on {addr}");
// hyper::Server::bind(&addr)
// .serve(redirect.into_make_service())
// .with_graceful_shutdown(signal)
// .await
// .unwrap();
//}
async fn handler() -> &'static str {
tokio::time::sleep(Duration::from_secs(5)).await;
"Hello, World!"
}

fn rustls_server_config(key: impl AsRef<Path>, cert: impl AsRef<Path>) -> Arc<ServerConfig> {
let mut key_reader = BufReader::new(File::open(key).unwrap());
let mut cert_reader = BufReader::new(File::open(cert).unwrap());

let key = PrivateKey(pkcs8_private_keys(&mut key_reader).unwrap().remove(0));
let certs = certs(&mut cert_reader)
.unwrap()
.into_iter()
.map(Certificate)
.collect();

let mut config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)
.expect("bad certificate/key");

config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];

Arc::new(config)
}

async fn mk_shutdown_signal() {
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};

#[cfg(not(unix))]
let terminate = std::future::pending::<()>();

select! {
_ = signal::ctrl_c() => {},
_ = terminate => {},
}

info!("Received termination signal shutting down");
}

// Redirect HTTP to HTTPS
#[derive(Clone, Copy)]
struct Ports {
http: u16,
https: u16,
}

async fn redirect_http_to_https(ports: Ports, signal: impl Future<Output = ()>) {
fn make_https(host: String, uri: Uri, ports: Ports) -> Result<Uri, BoxError> {
let mut parts = uri.into_parts();

parts.scheme = Some(axum::http::uri::Scheme::HTTPS);

if parts.path_and_query.is_none() {
parts.path_and_query = Some("/".parse().unwrap());
}

let https_host = host.replace(&ports.http.to_string(), &ports.https.to_string());
parts.authority = Some(https_host.parse()?);

Ok(Uri::from_parts(parts)?)
}

let redirect = move |Host(host): Host, uri: Uri| async move {
match make_https(host, uri, ports) {
Ok(uri) => Ok(Redirect::permanent(&uri.to_string())),
Err(error) => {
warn!(%error, "failed to convert URI to HTTPS");
Err(StatusCode::BAD_REQUEST)
}
}
};

let bind = format!("[::1]:{}", ports.http);
let listener = TcpListener::bind(&bind).await.unwrap();
info!(
"HTTP server listening on {bind}. To contact curl http://localhost:{}",
ports.http
);
let server = axum::serve(listener, redirect.into_make_service()).into_future();

select! {
biased;

_ = signal => {},
_ = server => {},
}

info!("HTTP server shutdown");
}