Skip to content

Commit

Permalink
refactor(gateway, lavalink): Switch to tokio-websockets
Browse files Browse the repository at this point in the history
Signed-off-by: Jens Reidel <[email protected]>
  • Loading branch information
Gelbpunkt committed Sep 10, 2023
1 parent 89dee72 commit 14cf608
Show file tree
Hide file tree
Showing 12 changed files with 120 additions and 466 deletions.
19 changes: 6 additions & 13 deletions twilight-gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ version = "0.15.4"

[dependencies]
bitflags = { default-features = false, version = "2" }
futures-util = { default-features = false, features = ["std"], version = "0.3" }
rand = { default-features = false, features = ["std", "std_rng"], version = "0.8" }
fastrand = { default-features = false, features = ["std"], version = "2" }
futures-util = { default-features = false, features = ["sink", "std"], version = "0.3" }
serde = { default-features = false, features = ["derive"], version = "1" }
serde_json = { default-features = false, features = ["std"], version = "1" }
tokio = { default-features = false, features = ["net", "rt", "sync", "time"], version = "1.19" }
tokio-tungstenite = { default-features = false, features = ["connect"], version = "0.19" }
tokio-websockets = { default-features = false, features = ["client", "fastrand", "sha1_smol", "simd"], version = "0.4" }
tracing = { default-features = false, features = ["std", "attributes"], version = "0.1" }
twilight-gateway-queue = { default-features = false, path = "../twilight-gateway-queue", version = "0.15.4" }
twilight-model = { default-features = false, path = "../twilight-model", version = "0.15.4" }
Expand All @@ -34,13 +34,6 @@ flate2 = { default-features = false, optional = true, version = "1.0.24" }
twilight-http = { default-features = false, optional = true, path = "../twilight-http", version = "0.15.4" }
simd-json = { default-features = false, features = ["serde_impl", "swar-number-parsing"], optional = true, version = ">=0.4, <0.11" }

# TLS libraries
# They are needed to track what is used in tokio-tungstenite
native-tls = { default-features = false, optional = true, version = "0.2.8" }
rustls-native-certs = { default-features = false, optional = true, version = "0.6" }
rustls-tls = { default-features = false, optional = true, package = "rustls", version = "0.21" }
webpki-roots = { default-features = false, optional = true, version = "0.23" }

