Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(gateway, http, lavalink): Switch to tokio-websockets #2239

Merged
merged 4 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions twilight-gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ version = "0.15.4"

[dependencies]
bitflags = { default-features = false, version = "2" }
futures-util = { default-features = false, features = ["std"], version = "0.3" }
rand = { default-features = false, features = ["std", "std_rng"], version = "0.8" }
fastrand = { default-features = false, features = ["std"], version = "2" }
futures-util = { default-features = false, features = ["sink", "std"], version = "0.3" }
serde = { default-features = false, features = ["derive"], version = "1" }
serde_json = { default-features = false, features = ["std"], version = "1" }
tokio = { default-features = false, features = ["net", "rt", "sync", "time"], version = "1.19" }
tokio-tungstenite = { default-features = false, features = ["connect"], version = "0.19" }
tokio-websockets = { default-features = false, features = ["client", "fastrand", "sha1_smol", "simd"], version = "0.4" }
tracing = { default-features = false, features = ["std", "attributes"], version = "0.1" }
twilight-gateway-queue = { default-features = false, path = "../twilight-gateway-queue", version = "0.15.4" }
twilight-model = { default-features = false, path = "../twilight-model", version = "0.15.4" }
Expand All @@ -34,13 +34,6 @@ flate2 = { default-features = false, optional = true, version = "1.0.24" }
twilight-http = { default-features = false, optional = true, path = "../twilight-http", version = "0.15.4" }
simd-json = { default-features = false, features = ["serde_impl", "swar-number-parsing"], optional = true, version = ">=0.4, <0.11" }

# TLS libraries
# They are needed to track what is used in tokio-tungstenite
native-tls = { default-features = false, optional = true, version = "0.2.8" }
rustls-native-certs = { default-features = false, optional = true, version = "0.6" }
rustls-tls = { default-features = false, optional = true, package = "rustls", version = "0.21" }
webpki-roots = { default-features = false, optional = true, version = "0.23" }

