diff --git a/async-nats/src/ext.rs b/async-nats/src/ext.rs new file mode 100644 index 000000000..94804bdc5 --- /dev/null +++ b/async-nats/src/ext.rs @@ -0,0 +1,169 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use futures::Stream; +use serde::de::DeserializeOwned; + +pub trait SubscribeExt: Stream +where + M: MessageTrait, +{ + fn for_type(self) -> TypedStream + where + Self: Sized, + T: DeserializeOwned, + { + TypedStream::new(self) + } +} + +impl SubscribeExt for S +where + S: Stream, + M: MessageTrait, +{ +} + +pin_project_lite::pin_project! { + pub struct TypedStream { + #[pin] + stream: S, + _phantom: std::marker::PhantomData, + } +} + +impl TypedStream { + fn new(stream: S) -> Self { + Self { + stream, + _phantom: std::marker::PhantomData, + } + } +} + +pub trait MessageTrait { + // fn payload(&self) -> Bytes; + // fn subject(&self) -> Subject; + // fn reply(&self) -> Option; + // fn headers(&self) -> Option; + // fn status(&self) -> Option; + // fn description(&self) -> Option; + // fn length(&self) -> usize; + fn payload(&self) -> &[u8]; +} + +impl MessageTrait for crate::Message { + fn payload(&self) -> &[u8] { + self.payload.as_ref() + } +} + +impl MessageTrait for crate::PublishMessage { + fn payload(&self) -> &[u8] { + self.payload.as_ref() + } +} + +impl MessageTrait for crate::jetstream::message::Message { + fn payload(&self) -> &[u8] { + self.payload.as_ref() + } +} + +impl Stream for TypedStream +where + S: Stream, + T: DeserializeOwned, + M: MessageTrait, +{ + type Item = serde_json::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + match this.stream.poll_next(cx) { + Poll::Ready(message) => match message { + Some(message) => { + let message = message.payload(); + Poll::Ready(Some(serde_json::from_slice(&message))) + } + None => Poll::Ready(None), + }, + Poll::Pending => Poll::Pending, + } + } +} + +#[cfg(test)] +mod test { + use futures::StreamExt; + use futures::TryStreamExt; + use serde::Serialize; + + use super::SubscribeExt; + use crate::PublishMessage; + + #[tokio::test] + async fn for_type() { + use futures::stream; + use serde::Deserialize; + + #[derive(Serialize, Deserialize, Debug, PartialEq)] + struct Test { + a: i32, + b: String, + } + + struct OtherTest { + data: (i32, String), + } + + // Prepare some messages + let messages = vec![ + PublishMessage { + subject: "test".into(), + payload: serde_json::to_vec(&Test { + a: 1, + b: "a".to_string(), + }) + .unwrap() + .into(), + reply: None, + headers: Default::default(), + }, + PublishMessage { + subject: "test".into(), + payload: serde_json::to_vec(&Test { + a: 2, + b: "b".to_string(), + }) + .unwrap() + .into(), + reply: None, + headers: Default::default(), + }, + ]; + + // Simulate a stream of messages + let stream = stream::iter(messages); + + // first deserialize into a concrete type + let stream = stream + .for_type::() + // and then transform into another type + .and_then(|item| async move { + Ok(OtherTest { + data: (item.a, item.b), + }) + }); + + // Don't worry, that is just Rust bs about pinning data. + let mut stream = Box::pin(stream); + + // see that it works. + assert_eq!(stream.next().await.unwrap().unwrap().data.0, 1); + assert_eq!(stream.next().await.unwrap().unwrap().data.0, 2); + } +}