Skip to content

Commit 2b207bc

Browse files
committed
Update axum
1 parent f11b814 commit 2b207bc

File tree

11 files changed

+63
-435
lines changed

11 files changed

+63
-435
lines changed

Cargo.lock

Lines changed: 14 additions & 37 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ anymap = "0.12"
132132
arrayvec = "0.7.2"
133133
async-stream = "0.3.6"
134134
async-trait = "0.1.68"
135-
axum = { version = "0.7", features = ["tracing"] }
136-
axum-extra = { version = "0.9", features = ["typed-header"] }
135+
axum = { version = "0.8.4", features = ["tracing", "ws"] }
136+
axum-extra = { version = "0.10", features = ["typed-header"] }
137137
backtrace = "0.3.66"
138138
base64 = "0.21.2"
139139
bigdecimal = "0.4.7"

crates/client-api/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ futures = "0.3"
4141
bytes = "1"
4242
tracing.workspace = true
4343
bytestring = "1"
44-
tokio-tungstenite.workspace = true
4544
itoa.workspace = true
4645
derive_more = "0.99.17"
4746
uuid.workspace = true

crates/client-api/src/auth.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,6 @@ pub struct SpacetimeAuthHeader {
264264
auth: Option<SpacetimeAuth>,
265265
}
266266

