diff --git a/twilight-gateway/Cargo.toml b/twilight-gateway/Cargo.toml index beeba781a63..55a4b8150f6 100644 --- a/twilight-gateway/Cargo.toml +++ b/twilight-gateway/Cargo.toml @@ -15,12 +15,12 @@ version = "0.15.4" [dependencies] bitflags = { default-features = false, version = "2" } -futures-util = { default-features = false, features = ["std"], version = "0.3" } -rand = { default-features = false, features = ["std", "std_rng"], version = "0.8" } +fastrand = { default-features = false, features = ["std"], version = "2" } +futures-util = { default-features = false, features = ["sink", "std"], version = "0.3" } serde = { default-features = false, features = ["derive"], version = "1" } serde_json = { default-features = false, features = ["std"], version = "1" } tokio = { default-features = false, features = ["net", "rt", "sync", "time"], version = "1.19" } -tokio-tungstenite = { default-features = false, features = ["connect"], version = "0.19" } +tokio-websockets = { default-features = false, features = ["client", "fastrand", "sha1_smol", "simd"], version = "0.4" } tracing = { default-features = false, features = ["std", "attributes"], version = "0.1" } twilight-gateway-queue = { default-features = false, path = "../twilight-gateway-queue", version = "0.15.4" } twilight-model = { default-features = false, path = "../twilight-model", version = "0.15.4" } @@ -34,13 +34,6 @@ flate2 = { default-features = false, optional = true, version = "1.0.24" } twilight-http = { default-features = false, optional = true, path = "../twilight-http", version = "0.15.4" } simd-json = { default-features = false, features = ["serde_impl", "swar-number-parsing"], optional = true, version = ">=0.4, <0.11" } -# TLS libraries -# They are needed to track what is used in tokio-tungstenite -native-tls = { default-features = false, optional = true, version = "0.2.8" } -rustls-native-certs = { default-features = false, optional = true, version = "0.6" } -rustls-tls = { default-features = false, optional = true, package = "rustls", version = "0.21" } -webpki-roots = { default-features = false, optional = true, version = "0.23" } - [dev-dependencies] anyhow = { default-features = false, features = ["std"], version = "1" } futures = { default-features = false, version = "0.3" } @@ -51,9 +44,9 @@ tracing-subscriber = { default-features = false, features = ["fmt", "tracing-log [features] default = ["rustls-native-roots", "twilight-http", "zlib-stock"] -native = ["dep:native-tls", "tokio-tungstenite/native-tls"] -rustls-native-roots = ["dep:rustls-tls", "dep:rustls-native-certs", "tokio-tungstenite/rustls-tls-native-roots"] -rustls-webpki-roots = ["dep:rustls-tls", "dep:webpki-roots", "tokio-tungstenite/rustls-tls-webpki-roots"] +native = ["tokio-websockets/native-tls", "tokio-websockets/openssl"] +rustls-native-roots = ["tokio-websockets/rustls-native-roots"] +rustls-webpki-roots = ["tokio-websockets/rustls-webpki-roots"] zlib-simd = ["dep:flate2", "flate2?/zlib-ng"] zlib-stock = ["dep:flate2", "flate2?/zlib"] diff --git a/twilight-gateway/src/config.rs b/twilight-gateway/src/config.rs index 8c14e6bba40..746cea144f7 100644 --- a/twilight-gateway/src/config.rs +++ b/twilight-gateway/src/config.rs @@ -2,13 +2,13 @@ use crate::{ queue::{InMemoryQueue, Queue}, - tls::TlsContainer, EventTypeFlags, Session, }; use std::{ fmt::{Debug, Formatter, Result as FmtResult}, sync::Arc, }; +use tokio_websockets::Connector; use twilight_model::gateway::{ payload::outgoing::{identify::IdentifyProperties, update_presence::UpdatePresencePayload}, Intents, @@ -69,7 +69,7 @@ pub struct Config { /// TLS connector for Websocket connections. // We need this to be public so [`stream`] can re-use TLS on multiple shards // if unconfigured. - tls: TlsContainer, + tls: Arc, /// Token used to authenticate when identifying with the gateway. /// /// The token is prefixed with "Bot ", which is required by Discord for @@ -147,7 +147,7 @@ impl Config { } /// Immutable reference to the TLS connector in use by the shard. - pub(crate) const fn tls(&self) -> &TlsContainer { + pub(crate) fn tls(&self) -> &Connector { &self.tls } @@ -195,7 +195,7 @@ impl ConfigBuilder { queue: Arc::new(InMemoryQueue::default()), ratelimit_messages: true, session: None, - tls: TlsContainer::new().unwrap(), + tls: Arc::new(Connector::new().unwrap()), token: Token::new(token.into_boxed_str()), }, } diff --git a/twilight-gateway/src/connection.rs b/twilight-gateway/src/connection.rs index 1061d735f66..aba90a10d6a 100644 --- a/twilight-gateway/src/connection.rs +++ b/twilight-gateway/src/connection.rs @@ -1,9 +1,12 @@ //! Utilities for creating Websocket connections. -use crate::{error::ReceiveMessageError, tls::TlsContainer, API_VERSION}; +use crate::{ + error::{ReceiveMessageError, ReceiveMessageErrorType}, + API_VERSION, +}; use std::fmt::{Display, Formatter, Result as FmtResult}; use tokio::net::TcpStream; -use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, MaybeTlsStream, WebSocketStream}; +use tokio_websockets::{ClientBuilder, Connector, Limits, MaybeTlsStream, WebsocketStream}; /// Query argument with zlib-stream enabled. #[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))] @@ -16,29 +19,12 @@ const COMPRESSION_FEATURES: &str = ""; /// URL of the Discord gateway. const GATEWAY_URL: &str = "wss://gateway.discord.gg"; -/// Configuration used for Websocket connections. -/// -/// `max_frame_size` and `max_message_queue` limits are disabled because -/// Discord is not a malicious actor and having a limit has caused problems on -/// large [`GuildCreate`] payloads. -/// -/// `accept_unmasked_frames` and `max_send_queue` are set to their -/// defaults. -/// -/// [`GuildCreate`]: twilight_model::gateway::payload::incoming::GuildCreate -const WEBSOCKET_CONFIG: WebSocketConfig = WebSocketConfig { - accept_unmasked_frames: false, - max_frame_size: None, - max_message_size: None, - max_send_queue: None, -}; - -/// [`tokio_tungstenite`] library Websocket connection. +/// [`tokio_websockets`] library Websocket connection. /// /// Connections are used by [`Shard`]s when reconnecting. /// /// [`Shard`]: crate::Shard -pub type Connection = WebSocketStream>; +pub type Connection = WebsocketStream>; /// Formatter for a gateway URL, with the API version and compression features /// specified. @@ -93,12 +79,30 @@ impl Display for ConnectionUrl<'_> { #[tracing::instrument(skip_all)] pub async fn connect( maybe_gateway_url: Option<&str>, - tls: &TlsContainer, + tls: &Connector, ) -> Result { let url = ConnectionUrl::new(maybe_gateway_url).to_string(); + // Limits to impose on Websocket connections. + // + // `max_payload_len` limit is disabled because Discord is not a malicious + // actor and having a limit has caused problems on large `GuildCreate` + // payloads. + let limits = Limits::default().max_payload_len(None); + tracing::debug!(?url, "shaking hands with gateway"); - let stream = tls.connect(&url, WEBSOCKET_CONFIG).await?; + + let (stream, _) = ClientBuilder::new() + .uri(&url) + .expect("Gateway URL must be valid") + .limits(limits) + .connector(tls) + .connect() + .await + .map_err(|source| ReceiveMessageError { + kind: ReceiveMessageErrorType::Reconnect, + source: Some(Box::new(source)), + })?; Ok(stream) } diff --git a/twilight-gateway/src/lib.rs b/twilight-gateway/src/lib.rs index c64e1e87338..bbce47034b0 100644 --- a/twilight-gateway/src/lib.rs +++ b/twilight-gateway/src/lib.rs @@ -29,7 +29,6 @@ mod message; mod ratelimiter; mod session; mod shard; -mod tls; #[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))] pub use self::inflater::Inflater; diff --git a/twilight-gateway/src/message.rs b/twilight-gateway/src/message.rs index 29e9c2b11f2..9987a295e77 100644 --- a/twilight-gateway/src/message.rs +++ b/twilight-gateway/src/message.rs @@ -6,10 +6,9 @@ //! input will not be checked and will be passed directly to the underlying //! websocket library. -use tokio_tungstenite::tungstenite::{ - protocol::{frame::coding::CloseCode, CloseFrame as TungsteniteCloseFrame}, - Message as TungsteniteMessage, -}; +use std::borrow::Cow; + +use tokio_websockets::{CloseCode, Message as WebsocketMessage}; use twilight_model::gateway::CloseFrame; /// Message to send over the connection to the remote. @@ -25,33 +24,36 @@ pub enum Message { } impl Message { - /// Convert a `tungstenite` websocket message into a `twilight` websocket + /// Convert a `tokio-websockets` websocket message into a `twilight` websocket /// message. - pub(crate) fn from_tungstenite(tungstenite: TungsteniteMessage) -> Option { - match tungstenite { - TungsteniteMessage::Close(frame) => Some(Self::Close(frame.map(|frame| CloseFrame { - code: frame.code.into(), - reason: frame.reason, - }))), - TungsteniteMessage::Text(string) => Some(Self::Text(string)), - TungsteniteMessage::Binary(_) - | TungsteniteMessage::Frame(_) - | TungsteniteMessage::Ping(_) - | TungsteniteMessage::Pong(_) => None, + pub(crate) fn from_websocket_msg(msg: &WebsocketMessage) -> Option { + if msg.is_close() { + let (code, reason) = msg.as_close().unwrap(); + + let frame = (code == CloseCode::NO_STATUS_RECEIVED).then(|| CloseFrame { + code: code.into(), + reason: Cow::Owned(reason.to_string()), + }); + + Some(Self::Close(frame)) + } else if msg.is_text() { + Some(Self::Text(msg.as_text().unwrap().to_owned())) + } else { + None } } - /// Convert a `twilight` websocket message into a `tungstenite` websocket + /// Convert a `twilight` websocket message into a `tokio-websockets` websocket /// message. - pub(crate) fn into_tungstenite(self) -> TungsteniteMessage { + pub(crate) fn into_websocket_msg(self) -> WebsocketMessage { match self { - Self::Close(frame) => { - TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame { - code: CloseCode::from(frame.code), - reason: frame.reason, - })) - } - Self::Text(string) => TungsteniteMessage::Text(string), + Self::Close(frame) => WebsocketMessage::close( + frame + .as_ref() + .and_then(|f| CloseCode::try_from(f.code).ok()), + frame.map(|f| f.reason).as_deref().unwrap_or_default(), + ), + Self::Text(string) => WebsocketMessage::text(string), } } } diff --git a/twilight-gateway/src/shard.rs b/twilight-gateway/src/shard.rs index 6dfc5f143a6..3dc707eb4b1 100644 --- a/twilight-gateway/src/shard.rs +++ b/twilight-gateway/src/shard.rs @@ -89,7 +89,7 @@ use tokio::{ sync::oneshot, time::{self, Duration, Instant, Interval, MissedTickBehavior}, }; -use tokio_tungstenite::tungstenite::{Error as TungsteniteError, Message as TungsteniteMessage}; +use tokio_websockets::{Error as WebsocketError, Message as WebsocketMessage}; use twilight_model::gateway::{ event::{Event, GatewayEventDeserializer}, payload::{ @@ -568,7 +568,7 @@ impl Shard { /// Identify with the gateway. Identify, /// Handle this incoming gateway message. - Message(Option>), + Message(Option>), } match self.status { @@ -672,17 +672,17 @@ impl Shard { match poll_fn(next_action).await { Action::Message(Some(Ok(message))) => { #[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))] - if let TungsteniteMessage::Binary(bytes) = &message { + if message.is_binary() { if let Some(decompressed) = self .inflater - .inflate(bytes) + .inflate(message.as_payload()) .map_err(ReceiveMessageError::from_compression)? { tracing::trace!(%decompressed); break Message::Text(decompressed); }; } - if let Some(message) = Message::from_tungstenite(message) { + if let Some(message) = Message::from_websocket_msg(&message) { break message; } } @@ -696,7 +696,7 @@ impl Shard { feature = "rustls-native-roots", feature = "rustls-webpki-roots" ))] - Action::Message(Some(Err(TungsteniteError::Io(e)))) + Action::Message(Some(Err(WebsocketError::Io(e)))) if e.kind() == IoErrorKind::UnexpectedEof // Assert we're directly connected to Discord's gateway. && self.config.proxy_url().is_none() @@ -804,13 +804,11 @@ impl Shard { match &message { Message::Close(frame) => { - // Tungstenite automatically replies to the close message. + // tokio-websockets automatically replies to the close message. tracing::debug!(?frame, "received websocket close message"); // Don't run `disconnect` if we initiated the close. if !self.status.is_disconnected() { - self.disconnect(CloseInitiator::Gateway( - frame.as_ref().map(|frame| frame.code), - )); + self.disconnect(CloseInitiator::Gateway(frame.as_ref().map(|f| f.code))); } } Message::Text(event) => { @@ -913,7 +911,7 @@ impl Shard { kind: SendErrorType::Sending, source: None, })? - .send(message.into_tungstenite()) + .send(message.into_websocket_msg()) .await .map_err(|source| SendError { kind: SendErrorType::Sending, @@ -1116,7 +1114,7 @@ impl Shard { let heartbeat_interval = Duration::from_millis(event.data.heartbeat_interval); // First heartbeat should have some jitter, see // https://discord.com/developers/docs/topics/gateway#heartbeat-interval - let jitter = heartbeat_interval.mul_f64(rand::random()); + let jitter = heartbeat_interval.mul_f64(fastrand::f64()); tracing::debug!(?heartbeat_interval, ?jitter, "received hello"); if self.config().ratelimit_messages() { diff --git a/twilight-gateway/src/tls.rs b/twilight-gateway/src/tls.rs deleted file mode 100644 index f100a22a09c..00000000000 --- a/twilight-gateway/src/tls.rs +++ /dev/null @@ -1,332 +0,0 @@ -//! TLS manager to reuse connections between shards. - -#[cfg(not(any( - feature = "native", - feature = "rustls-native-roots", - feature = "rustls-webpki-roots" -)))] -mod r#impl { - //! Plain connections with no TLS. - - /// No connector is used when plain connections are enabled. - pub type TlsConnector = (); - - use super::{TlsContainer, TlsError}; - use crate::{ - connection::Connection, - error::{ReceiveMessageError, ReceiveMessageErrorType}, - }; - use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, Connector}; - - /// Create a TLS container without a TLS connector. - /// - /// # Errors - /// - /// Never returns an error, and only returns a Result to reach parity when - /// TLS features are enabled. - pub fn new() -> Result { - Ok(TlsContainer { tls: None }) - } - - /// Connect to the provided URL without TLS. - pub async fn connect( - url: &str, - config: WebSocketConfig, - _tls: &TlsContainer, - ) -> Result { - let (stream, _) = tokio_tungstenite::connect_async_with_config(url, Some(config), false) - .await - .map_err(|source| ReceiveMessageError { - kind: ReceiveMessageErrorType::Reconnect, - source: Some(Box::new(source)), - })?; - - Ok(stream) - } - - /// No TLS connector. - pub fn connector(_: &TlsContainer) -> Option { - None - } -} - -#[cfg(all( - feature = "native", - not(any(feature = "rustls-native-roots", feature = "rustls-webpki-roots")) -))] -mod r#impl { - //! Native TLS - - pub use native_tls::TlsConnector; - - use super::{TlsContainer, TlsError, TlsErrorType}; - use crate::{ - connection::Connection, - error::{ReceiveMessageError, ReceiveMessageErrorType}, - }; - use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, Connector}; - - /// Create a new TLS connector. - /// - /// # Errors - /// - /// Returns a [`TlsErrorType::Loading`] error type if the TLS connector - /// couldn't be initialized. - pub fn new() -> Result { - let native_connector = TlsConnector::new().map_err(|err| TlsError { - kind: TlsErrorType::Loading, - source: Some(Box::new(err)), - })?; - - Ok(TlsContainer { - tls: Some(native_connector), - }) - } - - /// Connect to the provided URL with the underlying TLS connector. - pub async fn connect( - url: &str, - config: WebSocketConfig, - tls: &TlsContainer, - ) -> Result { - let (stream, _) = tokio_tungstenite::connect_async_tls_with_config( - url, - Some(config), - false, - tls.connector(), - ) - .await - .map_err(|source| ReceiveMessageError { - kind: ReceiveMessageErrorType::Reconnect, - source: Some(Box::new(source)), - })?; - - Ok(stream) - } - - /// Clone the underlying TLS connector for native TLS. - pub fn connector(container: &TlsContainer) -> Option { - container - .tls - .as_ref() - .map(|tls| Connector::NativeTls(tls.clone())) - } -} - -#[cfg(any(feature = "rustls-native-roots", feature = "rustls-webpki-roots"))] -mod r#impl { - //! Rustls - - use super::{TlsContainer, TlsError}; - use crate::{ - connection::Connection, - error::{ReceiveMessageError, ReceiveMessageErrorType}, - }; - use rustls_tls::ClientConfig; - use std::sync::Arc; - use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, Connector}; - - /// Rustls client configuration. - pub type TlsConnector = Arc; - - /// Create a new TLS connector. - /// - /// # Errors - /// - /// Returns a `TlsErrorType::Loading` error type if the TLS connector - /// couldn't be initialized. - #[cfg(any(feature = "rustls-native-roots", feature = "rustls-webpki-roots"))] - pub fn new() -> Result { - let mut roots = rustls_tls::RootCertStore::empty(); - - #[cfg(feature = "rustls-native-roots")] - { - let certs = rustls_native_certs::load_native_certs().map_err(|err| TlsError { - kind: super::TlsErrorType::Loading, - source: Some(Box::new(err)), - })?; - - for cert in certs { - roots - .add(&rustls_tls::Certificate(cert.0)) - .map_err(|err| TlsError { - kind: super::TlsErrorType::Loading, - source: Some(Box::new(err)), - })?; - } - } - - #[cfg(feature = "rustls-webpki-roots")] - { - roots.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { - rustls_tls::OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); - }; - - let config = ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(roots) - .with_no_client_auth(); - - Ok(TlsContainer { - tls: Some(Arc::new(config)), - }) - } - - /// Connect to the provided URL with the underlying TLS connector. - pub async fn connect( - url: &str, - config: WebSocketConfig, - tls: &TlsContainer, - ) -> Result { - let (stream, _) = tokio_tungstenite::connect_async_tls_with_config( - url, - Some(config), - false, - tls.connector(), - ) - .await - .map_err(|source| ReceiveMessageError { - kind: ReceiveMessageErrorType::Reconnect, - source: Some(Box::new(source)), - })?; - - Ok(stream) - } - - /// Clone the underlying TLS connector for rustls. - pub fn connector(container: &TlsContainer) -> Option { - container - .tls - .as_ref() - .map(|tls| Connector::Rustls(Arc::clone(tls))) - } -} - -use r#impl::TlsConnector; -use std::{ - error::Error, - fmt::{Debug, Display, Formatter, Result as FmtResult}, -}; -use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, Connector}; - -use crate::{connection::Connection, error::ReceiveMessageError}; - -/// Creating a TLS connector failed, possibly due to loading certificates. -#[derive(Debug)] -pub struct TlsError { - /// Type of error. - kind: TlsErrorType, - /// Source error if available. - source: Option>, -} - -#[allow(dead_code)] -impl TlsError { - /// Immutable reference to the type of error that occurred. - #[must_use = "retrieving the type has no effect if left unused"] - pub const fn kind(&self) -> &TlsErrorType { - &self.kind - } - - /// Consume the error, returning the source error if there is any. - #[must_use = "consuming the error and retrieving the source has no effect if left unused"] - pub fn into_source(self) -> Option> { - self.source - } - - /// Consume the error, returning the owned error type and the source error. - #[must_use = "consuming the error into its parts has no effect if left unused"] - pub fn into_parts(self) -> (TlsErrorType, Option>) { - (self.kind, self.source) - } -} - -impl Display for TlsError { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - match self.kind { - TlsErrorType::Loading => { - f.write_str("failed to load the tls connector or its certificates") - } - } - } -} - -impl Error for TlsError { - fn source(&self) -> Option<&(dyn Error + 'static)> { - self.source - .as_ref() - .map(|source| &**source as &(dyn Error + 'static)) - } -} - -/// Type of [`TlsError`] that occurred. -#[derive(Debug)] -#[non_exhaustive] -pub enum TlsErrorType { - /// Loading the TLS connector or its certificates failed. - #[allow(unused)] - Loading, -} - -/// Wrapper over a native or Rustls TLS connector. -#[derive(Clone)] -pub struct TlsContainer { - /// TLS connector, which won't be present if no TLS feature is enabled. - #[allow(unused)] - tls: Option, -} - -impl Debug for TlsContainer { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - let mut debugger = f.debug_struct("TlsContainer"); - - #[cfg(all( - feature = "native", - not(any(feature = "rustls-native-roots", feature = "rustls-webpki-roots")), - ))] - debugger.field("tls", &self.tls); - - debugger.finish() - } -} - -impl TlsContainer { - /// Create a new TLS connector. - /// - /// # Errors - /// - /// For non-plain TLS, returns a [`TlsErrorType::Loading`] error type if - /// the TLS connector couldn't be initialized. - pub fn new() -> Result { - r#impl::new() - } - - /// Connect to the provided URL with the underlying TLS connector. - pub async fn connect( - &self, - url: &str, - config: WebSocketConfig, - ) -> Result { - r#impl::connect(url, config, self).await - } - - /// Clone of a reference to the connector. - #[allow(unused)] - pub(crate) fn connector(&self) -> Option { - r#impl::connector(self) - } -} - -#[cfg(test)] -mod tests { - use super::TlsContainer; - use static_assertions::assert_impl_all; - use std::fmt::Debug; - - assert_impl_all!(TlsContainer: Debug, Clone, Send, Sync); -} diff --git a/twilight-http/Cargo.toml b/twilight-http/Cargo.toml index 342f555cfe1..c31ced68e5b 100644 --- a/twilight-http/Cargo.toml +++ b/twilight-http/Cargo.toml @@ -14,12 +14,12 @@ rust-version.workspace = true version = "0.15.4" [dependencies] +fastrand = { default-features = false, features = ["std"], version = "2" } hyper = { default-features = false, features = ["client", "http1", "http2", "runtime"], version = "0.14" } hyper-rustls = { default-features = false, optional = true, features = ["http1", "http2"], version = "0.24" } hyper-tls = { default-features = false, optional = true, version = "0.5" } hyper-trust-dns = { default-features = false, optional = true, version = "0.5" } percent-encoding = { default-features = false, version = "2" } -rand = { default-features = false, features = ["std_rng", "std"], version = "0.8" } serde = { default-features = false, features = ["derive"], version = "1" } serde_json = { default-features = false, features = ["std"], version = "1" } tokio = { default-features = false, features = ["sync", "time"], version = "1.0" } diff --git a/twilight-http/src/request/multipart.rs b/twilight-http/src/request/multipart.rs index b1ee52c8df7..3bd2c7635ec 100644 --- a/twilight-http/src/request/multipart.rs +++ b/twilight-http/src/request/multipart.rs @@ -1,5 +1,3 @@ -use rand::{distributions::Alphanumeric, Rng}; - #[derive(Clone, Debug)] #[must_use = "has no effect if not built into a Form"] pub struct Form { @@ -128,10 +126,9 @@ impl Default for Form { /// Generate a random boundary that is 15 characters long. pub fn random_boundary() -> [u8; 15] { let mut boundary = [0; 15]; - let mut rng = rand::thread_rng(); for value in &mut boundary { - *value = rng.sample(Alphanumeric); + *value = fastrand::alphanumeric() as u8; } boundary diff --git a/twilight-lavalink/Cargo.toml b/twilight-lavalink/Cargo.toml index e32e3afeffe..ef697a1b786 100644 --- a/twilight-lavalink/Cargo.toml +++ b/twilight-lavalink/Cargo.toml @@ -15,12 +15,12 @@ version = "0.15.3" [dependencies] dashmap = { default-features = false, version = "5.3" } -futures-util = { default-features = false, features = ["bilock", "std", "unstable"], version = "0.3" } +futures-util = { default-features = false, features = ["bilock", "sink", "std", "unstable"], version = "0.3" } http = { default-features = false, version = "0.2" } serde = { default-features = false, features = ["derive", "std"], version = "1" } serde_json = { default-features = false, features = ["std"], version = "1" } tokio = { default-features = false, features = ["macros", "net", "rt", "sync", "time"], version = "1.0" } -tokio-tungstenite = { default-features = false, features = ["connect"], version = "0.19" } +tokio-websockets = { default-features = false, features = ["client", "fastrand", "sha1_smol", "simd"], version = "0.4" } tracing = { default-features = false, features = ["std", "attributes"], version = "0.1" } twilight-model = { default-features = false, path = "../twilight-model", version = "0.15.4" } @@ -39,9 +39,9 @@ twilight-http = { default-features = false, features = ["rustls-native-roots"], [features] default = ["http-support", "rustls-native-roots"] http-support = ["dep:percent-encoding"] -native = ["tokio-tungstenite/native-tls"] -rustls-native-roots = ["tokio-tungstenite/rustls-tls-native-roots"] -rustls-webpki-roots = ["tokio-tungstenite/rustls-tls-webpki-roots"] +native = ["tokio-websockets/native-tls", "tokio-websockets/openssl"] +rustls-native-roots = ["tokio-websockets/rustls-native-roots"] +rustls-webpki-roots = ["tokio-websockets/rustls-webpki-roots"] [package.metadata.docs.rs] all-features = true diff --git a/twilight-lavalink/README.md b/twilight-lavalink/README.md index e90a6ace989..0c56c48a6a0 100644 --- a/twilight-lavalink/README.md +++ b/twilight-lavalink/README.md @@ -22,13 +22,13 @@ request types from the [`http`] crate. This is enabled by default. ### TLS -`twilight-lavalink` has features to enable [`tokio-tungstenite`]'s TLS +`twilight-lavalink` has features to enable [`tokio-websockets`]'s TLS features. These features are mutually exclusive. `rustls-native-roots` is enabled by default. #### `native` -The `native` feature enables [`tokio-tungstenite`]'s `native-tls` feature. +The `native` feature enables [`tokio-websockets`]'s `native-tls` feature. To enable `native`, do something like this in your `Cargo.toml`: @@ -39,14 +39,14 @@ twilight-lavalink = { default-features = false, features = ["native"], version = #### `rustls-native-roots` -The `rustls-native-roots` feature enables [`tokio-tungstenite`]'s `rustls-tls-native-roots` feature, +The `rustls-native-roots` feature enables [`tokio-websockets`]'s `rustls-native-roots` feature, which uses [`rustls`] as the TLS backend and [`rustls-native-certs`] for root certificates. This is enabled by default. #### `rustls-webpki-roots` -The `rustls-webpki-roots` feature enables [`tokio-tungstenite`]'s `rustls-tls-webpki-roots` feature, +The `rustls-webpki-roots` feature enables [`tokio-websockets`]'s `rustls-webpki-roots` feature, which uses [`rustls`] as the TLS backend and [`webpki-roots`] for root certificates. This should be preferred over `rustls-native-roots` in Docker containers based on `scratch`. @@ -115,7 +115,7 @@ There is also an example of a basic bot located in the [root of the [`http`]: https://crates.io/crates/http [`rustls`]: https://crates.io/crates/rustls [`rustls-native-certs`]: https://crates.io/crates/rustls-native-certs -[`tokio-tungstenite`]: https://crates.io/crates/tokio-tungstenite +[`tokio-websockets`]: https://crates.io/crates/tokio-websockets [`webpki-roots`]: https://crates.io/crates/webpki-roots [client]: Lavalink [codecov badge]: https://img.shields.io/codecov/c/gh/twilight-rs/twilight?logo=codecov&style=for-the-badge&token=E9ERLJL0L2 diff --git a/twilight-lavalink/src/node.rs b/twilight-lavalink/src/node.rs index 2ee3c228a08..96434fb5e69 100644 --- a/twilight-lavalink/src/node.rs +++ b/twilight-lavalink/src/node.rs @@ -26,7 +26,7 @@ use futures_util::{ sink::SinkExt, stream::{Stream, StreamExt}, }; -use http::{header::HeaderName, Request, Response, StatusCode}; +use http::header::{HeaderName, AUTHORIZATION}; use std::{ error::Error, fmt::{Debug, Display, Formatter, Result as FmtResult}, @@ -40,9 +40,8 @@ use tokio::{ sync::mpsc::{self, UnboundedReceiver, UnboundedSender}, time as tokio_time, }; -use tokio_tungstenite::{ - tungstenite::{client::IntoClientRequest, Error as TungsteniteError, Message}, - MaybeTlsStream, WebSocketStream, +use tokio_websockets::{ + upgrade, ClientBuilder, Error as WebsocketError, MaybeTlsStream, Message, WebsocketStream, }; use twilight_model::id::{marker::UserMarker, Id}; @@ -465,7 +464,7 @@ impl Node { struct Connection { config: NodeConfig, - connection: WebSocketStream>, + connection: WebsocketStream>, node_from: UnboundedReceiver, node_to: UnboundedSender, players: PlayerManager, @@ -526,7 +525,7 @@ impl Connection { kind: NodeErrorType::SerializingMessage { message: outgoing }, source: Some(Box::new(source)), })?; - let msg = Message::Text(payload); + let msg = Message::text(payload); self.connection.send(msg).await.unwrap(); } else { tracing::debug!("node {} closed, ending connection", self.config.address); @@ -546,31 +545,19 @@ impl Connection { self.config.address, ); - let text = match incoming { - Message::Close(_) => { - tracing::debug!("got close, closing connection"); - let _result = self.connection.send(Message::Close(None)).await; + let text = if incoming.is_text() { + incoming.as_text().expect("message is text") + } else if incoming.is_close() { + tracing::debug!("got close, closing connection"); - return Ok(false); - } - Message::Ping(data) => { - tracing::debug!("got ping, sending pong"); - let msg = Message::Pong(data); - - // We don't need to immediately care if a pong fails. - let _result = self.connection.send(msg).await; - - return Ok(true); - } - Message::Text(text) => text, - other => { - tracing::debug!("got pong or bytes payload: {other:?}"); + return Ok(false); + } else { + tracing::debug!("got ping, pong or binary payload: {incoming:?}"); - return Ok(true); - } + return Ok(true); }; - let Ok(event) = serde_json::from_str(&text) else { + let Ok(event) = serde_json::from_str(text) else { tracing::warn!("unknown message from lavalink node: {text}"); return Ok(true); @@ -623,27 +610,32 @@ impl Drop for Connection { } } -fn connect_request(state: &NodeConfig) -> Result, NodeError> { - let mut request = format!("ws://{}", state.address) - .into_client_request() +fn connect_request(state: &NodeConfig) -> Result { + let mut builder = ClientBuilder::new() + .uri(&format!("ws://{}", state.address)) .map_err(|source| NodeError { kind: NodeErrorType::BuildingConnectionRequest, source: Some(Box::new(source)), - })?; - let headers = request.headers_mut(); - headers.insert("Authorization", state.authorization.parse().unwrap()); - headers.insert("User-Id", state.user_id.get().into()); + })? + .add_header(AUTHORIZATION, state.authorization.parse().unwrap()) + .add_header( + HeaderName::from_static("User-Id"), + state.user_id.get().into(), + ); if state.resume.is_some() { - headers.insert("Resume-Key", state.address.to_string().parse().unwrap()); + builder = builder.add_header( + HeaderName::from_static("Resume-Key"), + state.address.to_string().parse().unwrap(), + ); } - Ok(request) + Ok(builder) } async fn reconnect( config: &NodeConfig, -) -> Result>, NodeError> { +) -> Result>, NodeError> { let (mut stream, res) = backoff(config).await?; let headers = res.headers(); @@ -660,7 +652,7 @@ async fn reconnect( "key": config.address, "timeout": resume.timeout, }); - let msg = Message::Text(serde_json::to_string(&payload).unwrap()); + let msg = Message::text(serde_json::to_string(&payload).unwrap()); stream.send(msg).await.unwrap(); } else { @@ -676,8 +668,8 @@ async fn backoff( config: &NodeConfig, ) -> Result< ( - WebSocketStream>, - Response>>, + WebsocketStream>, + upgrade::Response, ), NodeError, > { @@ -686,12 +678,12 @@ async fn backoff( loop { let request = connect_request(config)?; - match tokio_tungstenite::connect_async(request).await { + match request.connect().await { Ok((stream, response)) => return Ok((stream, response)), Err(source) => { tracing::warn!("failed to connect to node {source}: {:?}", config.address); - if matches!(&source, TungsteniteError::Http(resp) if resp.status() == StatusCode::UNAUTHORIZED) + if matches!(&source, WebsocketError::Upgrade(upgrade::Error::DidNotSwitchProtocols(status)) if status == &403) { return Err(NodeError { kind: NodeErrorType::Unauthorized {