diff --git a/examples/tls-graceful-shutdown/Cargo.toml b/examples/tls-graceful-shutdown/Cargo.toml index 7b0169ba8f..f8e3b47ca9 100644 --- a/examples/tls-graceful-shutdown/Cargo.toml +++ b/examples/tls-graceful-shutdown/Cargo.toml @@ -6,8 +6,14 @@ publish = false [dependencies] axum = { path = "../../axum" } -axum-server = { version = "0.3", features = ["tls-rustls"] } -hyper = { version = "0.14", features = ["full"] } +futures-util = { version = "0.3", default-features = false } +hyper = { version = "1.0.0", features = ["full"] } +hyper-util = { version = "0.1" } +rustls-pemfile = "1.0.4" +scopeguard = "1.2.0" tokio = { version = "1", features = ["full"] } +tokio-rustls = "0.24.1" +tower = { version = "0.4", features = [] } +tower-service = "0.3.2" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/tls-graceful-shutdown/src/main.rs b/examples/tls-graceful-shutdown/src/main.rs index 13251846de..b58b6c2c4d 100644 --- a/examples/tls-graceful-shutdown/src/main.rs +++ b/examples/tls-graceful-shutdown/src/main.rs @@ -4,140 +4,246 @@ //! cargo run -p example-tls-graceful-shutdown //! ``` -fn main() { - // This example has not yet been updated to Hyper 1.0 +use axum::extract::Host; +use axum::handler::HandlerWithoutStateExt; +use axum::http::{StatusCode, Uri}; +use axum::response::Redirect; +use axum::{extract::Request, routing::get, BoxError, Router}; +use futures_util::{pin_mut, FutureExt}; +use hyper::body::Incoming; +use hyper_util::rt::{TokioExecutor, TokioIo}; +use rustls_pemfile::{certs, pkcs8_private_keys}; +use std::future::{Future, IntoFuture}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::Duration; +use std::{ + fs::File, + io::BufReader, + path::{Path, PathBuf}, + sync::Arc, +}; +use tokio::net::TcpListener; +use tokio::{select, signal}; +use tokio_rustls::{ + rustls::{Certificate, PrivateKey, ServerConfig}, + TlsAcceptor, +}; +use tower_service::Service; +use tracing::{error, info, warn}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +// !!!!!!!!! WARNING !!!!!!!!!! +// +// The code only gracefully shutdowns connections/tasks that are managed by axum. +// If inside your handler you spawn a task (with tokio::spawn) and return a response. This task will not be awaited before shutdown. +// So it is up to you to track those tasks and await for them correctly before terminating +// +// !!!!!!!!! WARNING !!!!!!!!!! +#[tokio::main] +async fn main() { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "example_tls_graceful_shutdown=debug".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let rustls_config = rustls_server_config( + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("self_signed_certs") + .join("key.pem"), + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("self_signed_certs") + .join("cert.pem"), + ); + + let app = Router::new().route("/", get(handler)); + + let nb_inflight_requests = Arc::new(AtomicU64::new(0)); + let shutdown_signal = mk_shutdown_signal().fuse(); + let tls_acceptor = TlsAcceptor::from(rustls_config); + + let ports = Ports { + http: 3080, + https: 3443, + }; + let bind = format!("[::1]:{}", ports.https); + let tcp_listener = TcpListener::bind(&bind).await.unwrap(); + info!( + "HTTPS server listening on {bind}. To contact curl -k https://localhost:{}", + ports.https + ); + tokio::spawn(redirect_http_to_https(ports, mk_shutdown_signal())); + + pin_mut!(shutdown_signal); + loop { + let tower_service = app.clone(); + let tls_acceptor = tls_acceptor.clone(); + + // Wait for new tcp connection or shutdown signal + let (cnx, addr) = select! { + biased; + + _ = &mut shutdown_signal => { + break; + } + + cnx = tcp_listener.accept() => { + let Ok(cnx) = cnx else { + error!("error accepting connection"); + break; + }; + nb_inflight_requests.fetch_add(1, Ordering::Relaxed); + cnx + } + }; + + let nb_inflight_requests = nb_inflight_requests.clone(); + tokio::spawn(async move { + let _guard = scopeguard::guard((), |_| { + nb_inflight_requests.fetch_sub(1, Ordering::Relaxed); + }); + + // Wait for tls handshake to happen + let Ok(stream) = tls_acceptor.accept(cnx).await else { + error!("error during tls handshake connection from {}", addr); + return; + }; + + // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio. + // `TokioIo` converts between them. + let stream = TokioIo::new(stream); + + // Hyper has also its own `Service` trait and doesn't use tower. We can use + // `hyper::service::service_fn` to create a hyper `Service` that calls our app through + // `tower::Service::call`. + let hyper_service = hyper::service::service_fn(move |request: Request| { + // We have to clone `tower_service` because hyper's `Service` uses `&self` whereas + // tower's `Service` requires `&mut self`. + // + // We don't need to call `poll_ready` since `Router` is always ready. + tower_service.clone().call(request) + }); + + let ret = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) + .serve_connection_with_upgrades(stream, hyper_service) + .await; + + if let Err(err) = ret { + warn!("error serving connection from {}: {}", addr, err); + } + }); + } + + // drop tcp_listener to stop accepting new connections + drop(tls_acceptor); + drop(tcp_listener); + info!("Server is shutting down. Waiting for inflight requests to complete before terminating"); + loop { + let nb_inflights = nb_inflight_requests.load(Ordering::Relaxed); + if nb_inflights == 0 { + break; + } + info!("Server is shutting down. Waiting for {} inflight requests to complete before terminating", nb_inflights); + tokio::time::sleep(Duration::from_secs(1)).await; + } } -//use axum::{ -// extract::Host, -// handler::HandlerWithoutStateExt, -// http::{StatusCode, Uri}, -// response::Redirect, -// routing::get, -// BoxError, Router, -//}; -//use axum_server::tls_rustls::RustlsConfig; -//use std::{future::Future, net::SocketAddr, path::PathBuf, time::Duration}; -//use tokio::signal; -//use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; - -//#[derive(Clone, Copy)] -//struct Ports { -// http: u16, -// https: u16, -//} - -//#[tokio::main] -//async fn main() { -// tracing_subscriber::registry() -// .with( -// tracing_subscriber::EnvFilter::try_from_default_env() -// .unwrap_or_else(|_| "example_tls_graceful_shutdown=debug".into()), -// ) -// .with(tracing_subscriber::fmt::layer()) -// .init(); - -// let ports = Ports { -// http: 7878, -// https: 3000, -// }; - -// //Create a handle for our TLS server so the shutdown signal can all shutdown -// let handle = axum_server::Handle::new(); -// //save the future for easy shutting down of redirect server -// let shutdown_future = shutdown_signal(handle.clone()); - -// // optional: spawn a second server to redirect http requests to this server -// tokio::spawn(redirect_http_to_https(ports, shutdown_future)); - -// // configure certificate and private key used by https -// let config = RustlsConfig::from_pem_file( -// PathBuf::from(env!("CARGO_MANIFEST_DIR")) -// .join("self_signed_certs") -// .join("cert.pem"), -// PathBuf::from(env!("CARGO_MANIFEST_DIR")) -// .join("self_signed_certs") -// .join("key.pem"), -// ) -// .await -// .unwrap(); - -// let app = Router::new().route("/", get(handler)); - -// // run https server -// let addr = SocketAddr::from(([127, 0, 0, 1], ports.https)); -// tracing::debug!("listening on {addr}"); -// axum_server::bind_rustls(addr, config) -// .handle(handle) -// .serve(app.into_make_service()) -// .await -// .unwrap(); -//} - -//async fn shutdown_signal(handle: axum_server::Handle) { -// let ctrl_c = async { -// signal::ctrl_c() -// .await -// .expect("failed to install Ctrl+C handler"); -// }; - -// #[cfg(unix)] -// let terminate = async { -// signal::unix::signal(signal::unix::SignalKind::terminate()) -// .expect("failed to install signal handler") -// .recv() -// .await; -// }; - -// #[cfg(not(unix))] -// let terminate = std::future::pending::<()>(); - -// tokio::select! { -// _ = ctrl_c => {}, -// _ = terminate => {}, -// } - -// tracing::info!("Received termination signal shutting down"); -// handle.graceful_shutdown(Some(Duration::from_secs(10))); // 10 secs is how long docker will wait -// // to force shutdown -//} - -//async fn handler() -> &'static str { -// "Hello, World!" -//} - -//async fn redirect_http_to_https(ports: Ports, signal: impl Future) { -// fn make_https(host: String, uri: Uri, ports: Ports) -> Result { -// let mut parts = uri.into_parts(); - -// parts.scheme = Some(axum::http::uri::Scheme::HTTPS); - -// if parts.path_and_query.is_none() { -// parts.path_and_query = Some("/".parse().unwrap()); -// } - -// let https_host = host.replace(&ports.http.to_string(), &ports.https.to_string()); -// parts.authority = Some(https_host.parse()?); - -// Ok(Uri::from_parts(parts)?) -// } - -// let redirect = move |Host(host): Host, uri: Uri| async move { -// match make_https(host, uri, ports) { -// Ok(uri) => Ok(Redirect::permanent(&uri.to_string())), -// Err(error) => { -// tracing::warn!(%error, "failed to convert URI to HTTPS"); -// Err(StatusCode::BAD_REQUEST) -// } -// } -// }; - -// let addr = SocketAddr::from(([127, 0, 0, 1], ports.http)); -// //let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); -// tracing::debug!("listening on {addr}"); -// hyper::Server::bind(&addr) -// .serve(redirect.into_make_service()) -// .with_graceful_shutdown(signal) -// .await -// .unwrap(); -//} +async fn handler() -> &'static str { + tokio::time::sleep(Duration::from_secs(5)).await; + "Hello, World!" +} + +fn rustls_server_config(key: impl AsRef, cert: impl AsRef) -> Arc { + let mut key_reader = BufReader::new(File::open(key).unwrap()); + let mut cert_reader = BufReader::new(File::open(cert).unwrap()); + + let key = PrivateKey(pkcs8_private_keys(&mut key_reader).unwrap().remove(0)); + let certs = certs(&mut cert_reader) + .unwrap() + .into_iter() + .map(Certificate) + .collect(); + + let mut config = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(certs, key) + .expect("bad certificate/key"); + + config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + + Arc::new(config) +} + +async fn mk_shutdown_signal() { + #[cfg(unix)] + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + select! { + _ = signal::ctrl_c() => {}, + _ = terminate => {}, + } + + info!("Received termination signal shutting down"); +} + +// Redirect HTTP to HTTPS +#[derive(Clone, Copy)] +struct Ports { + http: u16, + https: u16, +} + +async fn redirect_http_to_https(ports: Ports, signal: impl Future) { + fn make_https(host: String, uri: Uri, ports: Ports) -> Result { + let mut parts = uri.into_parts(); + + parts.scheme = Some(axum::http::uri::Scheme::HTTPS); + + if parts.path_and_query.is_none() { + parts.path_and_query = Some("/".parse().unwrap()); + } + + let https_host = host.replace(&ports.http.to_string(), &ports.https.to_string()); + parts.authority = Some(https_host.parse()?); + + Ok(Uri::from_parts(parts)?) + } + + let redirect = move |Host(host): Host, uri: Uri| async move { + match make_https(host, uri, ports) { + Ok(uri) => Ok(Redirect::permanent(&uri.to_string())), + Err(error) => { + warn!(%error, "failed to convert URI to HTTPS"); + Err(StatusCode::BAD_REQUEST) + } + } + }; + + let bind = format!("[::1]:{}", ports.http); + let listener = TcpListener::bind(&bind).await.unwrap(); + info!( + "HTTP server listening on {bind}. To contact curl http://localhost:{}", + ports.http + ); + let server = axum::serve(listener, redirect.into_make_service()).into_future(); + + select! { + biased; + + _ = signal => {}, + _ = server => {}, + } + + info!("HTTP server shutdown"); +}