@@ -3,6 +3,7 @@ use std::mem;
3
3
use std:: pin:: { pin, Pin } ;
4
4
use std:: time:: Duration ;
5
5
6
+ use axum:: extract:: ws;
6
7
use axum:: extract:: { Path , Query , State } ;
7
8
use axum:: response:: IntoResponse ;
8
9
use axum:: Extension ;
@@ -24,12 +25,8 @@ use spacetimedb_client_api_messages::websocket::{self as ws_api, Compression};
24
25
use spacetimedb_lib:: connection_id:: { ConnectionId , ConnectionIdForUrl } ;
25
26
use std:: time:: Instant ;
26
27
use tokio:: sync:: mpsc;
27
- use tokio_tungstenite:: tungstenite:: Utf8Bytes ;
28
28
29
29
use crate :: auth:: SpacetimeAuth ;
30
- use crate :: util:: websocket:: {
31
- CloseCode , CloseFrame , Message as WsMessage , WebSocketConfig , WebSocketStream , WebSocketUpgrade ,
32
- } ;
33
30
use crate :: util:: { NameOrIdentity , XForwardedFor } ;
34
31
use crate :: { log_and_500, ControlStateDelegate , NodeDelegate } ;
35
32
@@ -68,7 +65,7 @@ pub async fn handle_websocket<S>(
68
65
} ) : Query < SubscribeQueryParams > ,
69
66
forwarded_for : Option < TypedHeader < XForwardedFor > > ,
70
67
Extension ( auth) : Extension < SpacetimeAuth > ,
71
- ws : WebSocketUpgrade ,
68
+ ws : ws :: WebSocketUpgrade ,
72
69
) -> axum:: response:: Result < impl IntoResponse >
73
70
where
74
71
S : NodeDelegate + ControlStateDelegate ,
91
88
92
89
let db_identity = name_or_identity. resolve ( & ctx) . await ?;
93
90
94
- let ( res, ws_upgrade, protocol) =
95
- ws. select_protocol ( [ ( BIN_PROTOCOL , Protocol :: Binary ) , ( TEXT_PROTOCOL , Protocol :: Text ) ] ) ;
91
+ let ws = ws. protocols ( [ ws_api:: BIN_PROTOCOL , ws_api:: TEXT_PROTOCOL ] ) ;
92
+
93
+ let protocol = ws. selected_protocol ( ) . and_then ( |proto| {
94
+ if proto == BIN_PROTOCOL {
95
+ Some ( Protocol :: Binary )
96
+ } else if proto == TEXT_PROTOCOL {
97
+ Some ( Protocol :: Text )
98
+ } else {
99
+ None
100
+ }
101
+ } ) ;
96
102
97
103
let protocol = protocol. ok_or ( ( StatusCode :: BAD_REQUEST , "no valid protocol selected" ) ) ?;
98
104
let client_config = ClientConfig {
@@ -125,20 +131,13 @@ where
125
131
name : ctx. client_actor_index ( ) . next_client_name ( ) ,
126
132
} ;
127
133
128
- let ws_config = WebSocketConfig :: default ( )
129
- . max_message_size ( Some ( 0x2000000 ) )
130
- . max_frame_size ( None )
131
- . accept_unmasked_frames ( false ) ;
132
-
133
- tokio:: spawn ( async move {
134
- let ws = match ws_upgrade. upgrade ( ws_config) . await {
135
- Ok ( ws) => ws,
136
- Err ( err) => {
137
- log:: error!( "WebSocket init error: {}" , err) ;
138
- return ;
139
- }
140
- } ;
134
+ let ws = ws
135
+ . max_message_size ( 0x2000000 )
136
+ . max_frame_size ( usize:: MAX )
137
+ . accept_unmasked_frames ( false )
138
+ . on_failed_upgrade ( |err| log:: error!( "WebSocket init error: {}" , err) ) ;
141
139
140
+ let res = ws. on_upgrade ( move |ws| async move {
142
141
match forwarded_for {
143
142
Some ( TypedHeader ( XForwardedFor ( ip) ) ) => {
144
143
log:: debug!( "New client connected from ip {}" , ip)
@@ -180,7 +179,7 @@ where
180
179
181
180
const LIVELINESS_TIMEOUT : Duration = Duration :: from_secs ( 60 ) ;
182
181
183
- async fn ws_client_actor ( client : ClientConnection , ws : WebSocketStream , sendrx : mpsc:: Receiver < SerializableMessage > ) {
182
+ async fn ws_client_actor ( client : ClientConnection , ws : ws :: WebSocket , sendrx : mpsc:: Receiver < SerializableMessage > ) {
184
183
// ensure that even if this task gets cancelled, we always cleanup the connection
185
184
let mut client = scopeguard:: guard ( client, |client| {
186
185
tokio:: spawn ( client. disconnect ( ) ) ;
@@ -201,7 +200,7 @@ async fn make_progress<Fut: Future>(fut: &mut Pin<&mut MaybeDone<Fut>>) {
201
200
202
201
async fn ws_client_actor_inner (
203
202
client : & mut ClientConnection ,
204
- mut ws : WebSocketStream ,
203
+ mut ws : ws :: WebSocket ,
205
204
mut sendrx : mpsc:: Receiver < SerializableMessage > ,
206
205
) {
207
206
let mut liveness_check_interval = tokio:: time:: interval ( LIVELINESS_TIMEOUT ) ;
@@ -280,7 +279,7 @@ async fn ws_client_actor_inner(
280
279
let workload = msg. workload( ) ;
281
280
let num_rows = msg. num_rows( ) ;
282
281
283
- let msg = datamsg_to_wsmsg ( serialize( msg, client. config) ) ;
282
+ let msg = serialize( msg, client. config) ;
284
283
285
284
// These metrics should be updated together,
286
285
// or not at all.
@@ -295,7 +294,7 @@ async fn ws_client_actor_inner(
295
294
. observe( msg. len( ) as f64 ) ;
296
295
}
297
296
// feed() buffers the message, but does not necessarily send it
298
- ws. feed( msg) . await ?;
297
+ ws. feed( datamsg_to_wsmsg ( msg) ) . await ?;
299
298
}
300
299
// now we flush all the messages to the socket
301
300
ws. flush( ) . await
@@ -323,7 +322,7 @@ async fn ws_client_actor_inner(
323
322
// Send a close frame while continuing to poll the `handle_queue`,
324
323
// to avoid deadlocks or delays due to enqueued futures holding resources.
325
324
let close = also_poll(
326
- ws. close ( Some ( CloseFrame { code: CloseCode :: Away , reason: "module exited" . into( ) } ) ) ,
325
+ ws. send ( ws :: Message :: Close ( Some ( ws :: CloseFrame { code: ws :: close_code :: AWAY , reason: "module exited" . into( ) } ) ) ) ,
327
326
make_progress( & mut current_message) ,
328
327
) ;
329
328
if let Err ( e) = close. await {
@@ -341,7 +340,7 @@ async fn ws_client_actor_inner(
341
340
if mem:: take( & mut got_pong) {
342
341
// Send a ping message while continuing to poll the `handle_queue`,
343
342
// to avoid deadlocks or delays due to enqueued futures holding resources.
344
- if let Err ( e) = also_poll( ws. send( WsMessage :: Ping ( Bytes :: new( ) ) ) , make_progress( & mut current_message) ) . await {
343
+ if let Err ( e) = also_poll( ws. send( ws :: Message :: Ping ( Bytes :: new( ) ) ) , make_progress( & mut current_message) ) . await {
345
344
log:: warn!( "error sending ping: {e:#}" ) ;
346
345
}
347
346
continue ;
@@ -376,10 +375,10 @@ async fn ws_client_actor_inner(
376
375
}
377
376
log:: debug!( "Client caused error on text message: {}" , e) ;
378
377
if let Err ( e) = ws
379
- . close ( Some ( CloseFrame {
380
- code : CloseCode :: Error ,
378
+ . send ( ws :: Message :: Close ( Some ( ws :: CloseFrame {
379
+ code : ws :: close_code :: ERROR ,
381
380
reason : format ! ( "{e:#}" ) . into ( ) ,
382
- } ) )
381
+ } ) ) )
383
382
. await
384
383
{
385
384
log:: warn!( "error closing websocket: {e:#}" )
@@ -419,34 +418,32 @@ enum ClientMessage {
419
418
Message ( DataMessage ) ,
420
419
Ping ( Bytes ) ,
421
420
Pong ( Bytes ) ,
422
- Close ( Option < CloseFrame > ) ,
421
+ Close ( Option < ws :: CloseFrame > ) ,
423
422
}
424
423
impl ClientMessage {
425
- fn from_message ( msg : WsMessage ) -> Self {
424
+ fn from_message ( msg : ws :: Message ) -> Self {
426
425
match msg {
427
- WsMessage :: Text ( s) => Self :: Message ( DataMessage :: Text ( utf8bytes_to_bytestring ( s) ) ) ,
428
- WsMessage :: Binary ( b) => Self :: Message ( DataMessage :: Binary ( b) ) ,
429
- WsMessage :: Ping ( b) => Self :: Ping ( b) ,
430
- WsMessage :: Pong ( b) => Self :: Pong ( b) ,
431
- WsMessage :: Close ( frame) => Self :: Close ( frame) ,
432
- // WebSocket::read_message() never returns a raw Message::Frame
433
- WsMessage :: Frame ( _) => unreachable ! ( ) ,
426
+ ws:: Message :: Text ( s) => Self :: Message ( DataMessage :: Text ( utf8bytes_to_bytestring ( s) ) ) ,
427
+ ws:: Message :: Binary ( b) => Self :: Message ( DataMessage :: Binary ( b) ) ,
428
+ ws:: Message :: Ping ( b) => Self :: Ping ( b) ,
429
+ ws:: Message :: Pong ( b) => Self :: Pong ( b) ,
430
+ ws:: Message :: Close ( frame) => Self :: Close ( frame) ,
434
431
}
435
432
}
436
433
}
437
434
438
- fn datamsg_to_wsmsg ( msg : DataMessage ) -> WsMessage {
435
+ fn datamsg_to_wsmsg ( msg : DataMessage ) -> ws :: Message {
439
436
match msg {
440
- DataMessage :: Text ( text) => WsMessage :: Text ( bytestring_to_utf8bytes ( text) ) ,
441
- DataMessage :: Binary ( bin) => WsMessage :: Binary ( bin) ,
437
+ DataMessage :: Text ( text) => ws :: Message :: Text ( bytestring_to_utf8bytes ( text) ) ,
438
+ DataMessage :: Binary ( bin) => ws :: Message :: Binary ( bin) ,
442
439
}
443
440
}
444
441
445
- fn utf8bytes_to_bytestring ( s : Utf8Bytes ) -> ByteString {
442
+ fn utf8bytes_to_bytestring ( s : ws :: Utf8Bytes ) -> ByteString {
446
443
// SAFETY: `Utf8Bytes` and `ByteString` have the same invariant of UTF-8 validity
447
444
unsafe { ByteString :: from_bytes_unchecked ( Bytes :: from ( s) ) }
448
445
}
449
- fn bytestring_to_utf8bytes ( s : ByteString ) -> Utf8Bytes {
446
+ fn bytestring_to_utf8bytes ( s : ByteString ) -> ws :: Utf8Bytes {
450
447
// SAFETY: `Utf8Bytes` and `ByteString` have the same invariant of UTF-8 validity
451
- unsafe { Utf8Bytes :: from_bytes_unchecked ( s. into_bytes ( ) ) }
448
+ unsafe { ws :: Utf8Bytes :: try_from ( s. into_bytes ( ) ) . unwrap_unchecked ( ) }
452
449
}
0 commit comments