diff --git a/src/next/actor.rs b/src/next/actor.rs index 8fc8807..b4f8166 100644 --- a/src/next/actor.rs +++ b/src/next/actor.rs @@ -25,7 +25,7 @@ use super::{ /// This type implements `IntoFuture` and should usually be spawned /// with an async runtime. pub struct ConnectionActor { - client: Option>, + client: async_channel::Receiver, connection: Box, operations: HashMap>, keep_alive: KeepAliveSettings, @@ -39,7 +39,7 @@ impl ConnectionActor { keep_alive: KeepAliveSettings, ) -> Self { ConnectionActor { - client: Some(client), + client, connection, operations: HashMap::new(), keep_alive_actor: Box::pin(keep_alive.run()), @@ -160,35 +160,24 @@ impl ConnectionActor { Message(Option), KeepAlive(Option), } - loop { - if let Some(client) = &mut self.client { - let command = async { Select::Command(client.recv().await.ok()) }; - let message = async { Select::Message(self.connection.receive().await) }; - let keep_alive = async { Select::KeepAlive(self.keep_alive_actor.next().await) }; - - match command.or(message).or(keep_alive).await { - Select::Command(Some(command)) | Select::KeepAlive(Some(command)) => { - return Some(Next::Command(command)); - } - Select::Command(None) => { - self.client.take(); - continue; - } - Select::Message(message) => { - self.keep_alive_actor = Box::pin(self.keep_alive.run()); - return Some(Next::Message(message?)); - } - Select::KeepAlive(None) => { - return Some(self.keep_alive.report_timeout()); - } - } - } - if self.operations.is_empty() { - // If client has disconnected and we have no running operations - // then we should shut down - return None; + let command = async { Select::Command(self.client.recv().await.ok()) }; + let message = async { Select::Message(self.connection.receive().await) }; + let keep_alive = async { Select::KeepAlive(self.keep_alive_actor.next().await) }; + + match command.or(message).or(keep_alive).await { + Select::Command(Some(command)) | Select::KeepAlive(Some(command)) => { + Some(Next::Command(command)) + } + Select::Command(None) => { + // All clients have disconnected + None + } + Select::Message(message) => { + self.keep_alive_actor = Box::pin(self.keep_alive.run()); + Some(Next::Message(message?)) } + Select::KeepAlive(None) => Some(self.keep_alive.report_timeout()), } } } diff --git a/src/next/connection.rs b/src/next/connection.rs index d46891d..6044710 100644 --- a/src/next/connection.rs +++ b/src/next/connection.rs @@ -3,7 +3,7 @@ use std::pin::Pin; use crate::Error; -/// Abstrction around a websocket connection. +/// Abstraction around a websocket connection. /// /// Built in implementations are provided for `ws_stream_wasm` & `async_tungstenite`. /// diff --git a/src/protocol.rs b/src/protocol.rs index 4ae5865..73e62cf 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -45,6 +45,7 @@ pub enum Message<'a, Operation> { Pong, } +#[allow(dead_code)] #[derive(serde::Deserialize, Debug)] #[serde(tag = "type")] pub enum Event { diff --git a/tests/graphql-client-tests.rs b/tests/graphql-client-tests.rs index 8cd6d2f..dab961c 100644 --- a/tests/graphql-client-tests.rs +++ b/tests/graphql-client-tests.rs @@ -1,6 +1,7 @@ use std::{future::IntoFuture, time::Duration}; use assert_matches::assert_matches; +use async_tungstenite::tungstenite::{client::IntoClientRequest, http::HeaderValue}; use futures_lite::{future, StreamExt}; use graphql_client::GraphQLQuery; use graphql_ws_client::graphql::StreamingOperation; @@ -19,8 +20,6 @@ struct BooksChanged; #[tokio::test] async fn main_test() { - use async_tungstenite::tungstenite::{client::IntoClientRequest, http::HeaderValue}; - let server = SubscriptionServer::start().await; sleep(Duration::from_millis(20)).await; @@ -82,8 +81,6 @@ async fn main_test() { #[tokio::test] async fn oneshot_operation_test() { - use async_tungstenite::tungstenite::{client::IntoClientRequest, http::HeaderValue}; - let server = SubscriptionServer::start().await; sleep(Duration::from_millis(20)).await; @@ -141,6 +138,67 @@ async fn oneshot_operation_test() { .await; } +#[tokio::test] +async fn multiple_clients_test() { + async fn inner(server: &SubscriptionServer) { + // Open connection + let mut request = server.websocket_url().into_client_request().unwrap(); + request.headers_mut().insert( + "Sec-WebSocket-Protocol", + HeaderValue::from_str("graphql-transport-ws").unwrap(), + ); + let (connection, _) = async_tungstenite::tokio::connect_async(request) + .await + .unwrap(); + + // Connect / Subscribe + let (client, actor) = graphql_ws_client::Client::build(connection).await.unwrap(); + tokio::spawn(actor.into_future()); + let mut stream = client.subscribe(build_query()).await.unwrap(); + + sleep(Duration::from_millis(20)).await; + + // Send / Receive + server + .send(subscription_server::BookChanged { + id: "123".into(), + book: None, + }) + .unwrap(); + let update = stream.next().await.unwrap().unwrap(); + assert_eq!(update.data.unwrap().books.id, "123"); + } + + // Start server + let server = SubscriptionServer::start().await; + sleep(Duration::from_millis(20)).await; + + // Open connection + let mut request = server.websocket_url().into_client_request().unwrap(); + request.headers_mut().insert( + "Sec-WebSocket-Protocol", + HeaderValue::from_str("graphql-transport-ws").unwrap(), + ); + let (connection, _) = async_tungstenite::tokio::connect_async(request) + .await + .unwrap(); + + // Connect / Subscribe + let (client, actor) = graphql_ws_client::Client::build(connection).await.unwrap(); + tokio::spawn(actor.into_future()); + let mut stream = client.subscribe(build_query()).await.unwrap(); + + // Spawn another client + inner(&server).await; + + // Receive + let update = stream.next().await.unwrap().unwrap(); + assert_eq!(update.data.unwrap().books.id, "123"); + + let res = tokio::time::timeout(Duration::from_millis(100), stream.next()).await; + assert!(res.is_err()) +} + fn build_query() -> graphql_ws_client::graphql::StreamingOperation { StreamingOperation::new(books_changed::Variables) }