Skip to content

Commit

Permalink
refactor(gateway, lavalink): Switch to tokio-websockets
Browse files Browse the repository at this point in the history
Signed-off-by: Jens Reidel <[email protected]>
  • Loading branch information
Gelbpunkt committed Jul 8, 2023
1 parent abbf91c commit a596c4d
Show file tree
Hide file tree
Showing 13 changed files with 128 additions and 476 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ include = ["src/**/*.rs", "README.md"]
license = "ISC"
repository = "https://github.com/twilight-rs/twilight.git"
rust-version = "1.67"

[patch.crates-io]
tokio-websockets = { git = "https://github.com/Gelbpunkt/tokio-websockets.git" }
19 changes: 7 additions & 12 deletions twilight-gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ version = "0.15.2"

[dependencies]
bitflags = { default-features = false, version = "2" }
fastrand = { default-features = false, features = ["std"], version = "2" }
futures-util = { default-features = false, features = ["std"], version = "0.3" }
rand = { default-features = false, features = ["std", "std_rng"], version = "0.8" }
serde = { default-features = false, features = ["derive"], version = "1" }
serde_json = { default-features = false, features = ["std"], version = "1" }
tokio = { default-features = false, features = ["net", "rt", "sync", "time"], version = "1.8" }
tokio-tungstenite = { default-features = false, features = ["connect"], version = "0.19" }
tokio-websockets = { default-features = false, features = ["client", "fastrand"], version = "0.3" }
tracing = { default-features = false, features = ["std", "attributes"], version = "0.1" }
twilight-gateway-queue = { default-features = false, path = "../twilight-gateway-queue", version = "0.15.2" }
twilight-model = { default-features = false, path = "../twilight-model", version = "0.15.2" }
Expand All @@ -34,13 +34,6 @@ flate2 = { default-features = false, optional = true, version = "1.0.24" }
twilight-http = { default-features = false, optional = true, path = "../twilight-http", version = "0.15.2" }
simd-json = { default-features = false, features = ["serde_impl", "swar-number-parsing"], optional = true, version = ">=0.4, <0.11" }

# TLS libraries
# They are needed to track what is used in tokio-tungstenite
native-tls = { default-features = false, optional = true, version = "0.2.8" }
rustls-native-certs = { default-features = false, optional = true, version = "0.6" }
rustls-tls = { default-features = false, optional = true, package = "rustls", version = "0.21" }
webpki-roots = { default-features = false, optional = true, version = "0.23" }

