diff --git a/examples/crypto/main.rs b/examples/crypto/main.rs index 8dab74c..4b2b325 100644 --- a/examples/crypto/main.rs +++ b/examples/crypto/main.rs @@ -1,6 +1,7 @@ use std::fs; use tokio::fs::File; +use tokio::io::AsyncReadExt; use tokio::time::sleep; use dapr::client::ReaderStream; @@ -28,7 +29,7 @@ async fn main() -> Result<(), Box> { .await .unwrap(); - let decrypted = client + let mut decrypted = client .decrypt( encrypted, dapr::client::DecryptRequestOptions { @@ -39,7 +40,11 @@ async fn main() -> Result<(), Box> { .await .unwrap(); - assert_eq!(String::from_utf8(decrypted).unwrap().as_str(), "Test"); + let mut value = String::new(); + + decrypted.read_to_string(&mut value).await.unwrap(); + + assert_eq!(value.as_str(), "Test"); println!("Successfully Decrypted String"); @@ -60,7 +65,7 @@ async fn main() -> Result<(), Box> { .await .unwrap(); - let decrypted = client + let mut decrypted = client .decrypt( encrypted, dapr::client::DecryptRequestOptions { @@ -73,7 +78,11 @@ async fn main() -> Result<(), Box> { let image = fs::read("./image.png").unwrap(); - assert_eq!(decrypted, image); + let mut buf = bytes::BytesMut::with_capacity(image.len()); + + decrypted.read_buf(&mut buf).await.unwrap(); + + assert_eq!(buf.to_vec(), image); println!("Successfully Decrypted Image"); diff --git a/src/client.rs b/src/client.rs index 712280e..9e01a50 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,10 +1,12 @@ use std::collections::HashMap; +use std::pin::Pin; +use std::task::{Context, Poll}; use async_trait::async_trait; use futures::StreamExt; use prost_types::Any; use serde::{Deserialize, Serialize}; -use tokio::io::AsyncRead; +use tokio::io::{AsyncRead, ReadBuf}; use tonic::codegen::tokio_stream; use tonic::{transport::Channel as TonicChannel, Request}; use tonic::{Status, Streaming}; @@ -394,7 +396,7 @@ impl Client { &mut self, payload: ReaderStream, request_options: EncryptRequestOptions, - ) -> Result, Status> + ) -> Result, Status> where R: AsyncRead + Send, { @@ -433,26 +435,27 @@ impl Client { /// * `options` - Decryption request options. pub async fn decrypt( &mut self, - encrypted: Vec, + mut encrypted_stream: ResponseStream, options: DecryptRequestOptions, - ) -> Result, Status> { - let requested_items: Vec = encrypted - .iter() - .enumerate() - .map(|(i, item)| { - if i == 0 { - DecryptRequest { - options: Some(options.clone()), - payload: Some(item.clone()), - } - } else { - DecryptRequest { - options: None, - payload: Some(item.clone()), + ) -> Result, Status> { + let mut requested_items = vec![]; + while let Some(resp_result) = encrypted_stream.stream.next().await { + if let Ok(resp) = resp_result { + if let Some(payload) = resp.payload { + if requested_items.len() == 0 { + requested_items.push(DecryptRequest { + options: Some(options.clone()), + payload: Some(payload), + }) + } else { + requested_items.push(DecryptRequest { + options: None, + payload: Some(payload), + }) } } - }) - .collect(); + } + } self.0.decrypt(requested_items).await } } @@ -497,10 +500,15 @@ pub trait DaprInterface: Sized { request: UnsubscribeConfigurationRequest, ) -> Result; - async fn encrypt(&mut self, payload: Vec) - -> Result, Status>; + async fn encrypt( + &mut self, + payload: Vec, + ) -> Result, Status>; - async fn decrypt(&mut self, payload: Vec) -> Result, Status>; + async fn decrypt( + &mut self, + payload: Vec, + ) -> Result, Status>; } #[async_trait] @@ -626,19 +634,10 @@ impl DaprInterface for dapr_v1::dapr_client::DaprClient { async fn encrypt( &mut self, request: Vec, - ) -> Result, Status> { + ) -> Result, Status> { let request = Request::new(tokio_stream::iter(request)); - let stream = self.encrypt_alpha1(request).await?; - let mut stream = stream.into_inner(); - let mut return_data = vec![]; - while let Some(resp) = stream.next().await { - if let Ok(resp) = resp { - if let Some(data) = resp.payload { - return_data.push(data) - } - } - } - Ok(return_data) + let stream = self.encrypt_alpha1(request).await?.into_inner(); + Ok(ResponseStream { stream }) } /// Decrypt binary data using Dapr. returns Vec. @@ -647,19 +646,13 @@ impl DaprInterface for dapr_v1::dapr_client::DaprClient { /// /// * `encrypted` - Encrypted data usually returned from encrypted, Vec /// * `options` - Decryption request options. - async fn decrypt(&mut self, request: Vec) -> Result, Status> { + async fn decrypt( + &mut self, + request: Vec, + ) -> Result, Status> { let request = Request::new(tokio_stream::iter(request)); - let stream = self.decrypt_alpha1(request).await?; - let mut stream = stream.into_inner(); - let mut data = vec![]; - while let Some(resp) = stream.next().await { - if let Ok(resp) = resp { - if let Some(mut payload) = resp.payload { - data.append(payload.data.as_mut()) - } - } - } - Ok(data) + let stream = self.decrypt_alpha1(request).await?.into_inner(); + Ok(ResponseStream { stream }) } } @@ -752,6 +745,10 @@ pub type EncryptRequestOptions = crate::dapr::dapr::proto::runtime::v1::EncryptR /// Decryption request options pub type DecryptRequestOptions = crate::dapr::dapr::proto::runtime::v1::DecryptRequestOptions; +pub type EncryptResponse = crate::dapr::dapr::proto::runtime::v1::EncryptResponse; + +pub type DecryptResponse = crate::dapr::dapr::proto::runtime::v1::DecryptResponse; + type StreamPayload = crate::dapr::dapr::proto::common::v1::StreamPayload; impl From<(K, Vec)> for common_v1::StateItem where @@ -773,3 +770,53 @@ impl ReaderStream { ReaderStream(tokio_util::io::ReaderStream::new(data)) } } + +pub struct ResponseStream { + stream: Streaming, +} + +impl AsyncRead for ResponseStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(resp))) => { + if let Some(payload) = resp.payload { + buf.put_slice(&payload.data); + } + Poll::Ready(Ok(())) + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + format!("{:?}", e), + ))), + Poll::Ready(None) => Poll::Ready(Ok(())), + Poll::Pending => Poll::Pending, + } + } +} + +impl AsyncRead for ResponseStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(resp))) => { + if let Some(payload) = resp.payload { + buf.put_slice(&payload.data); + } + Poll::Ready(Ok(())) + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + format!("{:?}", e), + ))), + Poll::Ready(None) => Poll::Ready(Ok(())), + Poll::Pending => Poll::Pending, + } + } +}