Skip to content

Commit dbe2c67

Browse files
nWackyLegNeato
andauthored
Minor improvements to subscriptions functionality (#591)
Co-authored-by: Christian Legnitto <[email protected]>
1 parent c91b989 commit dbe2c67

File tree

3 files changed

+133
-96
lines changed

3 files changed

+133
-96
lines changed

examples/warp_subscriptions/Cargo.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ serde_json = "1.0"
1414
tokio = { version = "0.2", features = ["rt-core", "macros"] }
1515
warp = "0.2.1"
1616

17-
# TODO#433: get crates from GitHub
18-
juniper = { path = "../../juniper" }
19-
juniper_subscriptions = { path = "../../juniper_subscriptions"}
20-
juniper_warp = { path = "../../juniper_warp", features = ["subscriptions"] }
17+
juniper = { git = "https://github.com/graphql-rust/juniper" }
18+
juniper_subscriptions = { git = "https://github.com/graphql-rust/juniper" }
19+
juniper_warp = { git = "https://github.com/graphql-rust/juniper", features = ["subscriptions"] }

examples/warp_subscriptions/src/main.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,20 @@ async fn main() {
165165
ctx: Context,
166166
coordinator: Arc<Coordinator<'static, _, _, _, _, _>>| {
167167
ws.on_upgrade(|websocket| -> Pin<Box<dyn Future<Output = ()> + Send>> {
168-
graphql_subscriptions(websocket, coordinator, ctx).boxed()
168+
graphql_subscriptions(websocket, coordinator, ctx)
169+
.map(|r| {
170+
if let Err(e) = r {
171+
println!("Websocket error: {}", e);
172+
}
173+
})
174+
.boxed()
169175
})
170176
},
171177
))
172-
.map(|reply| {
173-
// TODO#584: remove this workaround
174-
warp::reply::with_header(reply, "Sec-WebSocket-Protocol", "graphql-ws")
175-
})
178+
.map(|reply| {
179+
// TODO#584: remove this workaround
180+
warp::reply::with_header(reply, "Sec-WebSocket-Protocol", "graphql-ws")
181+
})
176182
.or(warp::post()
177183
.and(warp::path("graphql"))
178184
.and(qm_graphql_filter))

juniper_warp/src/lib.rs

Lines changed: 119 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,8 @@ fn playground_response(
442442
/// Cannot be merged to `juniper_warp` yet as GraphQL over WS[1]
443443
/// is not fully supported in current implementation.
444444
///
445+
/// *Note: this implementation is in an alpha state.*
446+
///
445447
/// [1]: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
446448
#[cfg(feature = "subscriptions")]
447449
pub mod subscriptions {
@@ -453,7 +455,7 @@ pub mod subscriptions {
453455
},
454456
};
455457

456-
use futures::{channel::mpsc, stream::StreamExt as _, Future};
458+
use futures::{channel::mpsc, Future, StreamExt as _};
457459
use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator as _};
458460
use juniper_subscriptions::Coordinator;
459461
use serde::{Deserialize, Serialize};
@@ -467,7 +469,7 @@ pub mod subscriptions {
467469
websocket: warp::ws::WebSocket,
468470
coordinator: Arc<Coordinator<'static, Query, Mutation, Subscription, Context, S>>,
469471
context: Context,
470-
) -> impl Future<Output = ()> + Send
472+
) -> impl Future<Output = Result<(), failure::Error>> + Send
471473
where
472474
S: ScalarValue + Send + Sync + 'static,
473475
Context: Clone + Send + Sync + 'static,
@@ -489,107 +491,137 @@ pub mod subscriptions {
489491
);
490492

491493
let context = Arc::new(context);
494+
let running = Arc::new(AtomicBool::new(false));
492495
let got_close_signal = Arc::new(AtomicBool::new(false));
493496

