Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace ezsockets with tokio-websockets #14

Merged
merged 4 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
408 changes: 80 additions & 328 deletions Cargo.lock

Large diffs are not rendered by default.

15 changes: 11 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ keywords = [
]

[dependencies]
tokio = { version = "=1.38", features = ["sync"] }
tracing = "0.1"
ezsockets = { version = "0.6", features = ["native-tls", "native_client"] }
tokio = { version = "1.40", features = ["sync", "macros", "rt"] }
tracing = { version = "0.1", default-features = false }
tokio-websockets = { version = "0.10", features = ["client"] }
futures-util = { version = "0.3", default-features = false, features = [ "std", "sink" ] }
async-trait = "0.1"
tokio-stream = { version = "0.1", features = ["sync"] }
pin-project-lite = "0.2.14"
Expand All @@ -41,8 +42,14 @@ serde_json = "1.0.114"
os_info = "3"

ssml = "0.1.0"
async-channel = "1.9.0" # needed for ezsockets 0.6 for call_with;

[features]
default = ["tws-rustls-native-roots", "tws-fastrand", "tws-smol-sha1"]
tws-rustls-native-roots = ["tokio-websockets/rustls-webpki-roots", "tokio-websockets/ring"]
tws-rustls-webpki-roots = ["tokio-websockets/rustls-native-roots", "tokio-websockets/ring"]
tws-native-tls = ["tokio-websockets/native-tls"]
tws-smol-sha1 = ["tokio-websockets/sha1_smol"]
tws-fastrand = ["tokio-websockets/fastrand"]

[dev-dependencies]
tokio = { version = "1.36.0", features = ["full"] }
Expand Down
247 changes: 118 additions & 129 deletions src/connector/client.rs
Original file line number Diff line number Diff line change
@@ -1,163 +1,152 @@
use crate::connector::message::Message;
use async_trait::async_trait;
use ezsockets::client::ClientCloseMode;
use ezsockets::{ClientConfig, CloseCode, CloseFrame, Error};
use tokio::sync::broadcast;
use tokio::sync::oneshot;
use futures_util::SinkExt;
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio_stream::wrappers::BroadcastStream;
use tokio_stream::StreamExt;
use tokio_websockets::ClientBuilder;

enum InternalMessage {
SendMessage(tokio_websockets::Message),
Subscribe(oneshot::Sender<crate::Result<broadcast::Receiver<crate::Result<Message>>>>),
Disconnect,
}

#[derive(Clone)]
pub struct Client {
handle: ezsockets::Client<BaseClient>,
channel: mpsc::Sender<InternalMessage>,
}

