Skip to content

Commit

Permalink
Simplify the implementation a bit (#57)
Browse files Browse the repository at this point in the history
- Removed the `sender_loop` function, replaced it with `StreamExt::forward`
- Noticed the `GraphqlClient` generic parameter on `SubscriptionStream` wasn't doing anything so got rid of it.
- Switched the subscription IDs to be integers so I could remove the `uuid` dependency.  It didn't seem worth the extra compile time for that one feature.
  • Loading branch information
obmarg authored Jan 19, 2024
1 parent 2585839 commit 625054c
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 78 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ all APIs might be changed.

## Unreleased - xxxx-xx-xx

### Breaking Changes

- Subscription IDs sent to the server are now just monotonic numbers rather
than uuids.
- `SubscriptionStream` no longer takes `GraphqlClient` as a generic parameter

## v0.7.0 - 2024-01-03

### Breaking Changes
Expand Down
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ members = ["examples", "examples-wasm"]
default = ["async-tungstenite"]
client-cynic = ["async-tungstenite", "cynic"]
client-graphql-client = ["async-tungstenite", "graphql_client"]
ws_stream_wasm = ["dep:ws_stream_wasm", "uuid/js", "no-logging", "pharos", "pin-project-lite"]
ws_stream_wasm = ["dep:ws_stream_wasm", "no-logging", "pharos", "pin-project-lite"]
no-logging = []

[dependencies]
Expand All @@ -31,7 +31,6 @@ pin-project = "1"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0"
uuid = { version = "1.0", features = ["v4"] }

cynic = { version = "3", optional = true }
async-tungstenite = { version = "0.24", optional = true }
Expand Down
113 changes: 37 additions & 76 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
use std::{collections::HashMap, marker::PhantomData, pin::Pin, sync::Arc};
use std::{
collections::HashMap,
marker::PhantomData,
pin::Pin,
sync::{
atomic::{self, AtomicU64},
Arc,
},
};

use futures::{
channel::{mpsc, oneshot},
channel::mpsc,
future::RemoteHandle,
lock::Mutex,
sink::{Sink, SinkExt},
stream::{Stream, StreamExt},
task::{Context, Poll, SpawnExt},
};
use serde::Serialize;
use uuid::Uuid;

use super::{
graphql::{self, GraphqlOperation},
Expand All @@ -27,6 +34,7 @@ where
{
inner: Arc<ClientInner<GraphqlClient>>,
sender_sink: mpsc::Sender<WsMessage>,
next_id: AtomicU64,
phantom: PhantomData<GraphqlClient>,
}

Expand Down Expand Up @@ -133,15 +141,14 @@ where

let (mut sender_sink, sender_stream) = mpsc::channel(1);

let (shutdown_sender, shutdown_receiver) = oneshot::channel();

let sender_handle = runtime
.spawn_with_handle(sender_loop(
sender_stream,
websocket_sink,
Arc::clone(&operations),
shutdown_receiver,
))
.spawn_with_handle(async move {
sender_stream
.map(Ok)
.forward(websocket_sink)
.await
.map_err(|error| Error::Send(error.to_string()))
})
.map_err(|err| Error::SpawnHandle(err.to_string()))?;

// wait for ack before entering receiver loop:
Expand Down Expand Up @@ -185,7 +192,6 @@ where
websocket_stream,
sender_sink.clone(),
Arc::clone(&operations),
shutdown_sender,
))
.map_err(|err| Error::SpawnHandle(err.to_string()))?;

Expand All @@ -195,6 +201,7 @@ where
operations,
sender_handle,
}),
next_id: 0.into(),
sender_sink,
phantom: PhantomData,
})
Expand All @@ -218,12 +225,12 @@ where
pub async fn streaming_operation<'a, Operation>(
&mut self,
op: Operation,
) -> Result<SubscriptionStream<GraphqlClient, Operation>, Error>
) -> Result<SubscriptionStream<Operation>, Error>
where
Operation:
GraphqlOperation<GenericResponse = GraphqlClient::Response> + Unpin + Send + 'static,
{
let id = Uuid::new_v4();
let id = self.next_id.fetch_add(1, atomic::Ordering::Relaxed);
let (sender, receiver) = mpsc::channel(SUBSCRIPTION_BUFFER_SIZE);

self.inner.operations.lock().await.insert(id, sender);
Expand All @@ -242,7 +249,7 @@ where
let mut sender_clone = self.sender_sink.clone();
let id_clone = id.to_string();

Ok(SubscriptionStream::<GraphqlClient, Operation> {
Ok(SubscriptionStream::<Operation> {
id: id.to_string(),
stream: Box::pin(receiver.map(move |response| {
op.decode(response)
Expand All @@ -260,7 +267,6 @@ where
Ok(())
})
}),
phantom: PhantomData,
})
}
}
Expand All @@ -269,32 +275,28 @@ where
///
/// Emits an item for each message received by the subscription.
#[pin_project::pin_project]
pub struct SubscriptionStream<GraphqlClient, Operation>
pub struct SubscriptionStream<Operation>
where
GraphqlClient: graphql::GraphqlClient,
Operation: GraphqlOperation<GenericResponse = GraphqlClient::Response>,
Operation: GraphqlOperation,
{
id: String,
stream: Pin<Box<dyn Stream<Item = Result<Operation::Response, Error>> + Send>>,
cancel_func: Box<dyn FnOnce() -> futures::future::BoxFuture<'static, Result<(), Error>> + Send>,
phantom: PhantomData<GraphqlClient>,
}

