Skip to content

Commit

Permalink
Merge pull request dani-garcia#3440 from BlackDex/switch-ws-to-streams
Browse files Browse the repository at this point in the history
Small update to Rocket WebSockets
  • Loading branch information
dani-garcia authored Apr 17, 2023
2 parents 5866338 + 48cc31a commit 3d7e80a
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 39 deletions.
46 changes: 23 additions & 23 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ futures = "0.3.28"
tokio = { version = "1.27.0", features = ["rt-multi-thread", "fs", "io-util", "parking_lot", "time", "signal"] }

# A generic serialization/deserialization framework
serde = { version = "1.0.159", features = ["derive"] }
serde = { version = "1.0.160", features = ["derive"] }
serde_json = "1.0.95"

# A safe, extensible ORM and Query builder
Expand Down Expand Up @@ -133,7 +133,7 @@ data-url = "0.2.0"
bytes = "1.4.0"

# Cache function results (Used for version check and favicon fetching)
cached = "0.42.0"
cached = "0.43.0"

# Used for custom short lived cookie jar during favicon extraction
cookie = "0.16.2"
Expand Down
26 changes: 12 additions & 14 deletions src/api/notifications.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async fn websockets_hub<'r>(
ws: rocket_ws::WebSocket,
data: WsAccessToken,
ip: ClientIp,
) -> Result<rocket_ws::Channel<'r>, Error> {
) -> Result<rocket_ws::Stream!['r], Error> {
let addr = ip.ip;
info!("Accepting Rocket WS connection from {addr}");

Expand All @@ -93,32 +93,32 @@ async fn websockets_hub<'r>(
(rx, WSEntryMapGuard::new(users, claims.sub, entry_uuid, addr))
};

Ok(ws.channel(move |mut stream| {
Box::pin(async move {
// Make sure the guard is moved into the channel future so it's not dropped earlier
Ok({
rocket_ws::Stream! { ws => {
let mut ws = ws;
let _guard = guard;
let mut interval = tokio::time::interval(Duration::from_secs(15));
loop {
tokio::select! {
res = stream.next() => {
res = ws.next() => {
match res {
Some(Ok(message)) => {
match message {
// Respond to any pings
Message::Ping(ping) => stream.send(Message::Pong(ping)).await?,
Message::Ping(ping) => yield Message::Pong(ping),
Message::Pong(_) => {/* Ignored */},

// We should receive an initial message with the protocol and version, and we will reply to it
Message::Text(ref message) => {
let msg = message.strip_suffix(RECORD_SEPARATOR as char).unwrap_or(message);

if serde_json::from_str(msg).ok() == Some(INITIAL_MESSAGE) {
stream.send(Message::binary(INITIAL_RESPONSE)).await?;
yield Message::binary(INITIAL_RESPONSE);
continue;
}
}
// Just echo anything else the client sends
_ => stream.send(message).await?,
_ => yield message,
}
}
_ => break,
Expand All @@ -127,18 +127,16 @@ async fn websockets_hub<'r>(

res = rx.recv() => {
match res {
Some(res) => stream.send(res).await?,
Some(res) => yield res,
None => break,
}
}

_ = interval.tick() => stream.send(Message::Ping(create_ping())).await?
_ = interval.tick() => yield Message::Ping(create_ping())
}
}

Ok(())
})
}))
}}
})
}

//
Expand Down

0 comments on commit 3d7e80a

Please sign in to comment.