494-
sink_rx.for_each(move |msg| {
495-
let msg = msg.unwrap_or_else(|e| panic!("Websocket receive error: {}", e));
496-
497-
if msg.is_close() {
498-
return futures::future::ready(());
499-
}
500-
497+
sink_rx.fold(Ok(()), move |_, msg| {
501498
let coordinator = coordinator.clone();
502499
let context = context.clone();
500+
let running = running.clone();
503501
let got_close_signal = got_close_signal.clone();
502+
let ws_tx = ws_tx.clone();
503+
504+
async move {
505+
let msg = match msg {
506+
Ok(m) => m,
507+
Err(e) => {
508+
got_close_signal.store(true, Ordering::Relaxed);
509+
return Err(failure::format_err!("Websocket error: {}", e));
510+
}
511+
};
504512

505-
let msg = msg.to_str().expect("Non-text messages are not accepted");
506-
let request: WsPayload<S> = serde_json::from_str(msg).expect("Invalid WsPayload");
513+
if msg.is_close() {
514+
return Ok(());
515+
}
507516

508-
match request.type_name.as_str() {
509-
"connection_init" => {}
510-
"start" => {
511-
{
512-
let closed = got_close_signal.load(Ordering::Relaxed);
513-
if closed {
514-
return futures::future::ready(());
517+
let msg = msg
518+
.to_str()
519+
.map_err(|_| failure::format_err!("Non-text messages are not accepted"))?;
520+
let request: WsPayload<S> = serde_json::from_str(msg)
521+
.map_err(|e| failure::format_err!("Invalid WsPayload: {}", e))?;
522+
523+
match request.type_name.as_str() {
524+
"connection_init" => {}
525+
"start" => {
526+
{
527+
let closed = got_close_signal.load(Ordering::Relaxed);
528+
if closed {
529+
return Ok(());
530+
}
531+
532+
if running.load(Ordering::Relaxed) {
533+
return Ok(());
534+
}
535+
running.store(true, Ordering::Relaxed);
515536
}
516-
}
517537

518-
let ws_tx = ws_tx.clone();
538+
let ws_tx = ws_tx.clone();
519539

520-
tokio::task::spawn(async move {
521-
let payload = request.payload.expect("Could not deserialize payload");
522-
let request_id = request.id.unwrap_or("1".to_owned());
540+
if let Some(ref payload) = request.payload {
541+
if payload.query.is_none() {
542+
return Err(failure::format_err!("Query not found"));
543+
}
544+
} else {
545+
return Err(failure::format_err!("Payload not found"));
546+
}
523547

524-
let graphql_request = GraphQLRequest::<S>::new(
525-
payload.query.expect("Could not deserialize query"),
526-
None,
527-
payload.variables,
548+
tokio::task::spawn(async move {
549+
let payload = request.payload.unwrap();
550+
551+
let request_id = request.id.unwrap_or("1".to_owned());
552+
553+
let graphql_request = GraphQLRequest::<S>::new(
554+
payload.query.unwrap(),
555+
None,
556+
payload.variables,
557+
);
558+
559+
let values_stream =
560+
match coordinator.subscribe(&graphql_request, &context).await {
561+
Ok(s) => s,
562+
Err(err) => {
563+
let _ =
564+
ws_tx.unbounded_send(Some(Ok(Message::text(format!(
565+
r#"{{"type":"error","id":"{}","payload":{}}}"#,
566+
request_id,
567+
serde_json::ser::to_string(&err).unwrap_or(
568+
"Error deserializing GraphQLError".to_owned()
569+
)
570+
)))));
571+
572+
let close_message = format!(
573+
r#"{{"type":"complete","id":"{}","payload":null}}"#,
574+
request_id
575+
);
576+
let _ = ws_tx
577+
.unbounded_send(Some(Ok(Message::text(close_message))));
578+
// close channel
579+
let _ = ws_tx.unbounded_send(None);
580+
return;
581+
}
582+
};
583+
584+
values_stream
585+
.take_while(move |response| {
586+
let request_id = request_id.clone();
587+
let closed = got_close_signal.load(Ordering::Relaxed);
588+
if !closed {
589+
let mut response_text = serde_json::to_string(&response)
590+
.unwrap_or("Error deserializing response".to_owned());
591+
592+
response_text = format!(
593+
r#"{{"type":"data","id":"{}","payload":{} }}"#,
594+
request_id, response_text
595+
);
596+
597+
let _ = ws_tx
598+
.unbounded_send(Some(Ok(Message::text(response_text))));
599+
}
600+
601+
async move { !closed }
602+
})
603+
.for_each(|_| async {})
604+
.await;
605+
});
606+
}
607+
"stop" => {
608+
got_close_signal.store(true, Ordering::Relaxed);
609+
610+
let request_id = request.id.unwrap_or("1".to_owned());
611+
let close_message = format!(
612+
r#"{{"type":"complete","id":"{}","payload":null}}"#,
613+
request_id
528614
);
615+
let _ = ws_tx.unbounded_send(Some(Ok(Message::text(close_message))));
529616

530-
let values_stream =
531-
match coordinator.subscribe(&graphql_request, &context).await {
532-
Ok(s) => s,
533-
Err(err) => {
534-
let _ = ws_tx.unbounded_send(Some(Ok(Message::text(format!(
535-
r#"{{"type":"error","id":"{}","payload":{}}}"#,
536-
request_id,
537-
serde_json::ser::to_string(&err).unwrap_or(
538-
"Error deserializing GraphQLError".to_owned()
539-
)
540-
)))));
541-
542-
let close_message = format!(
543-
r#"{{"type":"complete","id":"{}","payload":null}}"#,
544-
request_id
545-
);
546-
let _ = ws_tx
547-
.unbounded_send(Some(Ok(Message::text(close_message))));
548-
// close channel
549-
let _ = ws_tx.unbounded_send(None);
550-
return;
551-
}
552-
};
553-
554-
values_stream
555-
.take_while(move |response| {
556-
let request_id = request_id.clone();
557-
let closed = got_close_signal.load(Ordering::Relaxed);
558-
if !closed {
559-
let mut response_text = serde_json::to_string(&response)
560-
.unwrap_or("Error deserializing respone".to_owned());
561-
562-
response_text = format!(
563-
r#"{{"type":"data","id":"{}","payload":{} }}"#,
564-
request_id, response_text
565-
);
566-
567-
let _ = ws_tx
568-
.unbounded_send(Some(Ok(Message::text(response_text))));
569-
}
570-
async move { !closed }
571-
})
572-
.for_each(|_| async {})
573-
.await;
574-
});
617+
// close channel
618+
let _ = ws_tx.unbounded_send(None);
619+
}
620+
_ => {}
575621
}
576-
"stop" => {
577-
got_close_signal.store(true, Ordering::Relaxed);
578-
579-
let request_id = request.id.unwrap_or("1".to_owned());
580-
let close_message = format!(
581-
r#"{{"type":"complete","id":"{}","payload":null}}"#,
582-
request_id
583-
);
584-
let _ = ws_tx.unbounded_send(Some(Ok(Message::text(close_message))));
585622

586-
// close channel
587-
let _ = ws_tx.unbounded_send(None);
588-
}
589-
_ => {}
623+
Ok(())
590624
}
591-
592-
futures::future::ready(())
593625
})
594626
}
595627

0 commit comments

Comments
 (0)