[dev-dependencies]
anyhow = { default-features = false, features = ["std"], version = "1" }
futures = { default-features = false, version = "0.3" }
Expand All @@ -51,9 +44,9 @@ tracing-subscriber = { default-features = false, features = ["fmt", "tracing-log

[features]
default = ["rustls-native-roots", "twilight-http", "zlib-stock"]
native = ["dep:native-tls", "tokio-tungstenite/native-tls"]
rustls-native-roots = ["dep:rustls-tls", "dep:rustls-native-certs", "tokio-tungstenite/rustls-tls-native-roots"]
rustls-webpki-roots = ["dep:rustls-tls", "dep:webpki-roots", "tokio-tungstenite/rustls-tls-webpki-roots"]
native = ["tokio-websockets/native-tls", "tokio-websockets/openssl"]
rustls-native-roots = ["tokio-websockets/rustls-native-roots"]
rustls-webpki-roots = ["tokio-websockets/rustls-webpki-roots"]
zlib-simd = ["dep:flate2", "flate2?/zlib-ng"]
zlib-stock = ["dep:flate2", "flate2?/zlib"]

Expand Down
8 changes: 4 additions & 4 deletions twilight-gateway/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

use crate::{
queue::{InMemoryQueue, Queue},
tls::TlsContainer,
EventTypeFlags, Session,
};
use std::{
fmt::{Debug, Formatter, Result as FmtResult},
sync::Arc,
};
use tokio_websockets::Connector;
use twilight_model::gateway::{
payload::outgoing::{identify::IdentifyProperties, update_presence::UpdatePresencePayload},
Intents,
Expand Down Expand Up @@ -69,7 +69,7 @@ pub struct Config {
/// TLS connector for Websocket connections.
// We need this to be public so [`stream`] can re-use TLS on multiple shards
// if unconfigured.
tls: TlsContainer,
tls: Arc<Connector>,
/// Token used to authenticate when identifying with the gateway.
///
/// The token is prefixed with "Bot ", which is required by Discord for
Expand Down Expand Up @@ -147,7 +147,7 @@ impl Config {
}

/// Immutable reference to the TLS connector in use by the shard.
pub(crate) const fn tls(&self) -> &TlsContainer {
pub(crate) fn tls(&self) -> &Connector {
&self.tls
}

Expand Down Expand Up @@ -195,7 +195,7 @@ impl ConfigBuilder {
queue: Arc::new(InMemoryQueue::default()),
ratelimit_messages: true,
session: None,
tls: TlsContainer::new().unwrap(),
tls: Arc::new(Connector::new().unwrap()),
token: Token::new(token.into_boxed_str()),
},
}
Expand Down
50 changes: 27 additions & 23 deletions twilight-gateway/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
//! Utilities for creating Websocket connections.

use crate::{error::ReceiveMessageError, tls::TlsContainer, API_VERSION};
use crate::{
error::{ReceiveMessageError, ReceiveMessageErrorType},
API_VERSION,
};
use std::fmt::{Display, Formatter, Result as FmtResult};
use tokio::net::TcpStream;
use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, MaybeTlsStream, WebSocketStream};
use tokio_websockets::{ClientBuilder, Connector, Limits, MaybeTlsStream, WebsocketStream};

/// Query argument with zlib-stream enabled.
#[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
Expand All @@ -16,29 +19,12 @@ const COMPRESSION_FEATURES: &str = "";
/// URL of the Discord gateway.
const GATEWAY_URL: &str = "wss://gateway.discord.gg";

/// Configuration used for Websocket connections.
///
/// `max_frame_size` and `max_message_queue` limits are disabled because
/// Discord is not a malicious actor and having a limit has caused problems on
/// large [`GuildCreate`] payloads.
///
/// `accept_unmasked_frames` and `max_send_queue` are set to their
/// defaults.
///
/// [`GuildCreate`]: twilight_model::gateway::payload::incoming::GuildCreate
const WEBSOCKET_CONFIG: WebSocketConfig = WebSocketConfig {
accept_unmasked_frames: false,
max_frame_size: None,
max_message_size: None,
max_send_queue: None,
};

/// [`tokio_tungstenite`] library Websocket connection.
/// [`tokio_websockets`] library Websocket connection.
///
/// Connections are used by [`Shard`]s when reconnecting.
///
/// [`Shard`]: crate::Shard
pub type Connection = WebSocketStream<MaybeTlsStream<TcpStream>>;
pub type Connection = WebsocketStream<MaybeTlsStream<TcpStream>>;

/// Formatter for a gateway URL, with the API version and compression features
/// specified.
Expand Down Expand Up @@ -93,12 +79,30 @@ impl Display for ConnectionUrl<'_> {
#[tracing::instrument(skip_all)]
pub async fn connect(
maybe_gateway_url: Option<&str>,
tls: &TlsContainer,
tls: &Connector,
) -> Result<Connection, ReceiveMessageError> {
let url = ConnectionUrl::new(maybe_gateway_url).to_string();

// Limits to impose on Websocket connections.
//
// `max_payload_len` limit is disabled because Discord is not a malicious
// actor and having a limit has caused problems on large `GuildCreate`
// payloads.
let limits = Limits::default().max_payload_len(None);

tracing::debug!(?url, "shaking hands with gateway");
let stream = tls.connect(&url, WEBSOCKET_CONFIG).await?;

let (stream, _) = ClientBuilder::new()
.uri(&url)
.expect("Gateway URL must be valid")
.limits(limits)
.connector(tls)
.connect()
.await
.map_err(|source| ReceiveMessageError {
kind: ReceiveMessageErrorType::Reconnect,
source: Some(Box::new(source)),
})?;

Ok(stream)
}
Expand Down
1 change: 0 additions & 1 deletion twilight-gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ mod message;
mod ratelimiter;
mod session;
mod shard;
mod tls;

#[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
pub use self::inflater::Inflater;
Expand Down
52 changes: 27 additions & 25 deletions twilight-gateway/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
//! input will not be checked and will be passed directly to the underlying
//! websocket library.

use tokio_tungstenite::tungstenite::{
protocol::{frame::coding::CloseCode, CloseFrame as TungsteniteCloseFrame},
Message as TungsteniteMessage,
};
use std::borrow::Cow;

use tokio_websockets::{CloseCode, Message as WebsocketMessage};
use twilight_model::gateway::CloseFrame;

/// Message to send over the connection to the remote.
Expand All @@ -25,33 +24,36 @@ pub enum Message {
}

impl Message {
/// Convert a `tungstenite` websocket message into a `twilight` websocket
/// Convert a `tokio-websockets` websocket message into a `twilight` websocket
/// message.
pub(crate) fn from_tungstenite(tungstenite: TungsteniteMessage) -> Option<Self> {
match tungstenite {
TungsteniteMessage::Close(frame) => Some(Self::Close(frame.map(|frame| CloseFrame {
code: frame.code.into(),
reason: frame.reason,
}))),
TungsteniteMessage::Text(string) => Some(Self::Text(string)),
TungsteniteMessage::Binary(_)
| TungsteniteMessage::Frame(_)
| TungsteniteMessage::Ping(_)
| TungsteniteMessage::Pong(_) => None,
pub(crate) fn from_websocket_msg(msg: &WebsocketMessage) -> Option<Self> {
if msg.is_close() {
let (code, reason) = msg.as_close().unwrap();

let frame = (code == CloseCode::NO_STATUS_RECEIVED).then(|| CloseFrame {
code: code.into(),
reason: Cow::Owned(reason.to_string()),
});

Some(Self::Close(frame))
} else if msg.is_text() {
Some(Self::Text(msg.as_text().unwrap().to_owned()))
} else {
None
}
}

/// Convert a `twilight` websocket message into a `tungstenite` websocket
/// Convert a `twilight` websocket message into a `tokio-websockets` websocket
/// message.
pub(crate) fn into_tungstenite(self) -> TungsteniteMessage {
pub(crate) fn into_websocket_msg(self) -> WebsocketMessage {
match self {
Self::Close(frame) => {
TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
code: CloseCode::from(frame.code),
reason: frame.reason,
}))
}
Self::Text(string) => TungsteniteMessage::Text(string),
Self::Close(frame) => WebsocketMessage::close(
frame
.as_ref()
.and_then(|f| CloseCode::try_from(f.code).ok()),
frame.map(|f| f.reason).as_deref().unwrap_or_default(),
),
Self::Text(string) => WebsocketMessage::text(string),
}
}
}
Expand Down
38 changes: 17 additions & 21 deletions twilight-gateway/src/shard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ use tokio::{
sync::oneshot,
time::{self, Duration, Instant, Interval, MissedTickBehavior},
};
use tokio_tungstenite::tungstenite::{Error as TungsteniteError, Message as TungsteniteMessage};
use tokio_websockets::{Error as WebsocketError, Message as WebsocketMessage};
use twilight_model::gateway::{
event::{Event, GatewayEventDeserializer},
payload::{
Expand Down Expand Up @@ -568,18 +568,16 @@ impl Shard {
/// Identify with the gateway.
Identify,
/// Handle this incoming gateway message.
Message(Option<Result<TungsteniteMessage, TungsteniteError>>),
Message(Option<Result<WebsocketMessage, WebsocketError>>),
}

match self.status {
ConnectionStatus::Disconnected {
close_code,
reconnect_attempts,
} => {
// The shard is considered disconnected after having received a
// close frame or encountering a websocket error, but it should
// only reconnect after the underlying TCP connection is closed
// by the server (having returned `Ok(None)`).
// The shard should should only reconnect after the gateway
// closes the underlying TCP connection.
if self.connection.is_none() {
self.reconnect(close_code, reconnect_attempts).await?;
}
Expand Down Expand Up @@ -672,17 +670,17 @@ impl Shard {
match poll_fn(next_action).await {
Action::Message(Some(Ok(message))) => {
#[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
if let TungsteniteMessage::Binary(bytes) = &message {
if message.is_binary() {
if let Some(decompressed) = self
.inflater
.inflate(bytes)
.inflate(message.as_payload())
.map_err(ReceiveMessageError::from_compression)?
{
tracing::trace!(%decompressed);
break Message::Text(decompressed);
};
}
if let Some(message) = Message::from_tungstenite(message) {
if let Some(message) = Message::from_websocket_msg(&message) {
break message;
}
}
Expand All @@ -696,7 +694,7 @@ impl Shard {
feature = "rustls-native-roots",
feature = "rustls-webpki-roots"
))]
Action::Message(Some(Err(TungsteniteError::Io(e))))
Action::Message(Some(Err(WebsocketError::Io(e))))
if e.kind() == IoErrorKind::UnexpectedEof
// Assert we're directly connected to Discord's gateway.
&& self.config.proxy_url().is_none()
Expand Down Expand Up @@ -726,11 +724,11 @@ impl Shard {
ConnectionStatus::FatallyClosed { close_code } => {
return Err(ReceiveMessageError::from_fatally_closed(close_code))
}
_ => unreachable!(
"stream ended because websocket is closed (received close frame sets \
status to disconnected or fatally closed) or because it errored (which \
also sets status to disconnected)"
),
_ => {
// Abnormal closure without close frame exchange.
self.disconnect(CloseInitiator::None);
self.reconnect(None, 0).await?;
Comment on lines +728 to +730
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we do need to return something here to indicate upstream that the shard's closing. Again, let's leave this as is but remember to check back on this prior to a release. I plan on rewriting large parts of the internals of Shard (cancel safety) during which I'll fix this too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's get this merged into next very soon then so we can start refactoring the shard internals

}
};

continue;
Expand Down Expand Up @@ -804,13 +802,11 @@ impl Shard {

match &message {
Message::Close(frame) => {
// Tungstenite automatically replies to the close message.
// tokio-websockets automatically replies to the close message.
tracing::debug!(?frame, "received websocket close message");
// Don't run `disconnect` if we initiated the close.
if !self.status.is_disconnected() {
self.disconnect(CloseInitiator::Gateway(
frame.as_ref().map(|frame| frame.code),
));
self.disconnect(CloseInitiator::Gateway(frame.as_ref().map(|f| f.code)));
}
}
Message::Text(event) => {
Expand Down Expand Up @@ -913,7 +909,7 @@ impl Shard {
kind: SendErrorType::Sending,
source: None,
})?
.send(message.into_tungstenite())
.send(message.into_websocket_msg())
.await
.map_err(|source| SendError {
kind: SendErrorType::Sending,
Expand Down Expand Up @@ -1116,7 +1112,7 @@ impl Shard {
let heartbeat_interval = Duration::from_millis(event.data.heartbeat_interval);
// First heartbeat should have some jitter, see
// https://discord.com/developers/docs/topics/gateway#heartbeat-interval
let jitter = heartbeat_interval.mul_f64(rand::random());
let jitter = heartbeat_interval.mul_f64(fastrand::f64());
tracing::debug!(?heartbeat_interval, ?jitter, "received hello");

if self.config().ratelimit_messages() {
Expand Down
Loading