impl Client {
/// Create a new client.
pub(crate) fn new(handle: ezsockets::Client<BaseClient>) -> Self {
Self { handle }
fn new(channel: mpsc::Sender<InternalMessage>) -> Self {
Self { channel }
}
}
impl Client {
/// Send a text message to the server.
pub fn send_text(&self, text: impl Into<String>) -> crate::Result<()> {
self.handle.text(text)?;
pub async fn send_text(&self, text: impl Into<String>) -> crate::Result<()> {
self.channel
.send(InternalMessage::SendMessage(
tokio_websockets::Message::text(text.into()),
))
.await?;
Ok(())
}

/// Send a binary message to the server.
pub fn send_binary(&self, bytes: impl Into<Vec<u8>>) -> crate::Result<()> {
self.handle.binary(bytes)?;
pub async fn send_binary(&self, bytes: impl Into<Vec<u8>>) -> crate::Result<()> {
self.channel
.send(InternalMessage::SendMessage(
tokio_websockets::Message::binary(bytes.into()),
))
.await?;
Ok(())
}

/// Stream messages from the server.
pub async fn stream(&self) -> crate::Result<BroadcastStream<crate::Result<Message>>> {
self.handle.call_with(Call::Subscribe).await.map_or(
Err(crate::Error::InternalError(
"Failed to subscribe to messages".to_string(),
)),
|rx| Ok(BroadcastStream::new(rx)),
)
let (sender, receiver) = oneshot::channel();
self.channel
.send(InternalMessage::Subscribe(sender))
.await?;
Ok(BroadcastStream::new(receiver.await.map_err(|_| {
crate::Error::InternalError("Failed to subscribe to messages".to_string())
})??))
}
}

impl Client {
pub(crate) async fn connect(config: ClientConfig) -> crate::Result<Self> {
let (await_connection_tx, await_connection_rx) = oneshot::channel::<()>();
let (client, future) =
ezsockets::connect(|_| BaseClient::new(await_connection_tx), config).await;

tokio::select! {
_ = await_connection_rx => {
tracing::debug!("Client is ready to send messages");
Ok(Client::new(client))
}
_ = future => {
tracing::error!("Connection closed before the client was ready to send messages");
Err(crate::Error::ServerDisconnect("Connection closed before the client was ready to send messages".to_string()))
pub(crate) async fn connect(config: ClientBuilder<'static>) -> crate::Result<Self> {
let (mut stream, _res) = config.connect().await.unwrap();
let (sender, mut receiver) = mpsc::channel(16);
tokio::spawn(async move {
let (broadcaster, _) = broadcast::channel(32);
let mut connected = true;
loop {
tokio::select! {
msg = receiver.recv() => {
let Some(msg) = msg else {
// Receiving `None` here means the client has been dropped, so the task should stop as well.
break;
};
match msg {
InternalMessage::SendMessage(msg) => {
let _ = stream.send(msg).await;
},
InternalMessage::Subscribe(c) => {
if !connected {
// We got disconnected from the server for whatever reason. Since we are currently
// expecting a stream, now would be a good time to try to reconnect.
let mut last_error = None;
for i in 0..3 {
tracing::debug!("Reconncting ({i}/3)");
match config.connect().await {
Ok((new_stream, _)) => {
tracing::debug!("Reconnected successfully");
drop(last_error.take());
connected = true;
stream = new_stream;
break;
}
Err(e) => {
tracing::warn!("Failed to reconnect ({i}/3): {e}");
last_error.replace(e);
}
}
}

// If we still haven't reconnected, send the error to the client.
if let Some(err) = last_error.take() {
c.send(Err(crate::Error::ConnectionError(err.to_string()))).unwrap();
continue;
}
}

c.send(Ok(broadcaster.subscribe())).unwrap();
},
InternalMessage::Disconnect => {
connected = false;
let _ = stream.close().await;
}
}
}
msg = stream.next(), if connected => {
let Some(msg) = msg else {
// Receiving `None` here means the socket has been disconnected and can no longer receive messages.
// We set `connected` to false just to make sure that the stream isn't polled again until we're reconnected.
connected = false;
continue;
};
match msg {
Ok(msg) => {
if msg.is_text() {
let text = msg.as_text().unwrap();
broadcaster.send(Message::try_from(text)).unwrap();
} else if msg.is_binary() {
let bin = msg.as_payload();
broadcaster.send(Message::try_from(&**bin)).unwrap();
} else if msg.is_close() {
connected = false;

let close = msg.as_close().unwrap();
tracing::info!(reason = ?close.0, msg = close.1, "disconnected from server");
}
},
Err(e) => {
tracing::warn!(?e, "connection errored");
connected = false;
}
}
}
}
}
}
});
Ok(Client::new(sender))
}

/// Disconnect the client.
pub(crate) async fn disconnect(self) -> crate::Result<()> {
let _ = self.handle.close(None)?;
self.channel.send(InternalMessage::Disconnect).await?;
Ok(())
}
}

pub(crate) enum Call {
Subscribe(async_channel::Sender<broadcast::Receiver<crate::Result<Message>>>),
}

pub(crate) struct BaseClient {
messages: broadcast::Sender<crate::Result<Message>>,
ready: Option<oneshot::Sender<()>>,
}

impl BaseClient {
pub(crate) fn new(ready: oneshot::Sender<()>) -> Self {
let (sender, _) = broadcast::channel(1024 * 5);
Self {
messages: sender,
ready: Some(ready),
}
}
}

#[async_trait]
impl ezsockets::ClientExt for BaseClient {
type Call = Call;

async fn on_text(&mut self, text: String) -> Result<(), Error> {
tracing::debug!("Received text: {:?}", text);

return match text.try_into() {
Ok(value) => {
self.messages.send(Ok(value))?;
Ok(())
}
_ => Err(Error::from("Error parsing text".to_string())),
};
}

async fn on_binary(&mut self, bytes: Vec<u8>) -> Result<(), Error> {
tracing::debug!("Received binary: {:?}", bytes.len());
tracing::trace!("Received Binary data: {:?}", bytes);
return match bytes.try_into() {
Ok(value) => {
self.messages.send(Ok(value))?;
Ok(())
}
_ => Err(Error::from("Error parsing bytes".to_string())),
};
}

async fn on_call(&mut self, call: Self::Call) -> Result<(), Error> {
match call {
Call::Subscribe(respond_to) => {
let _ = respond_to.send(self.messages.subscribe()).await?;
}
};
Ok(())
}

async fn on_connect(&mut self) -> Result<(), Error> {
// send the ready signal.
// This is used to notify the connector that the client is ready to send messages.
if let Some(ready) = self.ready.take() {
let _ = ready.send(());
}

Ok(())
}

async fn on_close(&mut self, frame: Option<CloseFrame>) -> Result<ClientCloseMode, Error> {
tracing::debug!("Server close the connection...{:?}", frame);

match frame {
Some(CloseFrame { code, reason }) => {
let mode = match code {
CloseCode::Restart | CloseCode::Again | CloseCode::Normal => {
tracing::debug!("Reconnecting...");
ClientCloseMode::Reconnect
}
_ => {
tracing::debug!("Sending server error message...");
self.messages.send(Err(crate::Error::ServerDisconnect(
reason.clone().to_string(),
)))?;
tracing::debug!("Closing...");
ClientCloseMode::Close
}
};
tracing::debug!("Close mode: {:?}", mode);
Ok(mode)
}
None => {
tracing::debug!("Reconnecting...");
Ok(ClientCloseMode::Reconnect)
}
}
}
}
Loading