267-
#[async_trait::async_trait]
268267
impl<S: NodeDelegate + Send + Sync> axum::extract::FromRequestParts<S> for SpacetimeAuthHeader {
269268
type Rejection = AuthorizationRejection;
270269
async fn from_request_parts(parts: &mut request::Parts, state: &S) -> Result<Self, Self::Rejection> {
@@ -341,7 +340,6 @@ impl SpacetimeAuthHeader {
341340

342341
pub struct SpacetimeAuthRequired(pub SpacetimeAuth);
343342

344-
#[async_trait::async_trait]
345343
impl<S: NodeDelegate + Send + Sync> axum::extract::FromRequestParts<S> for SpacetimeAuthRequired {
346344
type Rejection = AuthorizationRejection;
347345
async fn from_request_parts(parts: &mut request::Parts, state: &S) -> Result<Self, Self::Rejection> {

crates/client-api/src/routes/database.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -778,14 +778,14 @@ where
778778
.route("/names", self.names_put)
779779
.route("/identity", self.identity_get)
780780
.route("/subscribe", self.subscribe_get)
781-
.route("/call/:reducer", self.call_reducer_post)
781+
.route("/call/{reducer}", self.call_reducer_post)
782782
.route("/schema", self.schema_get)
783783
.route("/logs", self.logs_get)
784784
.route("/sql", self.sql_post);
785785

786786
axum::Router::new()
787787
.route("/", self.root_post)
788-
.nest("/:name_or_identity", db_router)
788+
.nest("/{name_or_identity}", db_router)
789789
.route_layer(axum::middleware::from_fn_with_state(ctx, anon_auth_middleware::<S>))
790790
}
791791
}

crates/client-api/src/routes/energy.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ where
125125
{
126126
use axum::routing::get;
127127
axum::Router::new().route(
128-
"/:identity",
128+
"/{identity}",
129129
get(get_energy_balance::<S>)
130130
.put(set_energy_balance::<S>)
131131
.post(add_energy::<S>),

crates/client-api/src/routes/identity.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,6 @@ where
144144
.route("/", post(create_identity::<S>))
145145
.route("/public-key", get(get_public_key::<S>))
146146
.route("/websocket-token", post(create_websocket_token::<S>))
147-
.route("/:identity/verify", get(validate_token))
148-
.route("/:identity/databases", get(get_databases::<S>))
147+
.route("/{identity}/verify", get(validate_token))
148+
.route("/{identity}/databases", get(get_databases::<S>))
149149
}

crates/client-api/src/routes/subscribe.rs

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::mem;
33
use std::pin::{pin, Pin};
44
use std::time::Duration;
55

6+
use axum::extract::ws;
67
use axum::extract::{Path, Query, State};
78
use axum::response::IntoResponse;
89
use axum::Extension;
@@ -24,12 +25,8 @@ use spacetimedb_client_api_messages::websocket::{self as ws_api, Compression};
2425
use spacetimedb_lib::connection_id::{ConnectionId, ConnectionIdForUrl};
2526
use std::time::Instant;
2627
use tokio::sync::mpsc;
27-
use tokio_tungstenite::tungstenite::Utf8Bytes;
2828

2929
use crate::auth::SpacetimeAuth;
30-
use crate::util::websocket::{
31-
CloseCode, CloseFrame, Message as WsMessage, WebSocketConfig, WebSocketStream, WebSocketUpgrade,
32-
};
3330
use crate::util::{NameOrIdentity, XForwardedFor};
3431
use crate::{log_and_500, ControlStateDelegate, NodeDelegate};
3532

@@ -68,7 +65,7 @@ pub async fn handle_websocket<S>(
6865
}): Query<SubscribeQueryParams>,
6966
forwarded_for: Option<TypedHeader<XForwardedFor>>,
7067
Extension(auth): Extension<SpacetimeAuth>,
71-
ws: WebSocketUpgrade,
68+
ws: ws::WebSocketUpgrade,
7269
) -> axum::response::Result<impl IntoResponse>
7370
where
7471
S: NodeDelegate + ControlStateDelegate,
@@ -91,8 +88,17 @@ where
9188

9289
let db_identity = name_or_identity.resolve(&ctx).await?;
9390

94-
let (res, ws_upgrade, protocol) =
95-
ws.select_protocol([(BIN_PROTOCOL, Protocol::Binary), (TEXT_PROTOCOL, Protocol::Text)]);
91+
let ws = ws.protocols([ws_api::BIN_PROTOCOL, ws_api::TEXT_PROTOCOL]);
92+
93+
let protocol = ws.selected_protocol().and_then(|proto| {
94+
if proto == BIN_PROTOCOL {
95+
Some(Protocol::Binary)
96+
} else if proto == TEXT_PROTOCOL {
97+
Some(Protocol::Text)
98+
} else {
99+
None
100+
}
101+
});
96102

97103
let protocol = protocol.ok_or((StatusCode::BAD_REQUEST, "no valid protocol selected"))?;
98104
let client_config = ClientConfig {
@@ -125,20 +131,13 @@ where
125131
name: ctx.client_actor_index().next_client_name(),
126132
};
127133

128-
let ws_config = WebSocketConfig::default()
129-
.max_message_size(Some(0x2000000))
130-
.max_frame_size(None)
131-
.accept_unmasked_frames(false);
132-
133-
tokio::spawn(async move {
134-
let ws = match ws_upgrade.upgrade(ws_config).await {
135-
Ok(ws) => ws,
136-
Err(err) => {
137-
log::error!("WebSocket init error: {}", err);
138-
return;
139-
}
140-
};
134+
let ws = ws
135+
.max_message_size(0x2000000)
136+
.max_frame_size(usize::MAX)
137+
.accept_unmasked_frames(false)
138+
.on_failed_upgrade(|err| log::error!("WebSocket init error: {}", err));
141139

140+
let res = ws.on_upgrade(move |ws| async move {
142141
match forwarded_for {
143142
Some(TypedHeader(XForwardedFor(ip))) => {
144143
log::debug!("New client connected from ip {}", ip)
@@ -180,7 +179,7 @@ where
180179

181180
const LIVELINESS_TIMEOUT: Duration = Duration::from_secs(60);
182181

183-
async fn ws_client_actor(client: ClientConnection, ws: WebSocketStream, sendrx: mpsc::Receiver<SerializableMessage>) {
182+
async fn ws_client_actor(client: ClientConnection, ws: ws::WebSocket, sendrx: mpsc::Receiver<SerializableMessage>) {
184183
// ensure that even if this task gets cancelled, we always cleanup the connection
185184
let mut client = scopeguard::guard(client, |client| {
186185
tokio::spawn(client.disconnect());
@@ -201,7 +200,7 @@ async fn make_progress<Fut: Future>(fut: &mut Pin<&mut MaybeDone<Fut>>) {
201200

202201
async fn ws_client_actor_inner(
203202
client: &mut ClientConnection,
204-
mut ws: WebSocketStream,
203+
mut ws: ws::WebSocket,
205204
mut sendrx: mpsc::Receiver<SerializableMessage>,
206205
) {
207206
let mut liveness_check_interval = tokio::time::interval(LIVELINESS_TIMEOUT);
@@ -280,7 +279,7 @@ async fn ws_client_actor_inner(
280279
let workload = msg.workload();
281280
let num_rows = msg.num_rows();
282281

283-
let msg = datamsg_to_wsmsg(serialize(msg, client.config));
282+
let msg = serialize(msg, client.config);
284283

285284
// These metrics should be updated together,
286285
// or not at all.
@@ -295,7 +294,7 @@ async fn ws_client_actor_inner(
295294
.observe(msg.len() as f64);
296295
}
297296
// feed() buffers the message, but does not necessarily send it
298-
ws.feed(msg).await?;
297+
ws.feed(datamsg_to_wsmsg(msg)).await?;
299298
}
300299
// now we flush all the messages to the socket
301300
ws.flush().await
@@ -323,7 +322,7 @@ async fn ws_client_actor_inner(
323322
// Send a close frame while continuing to poll the `handle_queue`,
324323
// to avoid deadlocks or delays due to enqueued futures holding resources.
325324
let close = also_poll(
326-
ws.close(Some(CloseFrame { code: CloseCode::Away, reason: "module exited".into() })),
325+
ws.send(ws::Message::Close(Some(ws::CloseFrame { code: ws::close_code::AWAY, reason: "module exited".into() }))),
327326
make_progress(&mut current_message),
328327
);
329328
if let Err(e) = close.await {
@@ -341,7 +340,7 @@ async fn ws_client_actor_inner(
341340
if mem::take(&mut got_pong) {
342341
// Send a ping message while continuing to poll the `handle_queue`,
343342
// to avoid deadlocks or delays due to enqueued futures holding resources.
344-
if let Err(e) = also_poll(ws.send(WsMessage::Ping(Bytes::new())), make_progress(&mut current_message)).await {
343+
if let Err(e) = also_poll(ws.send(ws::Message::Ping(Bytes::new())), make_progress(&mut current_message)).await {
345344
log::warn!("error sending ping: {e:#}");
346345
}
347346
continue;
@@ -376,10 +375,10 @@ async fn ws_client_actor_inner(
376375
}
377376
log::debug!("Client caused error on text message: {}", e);
378377
if let Err(e) = ws
379-
.close(Some(CloseFrame {
380-
code: CloseCode::Error,
378+
.send(ws::Message::Close(Some(ws::CloseFrame {
379+
code: ws::close_code::ERROR,
381380
reason: format!("{e:#}").into(),
382-
}))
381+
})))
383382
.await
384383
{
385384
log::warn!("error closing websocket: {e:#}")
@@ -419,34 +418,32 @@ enum ClientMessage {
419418
Message(DataMessage),
420419
Ping(Bytes),
421420
Pong(Bytes),
422-
Close(Option<CloseFrame>),
421+
Close(Option<ws::CloseFrame>),
423422
}
424423
impl ClientMessage {
425-
fn from_message(msg: WsMessage) -> Self {
424+
fn from_message(msg: ws::Message) -> Self {
426425
match msg {
427-
WsMessage::Text(s) => Self::Message(DataMessage::Text(utf8bytes_to_bytestring(s))),
428-
WsMessage::Binary(b) => Self::Message(DataMessage::Binary(b)),
429-
WsMessage::Ping(b) => Self::Ping(b),
430-
WsMessage::Pong(b) => Self::Pong(b),
431-
WsMessage::Close(frame) => Self::Close(frame),
432-
// WebSocket::read_message() never returns a raw Message::Frame
433-
WsMessage::Frame(_) => unreachable!(),
426+
ws::Message::Text(s) => Self::Message(DataMessage::Text(utf8bytes_to_bytestring(s))),
427+
ws::Message::Binary(b) => Self::Message(DataMessage::Binary(b)),
428+
ws::Message::Ping(b) => Self::Ping(b),
429+
ws::Message::Pong(b) => Self::Pong(b),
430+
ws::Message::Close(frame) => Self::Close(frame),
434431
}
435432
}
436433
}
437434

438-
fn datamsg_to_wsmsg(msg: DataMessage) -> WsMessage {
435+
fn datamsg_to_wsmsg(msg: DataMessage) -> ws::Message {
439436
match msg {
440-
DataMessage::Text(text) => WsMessage::Text(bytestring_to_utf8bytes(text)),
441-
DataMessage::Binary(bin) => WsMessage::Binary(bin),
437+
DataMessage::Text(text) => ws::Message::Text(bytestring_to_utf8bytes(text)),
438+
DataMessage::Binary(bin) => ws::Message::Binary(bin),
442439
}
443440
}
444441

445-
fn utf8bytes_to_bytestring(s: Utf8Bytes) -> ByteString {
442+
fn utf8bytes_to_bytestring(s: ws::Utf8Bytes) -> ByteString {
446443
// SAFETY: `Utf8Bytes` and `ByteString` have the same invariant of UTF-8 validity
447444
unsafe { ByteString::from_bytes_unchecked(Bytes::from(s)) }
448445
}
449-
fn bytestring_to_utf8bytes(s: ByteString) -> Utf8Bytes {
446+
fn bytestring_to_utf8bytes(s: ByteString) -> ws::Utf8Bytes {
450447
// SAFETY: `Utf8Bytes` and `ByteString` have the same invariant of UTF-8 validity
451-
unsafe { Utf8Bytes::from_bytes_unchecked(s.into_bytes()) }
448+
unsafe { ws::Utf8Bytes::try_from(s.into_bytes()).unwrap_unchecked() }
452449
}

0 commit comments

Comments
 (0)