[dev-dependencies]
anyhow = { default-features = false, features = ["std"], version = "1" }
futures = { default-features = false, version = "0.3" }
Expand All @@ -51,9 +44,9 @@ tracing-subscriber = { default-features = false, features = ["fmt", "tracing-log

[features]
default = ["rustls-native-roots", "twilight-http", "zlib-stock"]
native = ["dep:native-tls", "tokio-tungstenite/native-tls"]
rustls-native-roots = ["dep:rustls-tls", "dep:rustls-native-certs", "tokio-tungstenite/rustls-tls-native-roots"]
rustls-webpki-roots = ["dep:rustls-tls", "dep:webpki-roots", "tokio-tungstenite/rustls-tls-webpki-roots"]
native = ["tokio-websockets/native-tls", "tokio-websockets/openssl"]
rustls-native-roots = ["tokio-websockets/rustls-native-roots"]
rustls-webpki-roots = ["tokio-websockets/rustls-webpki-roots"]
zlib-simd = ["dep:flate2", "flate2?/zlib-ng"]
zlib-stock = ["dep:flate2", "flate2?/zlib"]

Expand Down
9 changes: 5 additions & 4 deletions twilight-gateway/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
//! User configuration for shards.

use crate::{tls::TlsContainer, EventTypeFlags, Session};
use crate::{EventTypeFlags, Session};
use std::{
fmt::{Debug, Formatter, Result as FmtResult},
sync::Arc,
};
use tokio_websockets::Connector;
use twilight_gateway_queue::{LocalQueue, Queue};
use twilight_model::gateway::{
payload::outgoing::{identify::IdentifyProperties, update_presence::UpdatePresencePayload},
Expand Down Expand Up @@ -66,7 +67,7 @@ pub struct Config {
/// TLS connector for Websocket connections.
// We need this to be public so [`stream`] can re-use TLS on multiple shards
// if unconfigured.
tls: TlsContainer,
tls: Arc<Connector>,
/// Token used to authenticate when identifying with the gateway.
///
/// The token is prefixed with "Bot ", which is required by Discord for
Expand Down Expand Up @@ -144,7 +145,7 @@ impl Config {
}

/// Immutable reference to the TLS connector in use by the shard.
pub(crate) const fn tls(&self) -> &TlsContainer {
pub(crate) fn tls(&self) -> &Connector {
&self.tls
}

Expand Down Expand Up @@ -192,7 +193,7 @@ impl ConfigBuilder {
queue: Arc::new(LocalQueue::new()),
ratelimit_messages: true,
session: None,
tls: TlsContainer::new().unwrap(),
tls: Arc::new(Connector::new().unwrap()),
token: Token::new(token.into_boxed_str()),
},
}
Expand Down
50 changes: 27 additions & 23 deletions twilight-gateway/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
//! Utilities for creating Websocket connections.

use crate::{error::ReceiveMessageError, tls::TlsContainer, API_VERSION};
use crate::{
error::{ReceiveMessageError, ReceiveMessageErrorType},
API_VERSION,
};
use std::fmt::{Display, Formatter, Result as FmtResult};
use tokio::net::TcpStream;
use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, MaybeTlsStream, WebSocketStream};
use tokio_websockets::{ClientBuilder, Connector, Limits, MaybeTlsStream, WebsocketStream};

/// Query argument with zlib-stream enabled.
#[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
Expand All @@ -16,29 +19,12 @@ const COMPRESSION_FEATURES: &str = "";
/// URL of the Discord gateway.
const GATEWAY_URL: &str = "wss://gateway.discord.gg";

/// Configuration used for Websocket connections.
///
/// `max_frame_size` and `max_message_queue` limits are disabled because
/// Discord is not a malicious actor and having a limit has caused problems on
/// large [`GuildCreate`] payloads.
///
/// `accept_unmasked_frames` and `max_send_queue` are set to their
/// defaults.
///
/// [`GuildCreate`]: twilight_model::gateway::payload::incoming::GuildCreate
const WEBSOCKET_CONFIG: WebSocketConfig = WebSocketConfig {
accept_unmasked_frames: false,
max_frame_size: None,
max_message_size: None,
max_send_queue: None,
};

/// [`tokio_tungstenite`] library Websocket connection.
/// [`tokio_websockets`] library Websocket connection.
///
/// Connections are used by [`Shard`]s when reconnecting.
///
/// [`Shard`]: crate::Shard
pub type Connection = WebSocketStream<MaybeTlsStream<TcpStream>>;
pub type Connection = WebsocketStream<MaybeTlsStream<TcpStream>>;

/// Formatter for a gateway URL, with the API version and compression features
/// specified.
Expand Down Expand Up @@ -93,12 +79,30 @@ impl Display for ConnectionUrl<'_> {
#[tracing::instrument(skip_all)]
pub async fn connect(
maybe_gateway_url: Option<&str>,
tls: &TlsContainer,
tls: &Connector,
) -> Result<Connection, ReceiveMessageError> {
let url = ConnectionUrl::new(maybe_gateway_url).to_string();

// Limits to impose on Websocket connections.
//
// `max_payload_len` limit is disabled because Discord is not a malicious
// actor and having a limit has caused problems on large `GuildCreate`
// payloads.
let limits = Limits::default().max_payload_len(None);

tracing::debug!(?url, "shaking hands with gateway");
let stream = tls.connect(&url, WEBSOCKET_CONFIG).await?;

let (stream, _) = ClientBuilder::new()
.uri(&url)
.expect("Gateway URL must be valid")
.limits(limits)
.connector(tls)
.connect()
.await
.map_err(|source| ReceiveMessageError {
kind: ReceiveMessageErrorType::Reconnect,
source: Some(Box::new(source)),
})?;

Ok(stream)
}
Expand Down
1 change: 0 additions & 1 deletion twilight-gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ mod message;
mod ratelimiter;
mod session;
mod shard;
mod tls;

#[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
pub use self::inflater::Inflater;
Expand Down
52 changes: 27 additions & 25 deletions twilight-gateway/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
//! input will not be checked and will be passed directly to the underlying
//! websocket library.

use tokio_tungstenite::tungstenite::{
protocol::{frame::coding::CloseCode, CloseFrame as TungsteniteCloseFrame},
Message as TungsteniteMessage,
};
use std::borrow::Cow;

use tokio_websockets::{CloseCode, Message as WebsocketMessage};
use twilight_model::gateway::CloseFrame;

/// Message to send over the connection to the remote.
Expand All @@ -25,33 +24,36 @@ pub enum Message {
}

impl Message {
/// Convert a `tungstenite` websocket message into a `twilight` websocket
/// Convert a `tokio-websockets` websocket message into a `twilight` websocket
/// message.
pub(crate) fn from_tungstenite(tungstenite: TungsteniteMessage) -> Option<Self> {
match tungstenite {
TungsteniteMessage::Close(frame) => Some(Self::Close(frame.map(|frame| CloseFrame {
code: frame.code.into(),
reason: frame.reason,
}))),
TungsteniteMessage::Text(string) => Some(Self::Text(string)),
TungsteniteMessage::Binary(_)
| TungsteniteMessage::Frame(_)
| TungsteniteMessage::Ping(_)
| TungsteniteMessage::Pong(_) => None,
pub(crate) fn from_websocket_msg(msg: &WebsocketMessage) -> Option<Self> {
if msg.is_close() {
let (code, reason) = msg.as_close().unwrap();

let frame = (code == CloseCode::NO_STATUS_RECEIVED).then(|| CloseFrame {
code: code.into(),
reason: Cow::Owned(reason.to_string()),
});

Some(Self::Close(frame))
} else if msg.is_text() {
Some(Self::Text(msg.as_text().unwrap().to_owned()))
} else {
None
}
}

/// Convert a `twilight` websocket message into a `tungstenite` websocket
/// Convert a `twilight` websocket message into a `tokio-websockets` websocket
/// message.
pub(crate) fn into_tungstenite(self) -> TungsteniteMessage {
pub(crate) fn into_websocket_msg(self) -> WebsocketMessage {
match self {
Self::Close(frame) => {
TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
code: CloseCode::from(frame.code),
reason: frame.reason,
}))
}
Self::Text(string) => TungsteniteMessage::Text(string),
Self::Close(frame) => WebsocketMessage::close(
frame
.as_ref()
.and_then(|f| CloseCode::try_from(f.code).ok()),
frame.map(|f| f.reason).as_deref().unwrap_or_default(),
),
Self::Text(string) => WebsocketMessage::text(string),
}
}
}
Expand Down
22 changes: 10 additions & 12 deletions twilight-gateway/src/shard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ use tokio::{
task::JoinHandle,
time::{self, Duration, Instant, Interval, MissedTickBehavior},
};
use tokio_tungstenite::tungstenite::{Error as TungsteniteError, Message as TungsteniteMessage};
use tokio_websockets::{Error as WebsocketError, Message as WebsocketMessage};
use twilight_model::gateway::{
event::{Event, GatewayEventDeserializer},
payload::{
Expand Down Expand Up @@ -568,7 +568,7 @@ impl Shard {
/// Identify with the gateway.
Identify,
/// Handle this incoming gateway message.
Message(Option<Result<TungsteniteMessage, TungsteniteError>>),
Message(Option<Result<WebsocketMessage, WebsocketError>>),
}

match self.status {
Expand Down Expand Up @@ -667,17 +667,17 @@ impl Shard {
match poll_fn(next_action).await {
Action::Message(Some(Ok(message))) => {
#[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
if let TungsteniteMessage::Binary(bytes) = &message {
if message.is_binary() {
if let Some(decompressed) = self
.inflater
.inflate(bytes)
.inflate(message.as_payload())
.map_err(ReceiveMessageError::from_compression)?
{
tracing::trace!(%decompressed);
break Message::Text(decompressed);
};
}
if let Some(message) = Message::from_tungstenite(message) {
if let Some(message) = Message::from_websocket_msg(&message) {
break message;
}
}
Expand All @@ -691,7 +691,7 @@ impl Shard {
feature = "rustls-native-roots",
feature = "rustls-webpki-roots"
))]
Action::Message(Some(Err(TungsteniteError::Io(e))))
Action::Message(Some(Err(WebsocketError::Io(e))))
if e.kind() == IoErrorKind::UnexpectedEof
// Assert we're directly connected to Discord's gateway.
&& self.config.proxy_url().is_none()
Expand Down Expand Up @@ -801,13 +801,11 @@ impl Shard {

match &message {
Message::Close(frame) => {
// Tungstenite automatically replies to the close message.
// tokio-websockets automatically replies to the close message.
tracing::debug!(?frame, "received websocket close message");
// Don't run `disconnect` if we initiated the close.
if !self.status.is_disconnected() {
self.disconnect(CloseInitiator::Gateway(
frame.as_ref().map(|frame| frame.code),
));
self.disconnect(CloseInitiator::Gateway(frame.as_ref().map(|f| f.code)));
}
}
Message::Text(event) => {
Expand Down Expand Up @@ -910,7 +908,7 @@ impl Shard {
kind: SendErrorType::Sending,
source: None,
})?
.send(message.into_tungstenite())
.send(message.into_websocket_msg())
.await
.map_err(|source| SendError {
kind: SendErrorType::Sending,
Expand Down Expand Up @@ -1112,7 +1110,7 @@ impl Shard {
let heartbeat_interval = Duration::from_millis(event.data.heartbeat_interval);
// First heartbeat should have some jitter, see
// https://discord.com/developers/docs/topics/gateway#heartbeat-interval
let jitter = heartbeat_interval.mul_f64(rand::random());
let jitter = heartbeat_interval.mul_f64(fastrand::f64());
tracing::debug!(?heartbeat_interval, ?jitter, "received hello");

if self.config().ratelimit_messages() {
Expand Down
Loading

0 comments on commit 14cf608

Please sign in to comment.