From e9222069b0114b2d735e3f09a53817d980737380 Mon Sep 17 00:00:00 2001 From: Tim Vilgot Mikael Fredenberg Date: Sat, 24 Jun 2023 17:50:02 +0200 Subject: [PATCH 1/5] refactor(gateway): inline `future` into `shard` Unifies the core shard logic into one module to make it easier to reason about. --- twilight-gateway/src/future.rs | 171 --------------------------------- twilight-gateway/src/lib.rs | 1 - twilight-gateway/src/shard.rs | 138 +++++++++++++++++--------- 3 files changed, 92 insertions(+), 218 deletions(-) delete mode 100644 twilight-gateway/src/future.rs diff --git a/twilight-gateway/src/future.rs b/twilight-gateway/src/future.rs deleted file mode 100644 index 29097c93d57..00000000000 --- a/twilight-gateway/src/future.rs +++ /dev/null @@ -1,171 +0,0 @@ -//! Various utility futures used by the [`Shard`]. -//! -//! These tend to be used to get around lifetime and borrow requirements, but -//! are also sometimes used to simplify logic. -//! -//! [`Shard`]: crate::Shard - -use crate::{connection::Connection, CloseFrame, CommandRatelimiter, ConnectionStatus}; -use futures_util::{future::FutureExt, stream::Next}; -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; -use tokio::{ - sync::mpsc, - task::JoinHandle, - time::{self, Duration, Interval}, -}; -use tokio_tungstenite::tungstenite::{Error as TungsteniteError, Message as TungsteniteMessage}; - -/// Resolved value from polling a [`NextMessageFuture`]. -/// -/// **Be sure** to keep variants in sync with documented precedence in -/// [`NextMessageFuture`]! -pub enum NextMessageFutureOutput { - /// Message has been received from the Websocket connection. - /// - /// If no message is present then the stream has ended and a new connection - /// will need to be made. - Message(Option>), - /// Heartbeat must now be sent to Discord. - SendHeartbeat, - /// Identify may now be sent to Discord. - SendIdentify, - /// Close frame has been received from the user to be relayed over the - /// Websocket connection. - UserClose(CloseFrame<'static>), - /// Message has been received from the user to be relayed over the Websocket - /// connection. - UserCommand(String), -} - -/// Future to determine the next action when [`Shard::next_message`] is called. -/// -/// Polled futures are given a consistent precedence, from first to last polled: -/// -/// - [relaying a user's close frame][1] over the Websocket connection; -/// - [sending a heartbeat to Discord][2]; -/// - [sending an identify to Discord][3]; -/// - [relaying a user's message][4] over the Websocket connection; -/// - [receiving a message][5] from Discord -/// -/// **Be sure** to keep documented precedence in sync with variants in -/// [`NextMessageFutureOutput`]! -/// -/// [1]: NextMessageFutureOutput::UserClose -/// [2]: NextMessageFutureOutput::SendHeartbeat -/// [3]: NextMessageFutureOutput::SendIdentify -/// [4]: NextMessageFutureOutput::UserCommand -/// [5]: NextMessageFutureOutput::Message -/// [`Shard::next_message`]: crate::Shard::next_message -pub struct NextMessageFuture<'a> { - /// Receiver of user sent close frames to be relayed over the Websocket - /// connection. - close_receiver: &'a mut mpsc::Receiver>, - /// Receiver of user sent commands to be relayed over the Websocket - /// connection. - command_receiver: &'a mut mpsc::UnboundedReceiver, - /// Heartbeat interval, if enadbled. - heartbeat_interval: Option<&'a mut Interval>, - /// Identify queue background task handle. - identify_handle: Option<&'a mut JoinHandle<()>>, - /// Future resolving when the next Websocket message has been received. - message_future: Next<'a, Connection>, - /// Command ratelimiter, if enabled. - ratelimiter: Option<&'a mut CommandRatelimiter>, - /// Shard's connection status. - status: &'a ConnectionStatus, -} - -impl<'a> NextMessageFuture<'a> { - /// Initialize a new series of futures determining the next action to take. - pub fn new( - close_receiver: &'a mut mpsc::Receiver>, - command_receiver: &'a mut mpsc::UnboundedReceiver, - status: &'a ConnectionStatus, - identify_handle: Option<&'a mut JoinHandle<()>>, - message_future: Next<'a, Connection>, - heartbeat_interval: Option<&'a mut Interval>, - ratelimiter: Option<&'a mut CommandRatelimiter>, - ) -> Self { - Self { - close_receiver, - command_receiver, - heartbeat_interval, - identify_handle, - message_future, - ratelimiter, - status, - } - } -} - -impl Future for NextMessageFuture<'_> { - type Output = NextMessageFutureOutput; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.as_mut(); - - if !(this.status.is_disconnected() || this.status.is_fatally_closed()) { - if let Poll::Ready(frame) = this.close_receiver.poll_recv(cx) { - return Poll::Ready(NextMessageFutureOutput::UserClose( - frame.expect("shard owns channel"), - )); - } - } - - if this - .heartbeat_interval - .as_mut() - .map_or(false, |heartbeater| heartbeater.poll_tick(cx).is_ready()) - { - return Poll::Ready(NextMessageFutureOutput::SendHeartbeat); - } - - let ratelimited = this.ratelimiter.as_mut().map_or(false, |ratelimiter| { - ratelimiter.poll_available(cx).is_pending() - }); - - // Must poll to register waker. - if !ratelimited - && this - .identify_handle - .as_mut() - .map_or(false, |handle| handle.poll_unpin(cx).is_ready()) - { - return Poll::Ready(NextMessageFutureOutput::SendIdentify); - } - - if !ratelimited && this.status.is_identified() { - if let Poll::Ready(message) = this.command_receiver.poll_recv(cx) { - return Poll::Ready(NextMessageFutureOutput::UserCommand( - message.expect("shard owns channel"), - )); - } - } - - if let Poll::Ready(maybe_try_message) = this.message_future.poll_unpin(cx) { - return Poll::Ready(NextMessageFutureOutput::Message(maybe_try_message)); - } - - Poll::Pending - } -} - -/// Future that will resolve when the delay for a reconnect passes. -/// -/// The duration of the future is defined by the number of attempts at -/// reconnecting that have already been made. The math behind it is -/// `2 ^ attempts`, maxing out at `MAX_WAIT_SECONDS`. -pub async fn reconnect_delay(reconnect_attempts: u8) { - /// The maximum wait before resolving, in seconds. - const MAX_WAIT_SECONDS: u8 = 128; - - let wait = 2_u8 - .saturating_pow(reconnect_attempts.into()) - .min(MAX_WAIT_SECONDS); - - time::sleep(Duration::from_secs(wait.into())).await; -} diff --git a/twilight-gateway/src/lib.rs b/twilight-gateway/src/lib.rs index 8594edd20fc..c64e1e87338 100644 --- a/twilight-gateway/src/lib.rs +++ b/twilight-gateway/src/lib.rs @@ -21,7 +21,6 @@ mod command; mod config; mod connection; mod event; -mod future; #[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))] mod inflater; mod json; diff --git a/twilight-gateway/src/shard.rs b/twilight-gateway/src/shard.rs index 1fde728c93f..11ac705fe13 100644 --- a/twilight-gateway/src/shard.rs +++ b/twilight-gateway/src/shard.rs @@ -63,7 +63,6 @@ use crate::{ ProcessError, ProcessErrorType, ReceiveMessageError, ReceiveMessageErrorType, SendError, SendErrorType, }, - future::{self, NextMessageFuture, NextMessageFutureOutput}, json::{self, UnknownEventError}, latency::Latency, ratelimiter::CommandRatelimiter, @@ -78,19 +77,19 @@ use serde::{de::DeserializeOwned, Deserialize}; feature = "rustls-webpki-roots" ))] use std::io::ErrorKind as IoErrorKind; -use std::{env::consts::OS, error::Error, str}; +use std::{ + env::consts::OS, + error::Error, + future::{poll_fn, Future}, + pin::Pin, + str, + task::{Context, Poll}, +}; use tokio::{ task::JoinHandle, time::{self, Duration, Instant, Interval, MissedTickBehavior}, }; -#[cfg(any( - feature = "native", - feature = "rustls-native-roots", - feature = "rustls-webpki-roots" -))] -use tokio_tungstenite::tungstenite::Error as TungsteniteError; -#[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))] -use tokio_tungstenite::tungstenite::Message as TungsteniteMessage; +use tokio_tungstenite::tungstenite::{Error as TungsteniteError, Message as TungsteniteMessage}; use twilight_model::gateway::{ event::{Event, GatewayEventDeserializer}, payload::{ @@ -579,18 +578,79 @@ impl Shard { } let message = loop { - let future = NextMessageFuture::new( - &mut self.user_channel.close_rx, - &mut self.user_channel.command_rx, - &self.status, - self.identify_handle.as_mut(), - self.connection.as_mut().expect("connected").next(), - self.heartbeat_interval.as_mut(), - self.ratelimiter.as_mut(), - ); - - let tungstenite_message = match future.await { - NextMessageFutureOutput::Message(Some(Ok(message))) => message, + /// Actions the shard might take. + enum Action { + /// Close the gateway connection with this close frame. + Close(CloseFrame<'static>), + /// Send this command to the gateway. + Command(String), + /// Send a heartbeat command to the gateway. + Heartbeat, + /// Identify with the gateway. + Identify, + /// Handle this incoming gateway message. + Message(Option>), + } + let next_action = |cx: &mut Context<'_>| { + if !(self.status.is_disconnected() || self.status.is_fatally_closed()) { + if let Poll::Ready(frame) = self.user_channel.close_rx.poll_recv(cx) { + return Poll::Ready(Action::Close(frame.expect("shard owns channel"))); + } + } + + if self + .heartbeat_interval + .as_mut() + .map_or(false, |heartbeater| heartbeater.poll_tick(cx).is_ready()) + { + return Poll::Ready(Action::Heartbeat); + } + + let ratelimited = self.ratelimiter.as_mut().map_or(false, |ratelimiter| { + ratelimiter.poll_available(cx).is_pending() + }); + + if !ratelimited + && self + .identify_handle + .as_mut() + .map_or(false, |handle| Pin::new(handle).poll(cx).is_ready()) + { + return Poll::Ready(Action::Identify); + } + + if !ratelimited && self.status.is_identified() { + if let Poll::Ready(command) = self.user_channel.command_rx.poll_recv(cx) { + return Poll::Ready(Action::Command(command.expect("shard owns channel"))); + } + } + + if let Poll::Ready(message) = + Pin::new(&mut self.connection.as_mut().expect("connected").next()).poll(cx) + { + return Poll::Ready(Action::Message(message)); + } + + Poll::Pending + }; + + match poll_fn(next_action).await { + Action::Message(Some(Ok(message))) => { + #[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))] + if let TungsteniteMessage::Binary(bytes) = &message { + if let Some(decompressed) = self + .inflater + .inflate(bytes) + .map_err(ReceiveMessageError::from_compression)? + { + tracing::trace!(%decompressed); + break Message::Text(decompressed); + }; + } + if let Some(message) = Message::from_tungstenite(message) { + break message; + } + } // Discord, against recommendations from the WebSocket spec, // does not send a close_notify prior to shutting down the TCP // stream. This arm tries to gracefully handle this. The @@ -601,7 +661,7 @@ impl Shard { feature = "rustls-native-roots", feature = "rustls-webpki-roots" ))] - NextMessageFutureOutput::Message(Some(Err(TungsteniteError::Io(e)))) + Action::Message(Some(Err(TungsteniteError::Io(e)))) if e.kind() == IoErrorKind::UnexpectedEof // Assert we're directly connected to Discord's gateway. && self.config.proxy_url().is_none() @@ -609,7 +669,7 @@ impl Shard { { continue } - NextMessageFutureOutput::Message(Some(Err(source))) => { + Action::Message(Some(Err(source))) => { self.disconnect(CloseInitiator::None); return Err(ReceiveMessageError { @@ -617,7 +677,7 @@ impl Shard { source: Some(Box::new(source)), }); } - NextMessageFutureOutput::Message(None) => { + Action::Message(None) => { tracing::debug!("gateway connection closed"); self.connection = None; @@ -640,7 +700,7 @@ impl Shard { continue; } - NextMessageFutureOutput::SendHeartbeat => { + Action::Heartbeat => { let is_first_heartbeat = self.heartbeat_interval.is_some() && self.latency.sent().is_none(); @@ -664,7 +724,7 @@ impl Shard { continue; } - NextMessageFutureOutput::SendIdentify => { + Action::Identify => { self.identify_handle = None; tracing::debug!("sending identify"); @@ -689,7 +749,7 @@ impl Shard { continue; } - NextMessageFutureOutput::UserClose(frame) => { + Action::Close(frame) => { tracing::debug!("sending close frame from user channel"); self.session = self .close(frame) @@ -698,29 +758,14 @@ impl Shard { continue; } - NextMessageFutureOutput::UserCommand(message) => { + Action::Command(json) => { tracing::debug!("sending command from user channel"); - self.send(message) + self.send(json) .await .map_err(ReceiveMessageError::from_send)?; continue; } - }; - - #[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))] - if let TungsteniteMessage::Binary(bytes) = &tungstenite_message { - if let Some(decompressed) = self - .inflater - .inflate(bytes) - .map_err(ReceiveMessageError::from_compression)? - { - tracing::trace!(%decompressed); - break Message::Text(decompressed); - }; - } - if let Some(message) = Message::from_tungstenite(tungstenite_message) { - break message; } }; @@ -1137,7 +1182,8 @@ impl Shard { close_code: Option, reconnect_attempts: u8, ) -> Result<(), ReceiveMessageError> { - future::reconnect_delay(reconnect_attempts).await; + let secs = 2u8.saturating_pow(reconnect_attempts.into()); + time::sleep(Duration::from_secs(secs.into())).await; let maybe_gateway_url = self .resume_gateway_url From e4829a5fa2139b77ff1fcf41301d283450173a00 Mon Sep 17 00:00:00 2001 From: Tim Vilgot Mikael Fredenberg Date: Sat, 24 Jun 2023 18:00:19 +0200 Subject: [PATCH 2/5] fix(gateway): remove delay from first reconnect attempt 2^0 is 1 not 0 --- twilight-gateway/src/shard.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/twilight-gateway/src/shard.rs b/twilight-gateway/src/shard.rs index 11ac705fe13..b43cefe70fd 100644 --- a/twilight-gateway/src/shard.rs +++ b/twilight-gateway/src/shard.rs @@ -1182,8 +1182,10 @@ impl Shard { close_code: Option, reconnect_attempts: u8, ) -> Result<(), ReceiveMessageError> { - let secs = 2u8.saturating_pow(reconnect_attempts.into()); - time::sleep(Duration::from_secs(secs.into())).await; + if reconnect_attempts != 0 { + let secs = 2u8.saturating_pow(reconnect_attempts.into()); + time::sleep(Duration::from_secs(secs.into())).await; + } let maybe_gateway_url = self .resume_gateway_url From 90422cbf45c1f29d88c73284413382441bf21796 Mon Sep 17 00:00:00 2001 From: Tim Vilgot Mikael Fredenberg Date: Sat, 24 Jun 2023 18:31:57 +0200 Subject: [PATCH 3/5] move the Action enum declaration outside the loop --- twilight-gateway/src/shard.rs | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/twilight-gateway/src/shard.rs b/twilight-gateway/src/shard.rs index b43cefe70fd..5a3c9c9d136 100644 --- a/twilight-gateway/src/shard.rs +++ b/twilight-gateway/src/shard.rs @@ -577,20 +577,21 @@ impl Shard { _ => {} } + /// Actions the shard might take. + enum Action { + /// Close the gateway connection with this close frame. + Close(CloseFrame<'static>), + /// Send this command to the gateway. + Command(String), + /// Send a heartbeat command to the gateway. + Heartbeat, + /// Identify with the gateway. + Identify, + /// Handle this incoming gateway message. + Message(Option>), + } + let message = loop { - /// Actions the shard might take. - enum Action { - /// Close the gateway connection with this close frame. - Close(CloseFrame<'static>), - /// Send this command to the gateway. - Command(String), - /// Send a heartbeat command to the gateway. - Heartbeat, - /// Identify with the gateway. - Identify, - /// Handle this incoming gateway message. - Message(Option>), - } let next_action = |cx: &mut Context<'_>| { if !(self.status.is_disconnected() || self.status.is_fatally_closed()) { if let Poll::Ready(frame) = self.user_channel.close_rx.poll_recv(cx) { From 3501ab496a07502f7eb16b397b7755105aa1f0f3 Mon Sep 17 00:00:00 2001 From: Tim Vilgot Mikael Fredenberg Date: Sat, 24 Jun 2023 18:52:57 +0200 Subject: [PATCH 4/5] clippy --- twilight-gateway/src/shard.rs | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/twilight-gateway/src/shard.rs b/twilight-gateway/src/shard.rs index 5a3c9c9d136..04f6a122541 100644 --- a/twilight-gateway/src/shard.rs +++ b/twilight-gateway/src/shard.rs @@ -558,6 +558,20 @@ impl Shard { /// shard failed to send a message to the gateway, such as a heartbeat. #[tracing::instrument(fields(id = %self.id()), name = "shard", skip(self))] pub async fn next_message(&mut self) -> Result { + /// Actions the shard might take. + enum Action { + /// Close the gateway connection with this close frame. + Close(CloseFrame<'static>), + /// Send this command to the gateway. + Command(String), + /// Send a heartbeat command to the gateway. + Heartbeat, + /// Identify with the gateway. + Identify, + /// Handle this incoming gateway message. + Message(Option>), + } + match self.status { ConnectionStatus::Disconnected { close_code, @@ -577,20 +591,6 @@ impl Shard { _ => {} } - /// Actions the shard might take. - enum Action { - /// Close the gateway connection with this close frame. - Close(CloseFrame<'static>), - /// Send this command to the gateway. - Command(String), - /// Send a heartbeat command to the gateway. - Heartbeat, - /// Identify with the gateway. - Identify, - /// Handle this incoming gateway message. - Message(Option>), - } - let message = loop { let next_action = |cx: &mut Context<'_>| { if !(self.status.is_disconnected() || self.status.is_fatally_closed()) { From 4ec273d47aea1fa154021a88498f99360d62c7f6 Mon Sep 17 00:00:00 2001 From: Tim Vilgot Mikael Fredenberg Date: Sat, 1 Jul 2023 17:21:14 +0200 Subject: [PATCH 5/5] fmt --- twilight-gateway/src/shard.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/twilight-gateway/src/shard.rs b/twilight-gateway/src/shard.rs index e22ba78c81f..74d3223a5cd 100644 --- a/twilight-gateway/src/shard.rs +++ b/twilight-gateway/src/shard.rs @@ -607,9 +607,10 @@ impl Shard { return Poll::Ready(Action::Heartbeat); } - let ratelimited = self.ratelimiter.as_mut().map_or(false, |ratelimiter| { - ratelimiter.poll_ready(cx).is_pending() - }); + let ratelimited = self + .ratelimiter + .as_mut() + .map_or(false, |ratelimiter| ratelimiter.poll_ready(cx).is_pending()); if !ratelimited && self