impl<GraphqlClient, Operation> SubscriptionStream<GraphqlClient, Operation>
impl<Operation> SubscriptionStream<Operation>
where
GraphqlClient: graphql::GraphqlClient + Send,
Operation: GraphqlOperation<GenericResponse = GraphqlClient::Response> + Send,
Operation: GraphqlOperation + Send,
{
/// Stops the operation by sending a Complete message to the server.
pub async fn stop_operation(self) -> Result<(), Error> {
(self.cancel_func)().await
}
}

impl<GraphqlClient, Operation> Stream for SubscriptionStream<GraphqlClient, Operation>
impl<Operation> Stream for SubscriptionStream<Operation>
where
GraphqlClient: graphql::GraphqlClient,
Operation: GraphqlOperation<GenericResponse = GraphqlClient::Response> + Unpin,
Operation: GraphqlOperation + Unpin,
{
type Item = Result<Operation::Response, Error>;

Expand All @@ -305,13 +307,12 @@ where

type OperationSender<GenericResponse> = mpsc::Sender<GenericResponse>;

type OperationMap<GenericResponse> = Arc<Mutex<HashMap<Uuid, OperationSender<GenericResponse>>>>;
type OperationMap<GenericResponse> = Arc<Mutex<HashMap<u64, OperationSender<GenericResponse>>>>;

async fn receiver_loop<S, WsMessage, GraphqlClient>(
mut receiver: S,
mut sender: mpsc::Sender<WsMessage>,
operations: OperationMap<GraphqlClient::Response>,
shutdown: oneshot::Sender<()>,
) -> Result<(), Error>
where
S: Stream<Item = Result<WsMessage, WsMessage::Error>> + Unpin,
Expand All @@ -330,9 +331,10 @@ where
}
}

shutdown
.send(())
.map_err(|_| Error::SenderShutdown("Couldn't shutdown sender".to_owned()))
// Clear out any operations
operations.lock().await.clear();

Ok(())
}

async fn handle_message<WsMessage, GraphqlClient>(
Expand All @@ -355,7 +357,10 @@ where
};

let id = match event.id() {
Some(id) => Some(Uuid::parse_str(id).map_err(|err| Error::Decode(err.to_string()))?),
Some(id) => Some(
id.parse::<u64>()
.map_err(|err| Error::Decode(err.to_string()))?,
),
None => None,
};

Expand Down Expand Up @@ -414,50 +419,6 @@ where
Ok(())
}

async fn sender_loop<M, S, E, GenericResponse>(
message_stream: mpsc::Receiver<M>,
mut ws_sender: S,
operations: OperationMap<GenericResponse>,
shutdown: oneshot::Receiver<()>,
) -> Result<(), Error>
where
M: WebsocketMessage,
S: Sink<M, Error = E> + Unpin,
E: std::error::Error,
{
use futures::{future::FutureExt, select};

let mut message_stream = message_stream.fuse();
let mut shutdown = shutdown.fuse();

loop {
select! {
msg = message_stream.next() => {
if let Some(msg) = msg {
trace!("Sending message: {:?}", msg);
ws_sender
.send(msg)
.await
.map_err(|err| Error::Send(err.to_string()))?;
} else {
return Ok(());
}
}
_ = shutdown => {
// Shutdown the incoming message stream
let mut message_stream = message_stream.into_inner();
message_stream.close();
while message_stream.next().await.is_some() {}

// Clear out any operations
operations.lock().await.clear();

return Ok(());
}
}
}
}

struct ClientInner<GraphqlClient>
where
GraphqlClient: crate::graphql::GraphqlClient,
Expand Down

0 comments on commit 625054c

Please sign in to comment.