diff --git a/CHANGELOG.md b/CHANGELOG.md index 43134ab..d18c1ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changes +## [1.1.0] - 2024-08-12 + +* Server graceful shutdown support + ## [1.0.0] - 2024-05-28 * Use async fn for Service::ready() and Service::shutdown() diff --git a/Cargo.toml b/Cargo.toml index ff6a225..d887182 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-h2" -version = "1.0.0" +version = "1.1.0" license = "MIT OR Apache-2.0" authors = ["Nikolay Kim "] description = "An HTTP/2 client and server" diff --git a/src/config.rs b/src/config.rs index 3b094a1..0fa27bb 100644 --- a/src/config.rs +++ b/src/config.rs @@ -8,8 +8,9 @@ use crate::{consts, frame, frame::Settings, frame::WindowSize}; bitflags::bitflags! { #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] struct ConfigFlags: u8 { - const SERVER = 0b0000_0001; - const HTTPS = 0b0000_0010; + const SERVER = 0b0000_0001; + const HTTPS = 0b0000_0010; + const SHUTDOWN = 0b0000_0100; } } @@ -326,11 +327,30 @@ impl Config { self.0.flags.get().contains(ConfigFlags::SERVER) } + /// Check if service is shutting down. + pub fn is_shutdown(&self) -> bool { + self.0.flags.get().contains(ConfigFlags::SHUTDOWN) + } + + /// Set service shutdown. + pub fn shutdown(&self) { + let mut flags = self.0.flags.get(); + flags.insert(ConfigFlags::SHUTDOWN); + self.0.flags.set(flags); + } + pub(crate) fn inner(&self) -> &ConfigInner { self.0.as_ref() } } +impl ConfigInner { + /// Check if service is shutting down. + pub(crate) fn is_shutdown(&self) -> bool { + self.flags.get().contains(ConfigFlags::SHUTDOWN) + } +} + impl fmt::Debug for Config { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Config") diff --git a/src/connection.rs b/src/connection.rs index 2e544b7..a3e5e99 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -427,6 +427,18 @@ impl RecvHalfConnection { "Received headers", ))) } else { + // refuse stream if connection is preparing for disconnect + if self + .0 + .flags + .get() + .contains(ConnectionFlags::DISCONNECT_WHEN_READY) + { + self.encode(frame::Reset::new(id, frame::Reason::REFUSED_STREAM)); + self.set_flags(ConnectionFlags::STREAM_REFUSED); + return Ok(None); + } + if let Some(max) = self.0.local_config.0.remote_max_concurrent_streams.get() { if self.0.active_remote_streams.get() >= max { // check if client opened more streams than allowed diff --git a/src/server/service.rs b/src/server/service.rs index 69ca304..fa52cc8 100644 --- a/src/server/service.rs +++ b/src/server/service.rs @@ -1,14 +1,12 @@ -use std::{fmt, rc::Rc}; +use std::{fmt, future::poll_fn, future::Future, pin::Pin, rc::Rc}; use ntex_io::{Dispatcher as IoDispatcher, Filter, Io, IoBoxed}; use ntex_service::{Service, ServiceCtx, ServiceFactory}; use ntex_util::time::timeout_checked; -use crate::connection::Connection; use crate::control::{Control, ControlAck}; -use crate::{ - codec::Codec, config::Config, consts, dispatcher::Dispatcher, frame, message::Message, -}; +use crate::{codec::Codec, connection::Connection}; +use crate::{config::Config, consts, dispatcher::Dispatcher, frame, message::Message}; use super::{ServerBuilder, ServerError}; @@ -139,14 +137,21 @@ where // create h2 codec let codec = Codec::default(); let con = Connection::new(io.get_ref(), codec.clone(), inner.config.clone(), true); + let con2 = con.clone(); // start protocol dispatcher - IoDispatcher::new( + let mut fut = IoDispatcher::new( io, codec, Dispatcher::new(con, ctl_srv, pub_srv), &inner.config.inner().dispatcher_config, - ) + ); + poll_fn(|cx| { + if con2.config().is_shutdown() { + con2.disconnect_when_ready(); + } + Pin::new(&mut fut).poll(cx) + }) .await .map_err(|_| ServerError::Dispatcher) } @@ -245,14 +250,22 @@ where // create h2 codec let codec = Codec::default(); let con = Connection::new(io.get_ref(), codec.clone(), config.clone(), true); + let con2 = con.clone(); // start protocol dispatcher - IoDispatcher::new( + let mut fut = IoDispatcher::new( io, codec, Dispatcher::new(con, ctl_svc, pub_svc), &config.inner().dispatcher_config, - ) + ); + + poll_fn(|cx| { + if con2.config().is_shutdown() { + con2.disconnect_when_ready(); + } + Pin::new(&mut fut).poll(cx) + }) .await .map_err(|_| ServerError::Dispatcher) }