[dev-dependencies]
anyhow = { default-features = false, features = ["std"], version = "1" }
futures = { default-features = false, version = "0.3" }
Expand All @@ -51,11 +44,13 @@ tracing-subscriber = { default-features = false, features = ["fmt", "tracing-log

[features]
default = ["rustls-native-roots", "twilight-http", "zlib-stock"]
native = ["dep:native-tls", "tokio-tungstenite/native-tls"]
rustls-native-roots = ["dep:rustls-tls", "dep:rustls-native-certs", "tokio-tungstenite/rustls-tls-native-roots"]
rustls-webpki-roots = ["dep:rustls-tls", "dep:webpki-roots", "tokio-tungstenite/rustls-tls-webpki-roots"]
no-tls = ["tokio-websockets/sha1_smol"]
native = ["tokio-websockets/native-tls", "tokio-websockets/openssl"]
rustls-native-roots = ["tokio-websockets/rustls-native-roots"]
rustls-webpki-roots = ["tokio-websockets/rustls-webpki-roots"]
zlib-simd = ["dep:flate2", "flate2?/zlib-ng"]
zlib-stock = ["dep:flate2", "flate2?/zlib"]
websockets-simd = ["tokio-websockets/simd"]

[package.metadata.docs.rs]
rustdoc-args = ["--cfg", "docsrs"]
9 changes: 5 additions & 4 deletions twilight-gateway/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
//! User configuration for shards.

use crate::{tls::TlsContainer, EventTypeFlags, Session};
use crate::{EventTypeFlags, Session};
use std::{
fmt::{Debug, Formatter, Result as FmtResult},
sync::Arc,
};
use tokio_websockets::Connector;
use twilight_gateway_queue::{LocalQueue, Queue};
use twilight_model::gateway::{
payload::outgoing::{identify::IdentifyProperties, update_presence::UpdatePresencePayload},
Expand Down Expand Up @@ -67,7 +68,7 @@ pub struct Config {
/// TLS connector for Websocket connections.
// We need this to be public so [`stream`] can re-use TLS on multiple shards
// if unconfigured.
tls: TlsContainer,
tls: Arc<Connector>,
/// Token used to authenticate when identifying with the gateway.
///
/// The token is prefixed with "Bot ", which is required by Discord for
Expand Down Expand Up @@ -145,7 +146,7 @@ impl Config {
}

/// Immutable reference to the TLS connector in use by the shard.
pub(crate) const fn tls(&self) -> &TlsContainer {
pub(crate) fn tls(&self) -> &Connector {
&self.tls
}

Expand Down Expand Up @@ -193,7 +194,7 @@ impl ConfigBuilder {
queue: Arc::new(LocalQueue::new()),
ratelimit_messages: true,
session: None,
tls: TlsContainer::new().unwrap(),
tls: Arc::new(Connector::new().unwrap()),
token: Token::new(token.into_boxed_str()),
},
}
Expand Down
43 changes: 20 additions & 23 deletions twilight-gateway/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
//! Utilities for creating Websocket connections.

use crate::{error::ReceiveMessageError, tls::TlsContainer, API_VERSION};
use crate::{
error::{ReceiveMessageError, ReceiveMessageErrorType},
API_VERSION,
};
use std::fmt::{Display, Formatter, Result as FmtResult};
use tokio::net::TcpStream;
use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, MaybeTlsStream, WebSocketStream};
use tokio_websockets::{ClientBuilder, Connector, MaybeTlsStream, WebsocketStream};

/// Query argument with zlib-stream enabled.
#[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
Expand All @@ -16,29 +19,12 @@ const COMPRESSION_FEATURES: &str = "";
/// URL of the Discord gateway.
const GATEWAY_URL: &str = "wss://gateway.discord.gg";

/// Configuration used for Websocket connections.
///
/// `max_frame_size` and `max_message_queue` limits are disabled because
/// Discord is not a malicious actor and having a limit has caused problems on
/// large [`GuildCreate`] payloads.
///
/// `accept_unmasked_frames` and `max_send_queue` are set to their
/// defaults.
///
/// [`GuildCreate`]: twilight_model::gateway::payload::incoming::GuildCreate
const WEBSOCKET_CONFIG: WebSocketConfig = WebSocketConfig {
accept_unmasked_frames: false,
max_frame_size: None,
max_message_size: None,
max_send_queue: None,
};

/// [`tokio_tungstenite`] library Websocket connection.
/// [`tokio_websockets`] library Websocket connection.
///
/// Connections are used by [`Shard`]s when reconnecting.
///
/// [`Shard`]: crate::Shard
pub type Connection = WebSocketStream<MaybeTlsStream<TcpStream>>;
pub type Connection = WebsocketStream<MaybeTlsStream<TcpStream>>;

/// Formatter for a gateway URL, with the API version and compression features
/// specified.
Expand Down Expand Up @@ -93,12 +79,23 @@ impl Display for ConnectionUrl<'_> {
#[tracing::instrument(skip_all)]
pub async fn connect(
maybe_gateway_url: Option<&str>,
tls: &TlsContainer,
tls: &Connector,
) -> Result<Connection, ReceiveMessageError> {
let url = ConnectionUrl::new(maybe_gateway_url).to_string();

tracing::debug!(?url, "shaking hands with gateway");
let stream = tls.connect(&url, WEBSOCKET_CONFIG).await?;

let (stream, _) = ClientBuilder::new()
.uri(&url)
.expect("Gateway URL must be valid")
.fail_fast_on_invalid_utf8(false)
.connector(tls)
.connect()
.await
.map_err(|source| ReceiveMessageError {
kind: ReceiveMessageErrorType::Reconnect,
source: Some(Box::new(source)),
})?;

Ok(stream)
}
Expand Down
1 change: 0 additions & 1 deletion twilight-gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ mod message;
mod ratelimiter;
mod session;
mod shard;
mod tls;

#[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
pub use self::inflater::Inflater;
Expand Down
52 changes: 27 additions & 25 deletions twilight-gateway/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
//! input will not be checked and will be passed directly to the underlying
//! websocket library.

use tokio_tungstenite::tungstenite::{
protocol::{frame::coding::CloseCode, CloseFrame as TungsteniteCloseFrame},
Message as TungsteniteMessage,
};
use std::borrow::Cow;

use tokio_websockets::{CloseCode, Message as WebsocketMessage};
use twilight_model::gateway::CloseFrame;

/// Message to send over the connection to the remote.
Expand All @@ -25,33 +24,36 @@ pub enum Message {
}

impl Message {
/// Convert a `tungstenite` websocket message into a `twilight` websocket
/// Convert a `tokio-websockets` websocket message into a `twilight` websocket
/// message.
pub(crate) fn from_tungstenite(tungstenite: TungsteniteMessage) -> Option<Self> {
match tungstenite {
TungsteniteMessage::Close(frame) => Some(Self::Close(frame.map(|frame| CloseFrame {
code: frame.code.into(),
reason: frame.reason,
}))),
TungsteniteMessage::Text(string) => Some(Self::Text(string)),
TungsteniteMessage::Binary(_)
| TungsteniteMessage::Frame(_)
| TungsteniteMessage::Ping(_)
| TungsteniteMessage::Pong(_) => None,
pub(crate) fn from_websocket_msg(msg: WebsocketMessage) -> Option<Self> {
if msg.is_close() {
let (code, reason) = msg.as_close().unwrap();

let frame = code.map(|code| CloseFrame {
code: code.into(),
reason: Cow::Owned(reason.unwrap_or_default().to_string()),
});

Some(Self::Close(frame))
} else if msg.is_text() {
Some(Self::Text(msg.as_text().unwrap().to_owned()))
} else {
None
}
}

/// Convert a `twilight` websocket message into a `tungstenite` websocket
/// Convert a `twilight` websocket message into a `tokio-websockets` websocket
/// message.
pub(crate) fn into_tungstenite(self) -> TungsteniteMessage {
pub(crate) fn into_websocket_msg(self) -> WebsocketMessage {
match self {
Self::Close(frame) => {
TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
code: CloseCode::from(frame.code),
reason: frame.reason,
}))
}
Self::Text(string) => TungsteniteMessage::Text(string),
Self::Close(frame) => WebsocketMessage::close(
frame
.as_ref()
.and_then(|f| CloseCode::try_from(f.code).ok()),
frame.map(|f| f.reason).as_deref(),
),
Self::Text(string) => WebsocketMessage::text(string),
}
}
}
Expand Down
38 changes: 21 additions & 17 deletions twilight-gateway/src/shard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ use crate::{
session::Session,
Config, Message, ShardId,
};
use futures_util::{SinkExt, StreamExt};
use futures_util::{FutureExt, SinkExt, StreamExt};
use serde::{de::DeserializeOwned, Deserialize};
#[cfg(any(
feature = "native",
Expand All @@ -89,7 +89,7 @@ use tokio::{
task::JoinHandle,
time::{self, Duration, Instant, Interval, MissedTickBehavior},
};
use tokio_tungstenite::tungstenite::{Error as TungsteniteError, Message as TungsteniteMessage};
use tokio_websockets::{Error as WebsocketError, Message as WebsocketMessage};
use twilight_model::gateway::{
event::{Event, GatewayEventDeserializer},
payload::{
Expand Down Expand Up @@ -569,7 +569,7 @@ impl Shard {
/// Identify with the gateway.
Identify,
/// Handle this incoming gateway message.
Message(Option<Result<TungsteniteMessage, TungsteniteError>>),
Message(Option<Result<WebsocketMessage, WebsocketError>>),
}

match self.status {
Expand All @@ -592,7 +592,7 @@ impl Shard {
}

let message = loop {
let next_action = |cx: &mut Context<'_>| {
let mut next_action = |cx: &mut Context<'_>| {
if !(self.status.is_disconnected() || self.status.is_fatally_closed()) {
if let Poll::Ready(frame) = self.user_channel.close_rx.poll_recv(cx) {
return Poll::Ready(Action::Close(frame.expect("shard owns channel")));
Expand Down Expand Up @@ -627,29 +627,35 @@ impl Shard {
}
}

if let Poll::Ready(message) =
Pin::new(&mut self.connection.as_mut().expect("connected").next()).poll(cx)
if let Poll::Ready(message) = self
.connection
.as_mut()
.expect("connected")
.next()
.poll_unpin(cx)
{
return Poll::Ready(Action::Message(message));
}

Poll::Pending
};

match poll_fn(next_action).await {
let action = poll_fn(&mut next_action).await;

match action {
Action::Message(Some(Ok(message))) => {
#[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
if let TungsteniteMessage::Binary(bytes) = &message {
if message.is_binary() {
if let Some(decompressed) = self
.inflater
.inflate(bytes)
.inflate(message.as_data())
.map_err(ReceiveMessageError::from_compression)?
{
tracing::trace!(%decompressed);
break Message::Text(decompressed);
};
}
if let Some(message) = Message::from_tungstenite(message) {
if let Some(message) = Message::from_websocket_msg(message) {
break message;
}
}
Expand All @@ -663,7 +669,7 @@ impl Shard {
feature = "rustls-native-roots",
feature = "rustls-webpki-roots"
))]
Action::Message(Some(Err(TungsteniteError::Io(e))))
Action::Message(Some(Err(WebsocketError::Io(e))))
if e.kind() == IoErrorKind::UnexpectedEof
// Assert we're directly connected to Discord's gateway.
&& self.config.proxy_url().is_none()
Expand Down Expand Up @@ -773,13 +779,11 @@ impl Shard {

match &message {
Message::Close(frame) => {
// Tungstenite automatically replies to the close message.
// tokio-websockets automatically replies to the close message.
tracing::debug!(?frame, "received websocket close message");
// Don't run `disconnect` if we initiated the close.
if !self.status.is_disconnected() {
self.disconnect(CloseInitiator::Gateway(
frame.as_ref().map(|frame| frame.code),
));
self.disconnect(CloseInitiator::Gateway(frame.as_ref().map(|f| f.code)));
}
}
Message::Text(event) => {
Expand Down Expand Up @@ -884,7 +888,7 @@ impl Shard {
kind: SendErrorType::Sending,
source: None,
})?
.send(message.into_tungstenite())
.send(message.into_websocket_msg())
.await
.map_err(|source| SendError {
kind: SendErrorType::Sending,
Expand Down Expand Up @@ -1105,7 +1109,7 @@ impl Shard {
let heartbeat_interval = Duration::from_millis(event.data.heartbeat_interval);
// First heartbeat should have some jitter, see
// https://discord.com/developers/docs/topics/gateway#heartbeat-interval
let jitter = heartbeat_interval.mul_f64(rand::random());
let jitter = heartbeat_interval.mul_f64(fastrand::f64());
tracing::debug!(?heartbeat_interval, ?jitter, "received hello");

if self.config().ratelimit_messages() {
Expand Down
Loading

0 comments on commit a596c4d

Please sign in to comment.