Skip to content

Commit

Permalink
refactor to allow waiting for stream readiness
Browse files Browse the repository at this point in the history
  • Loading branch information
insipx committed Jul 1, 2024
1 parent e669a0f commit 6e1a961
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 44 deletions.
45 changes: 29 additions & 16 deletions bindings_ffi/src/mls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::GenericError;
use std::collections::HashMap;
use std::convert::TryInto;
use std::sync::Arc;
use tokio::{task::{JoinHandle, AbortHandle}, sync::Mutex};
use tokio::{task::AbortHandle, sync::Mutex};
use xmtp_api_grpc::grpc_api_helper::Client as TonicApiClient;
use xmtp_id::{
associations::{
Expand Down Expand Up @@ -42,6 +42,7 @@ use xmtp_mls::{
EncryptedMessageStore, EncryptionKey, StorageOption,
},
client::ClientError,
subscriptions::StreamHandle,
};

pub type RustXmtpClient = MlsClient<TonicApiClient>;
Expand Down Expand Up @@ -80,8 +81,7 @@ pub async fn create_client(
legacy_signed_private_key_proto: Option<Vec<u8>>,
history_sync_url: Option<String>,
) -> Result<Arc<FfiXmtpClient>, GenericError> {
// TODO: revert
// init_logger(logger);
init_logger(logger);
log::info!(
"Creating API client for host: {}, isSecure: {}",
host,
Expand Down Expand Up @@ -1143,23 +1143,31 @@ impl From<StoredGroupMessage> for FfiMessage {

#[derive(uniffi::Object, Clone, Debug)]
pub struct FfiStreamCloser {
handle: Arc<Mutex<Option<JoinHandle<Result<(), ClientError>>>>>,
#[allow(clippy::type_complexity)]
stream_handle: Arc<Mutex<Option<StreamHandle<Result<(), ClientError>>>>>,
// for convenience, does not require locking mutex.
abort_handle: Arc<AbortHandle>,
}

impl FfiStreamCloser {
pub fn new(handle: JoinHandle<Result<(), ClientError>>) -> Self {
pub fn new(stream_handle: StreamHandle<Result<(), ClientError>>) -> Self {
Self {
abort_handle: Arc::new(handle.abort_handle()),
handle: Arc::new(Mutex::new(Some(handle))),
abort_handle: Arc::new(stream_handle.handle.abort_handle()),
stream_handle: Arc::new(Mutex::new(Some(stream_handle))),
}
}

#[cfg(test)]
pub async fn wait_for_ready(&self) {
let mut handle = self.stream_handle.lock().await;
if let Some(ref mut h) = &mut *handle {
h.wait_for_ready().await;
}
}
}

#[uniffi::export]
impl FfiStreamCloser {

/// Signal the stream to end
/// Does not wait for the stream to end.
pub fn end(&self) {
Expand All @@ -1168,15 +1176,16 @@ impl FfiStreamCloser {

/// End the stream and asyncronously wait for it to shutdown
pub async fn end_and_wait(&self) -> Result<(), GenericError> {

if self.abort_handle.is_finished() {
return Ok(());
}

let mut handle = self.handle.lock().await;
let handle = handle.take();
if let Some(h) = handle {
h.abort();
let join_result = h.await;
let mut stream_handle = self.stream_handle.lock().await;
let stream_handle = stream_handle.take();
if let Some(h) = stream_handle {
h.handle.abort();
let join_result = h.handle.await;
if matches!(join_result, Err(ref e) if !e.is_cancelled()) {
return Err(GenericError::Generic {
err: format!("subscription event loop join error {}", join_result.unwrap_err()),
Expand Down Expand Up @@ -1267,7 +1276,7 @@ mod tests {
get_inbox_id_for_address, inbox_owner::SigningError, logger::FfiLogger,
FfiConversationCallback, FfiCreateGroupOptions, FfiGroupPermissionsOptions, FfiInboxOwner,
FfiListConversationsOptions, FfiListMessagesOptions, FfiMetadataField, FfiPermissionPolicy,
FfiPermissionPolicySet, FfiPermissionUpdateType,
FfiPermissionPolicySet, FfiPermissionUpdateType, FfiGroup
};
use std::{
env,
Expand Down Expand Up @@ -1700,8 +1709,6 @@ mod tests {
// Looks like this test might be a separate issue
#[tokio::test(flavor = "multi_thread", worker_threads = 5)]
async fn test_can_stream_group_messages_for_updates() {
let _ = tracing_subscriber::fmt::try_init();

let alix = new_test_client().await;
let bo = new_test_client().await;

Expand All @@ -1710,6 +1717,7 @@ mod tests {
let stream_messages = bo
.conversations()
.stream_all_messages(Box::new(message_callbacks.clone()));
stream_messages.wait_for_ready().await;

// Create group and send first message
let alix_group = alix
Expand Down Expand Up @@ -1750,6 +1758,7 @@ mod tests {
assert_eq!(message_callbacks.message_count(), 3);

stream_messages.end_and_wait().await.unwrap();

assert!(stream_messages.is_closed());
}

Expand All @@ -1764,6 +1773,7 @@ mod tests {
let stream_messages = bo
.conversations()
.stream_all_messages(Box::new(message_callbacks.clone()));
stream_messages.wait_for_ready().await;

let first_msg_check = 2;
let second_msg_check = 5;
Expand Down Expand Up @@ -1887,6 +1897,7 @@ mod tests {
let stream = caro
.conversations()
.stream_all_messages(Box::new(stream_callback.clone()));
stream.wait_for_ready().await;

alix_group.send("first".as_bytes().to_vec()).await.unwrap();
stream_callback.wait_for_delivery().await;
Expand Down Expand Up @@ -1965,6 +1976,7 @@ mod tests {
let stream_closer = bola
.conversations()
.stream_all_messages(Box::new(stream_callback.clone()));
stream_closer.wait_for_ready().await;

amal_group.send(b"hello1".to_vec()).await.unwrap();
stream_callback.wait_for_delivery().await;
Expand Down Expand Up @@ -2066,6 +2078,7 @@ mod tests {
let stream_messages = bo
.conversations()
.stream_all_messages(Box::new(message_callback.clone()));
stream_messages.wait_for_ready().await;

// Create group and send first message
let alix_group = alix
Expand Down
19 changes: 10 additions & 9 deletions bindings_node/src/streams.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
use std::sync::Arc;
use tokio::{sync::Mutex, task::{JoinHandle, AbortHandle}};
use xmtp_mls::client::ClientError;
use tokio::{sync::Mutex, task::AbortHandle};
use xmtp_mls::{client::ClientError, subscriptions::StreamHandle};
use napi::bindgen_prelude::Error;

use napi_derive::napi;

#[napi]
pub struct NapiStreamCloser {
handle: Arc<Mutex<Option<JoinHandle<Result<(), ClientError>>>>>,
#[allow(clippy::type_complexity)]
handle: Arc<Mutex<Option<StreamHandle<Result<(), ClientError>>>>>,
// for convenience, does not require locking mutex.
abort_handle: Arc<AbortHandle>,
}

impl NapiStreamCloser {
pub fn new(handle: JoinHandle<Result<(), ClientError>>) -> Self {
pub fn new(handle: StreamHandle<Result<(), ClientError>>) -> Self {
Self {
abort_handle: Arc::new(handle.abort_handle()),
abort_handle: Arc::new(handle.handle.abort_handle()),
handle: Arc::new(Mutex::new(Some(handle))),
}
}
}

impl From<JoinHandle<Result<(), ClientError>>> for NapiStreamCloser {
fn from(handle: JoinHandle<Result<(), ClientError>>) -> Self {
impl From<StreamHandle<Result<(), ClientError>>> for NapiStreamCloser {
fn from(handle: StreamHandle<Result<(), ClientError>>) -> Self {
NapiStreamCloser::new(handle)
}
}
Expand All @@ -45,8 +46,8 @@ impl NapiStreamCloser {
let mut handle = self.handle.lock().await;
let handle = handle.take();
if let Some(h) = handle {
h.abort();
let join_result = h.await;
h.handle.abort();
let join_result = h.handle.await;
if matches!(join_result, Err(ref e) if !e.is_cancelled()) {
return Err(Error::from_reason(
format!("subscription event loop join error {}", join_result.unwrap_err())
Expand Down
5 changes: 2 additions & 3 deletions xmtp_mls/src/groups/subscriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ use std::pin::Pin;
use std::sync::Arc;

use futures::Stream;
use tokio::task::JoinHandle;

use super::{extract_message_v1, GroupError, MlsGroup};
use crate::storage::group_message::StoredGroupMessage;
use crate::subscriptions::MessagesStreamInfo;
use crate::subscriptions::{MessagesStreamInfo, StreamHandle};
use crate::XmtpApi;
use crate::{retry_async, retry::Retry, Client};
use prost::Message;
Expand Down Expand Up @@ -119,7 +118,7 @@ impl MlsGroup {
group_id: Vec<u8>,
created_at_ns: i64,
callback: impl FnMut(StoredGroupMessage) + Send + 'static,
) -> JoinHandle<Result<(), crate::groups::ClientError>>
) -> StreamHandle<Result<(), crate::groups::ClientError>>
where
ApiClient: crate::XmtpApi,
{
Expand Down
73 changes: 57 additions & 16 deletions xmtp_mls/src/subscriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use futures::{Stream, StreamExt};
use prost::Message;
use tokio::{
sync::mpsc::self,
sync::oneshot,
task::JoinHandle,
};
use tokio_stream::wrappers::UnboundedReceiverStream;
Expand All @@ -23,6 +24,29 @@ use crate::{
Client, XmtpApi,
};

#[derive(Debug)]
/// Wrapper around a [`tokio::task::JoinHandle`] but with a oneshot receiver
/// which allows waiting for a `with_callback` stream fn to be ready for stream items.
pub struct StreamHandle<T>{
pub handle: JoinHandle<T>,
start: Option<oneshot::Receiver<()>>
}

impl<T> StreamHandle<T> {
/// Waits for the stream to be fully spawned
pub async fn wait_for_ready(&mut self) {
if let Some(s) = self.start.take() {
let _ = s.await;
}
}
}

impl<T> From<StreamHandle<T>> for JoinHandle<T> {
fn from(stream: StreamHandle<T>) -> JoinHandle<T> {
stream.handle
}
}

#[derive(Clone, Debug)]
pub(crate) struct MessagesStreamInfo {
pub convo_created_at_ns: i64,
Expand Down Expand Up @@ -195,28 +219,44 @@ where
pub fn stream_conversations_with_callback(
client: Arc<Client<ApiClient>>,
mut convo_callback: impl FnMut(MlsGroup) + Send + 'static,
) -> JoinHandle<Result<(), ClientError>> {
tokio::spawn(async move {
) -> StreamHandle<Result<(), ClientError>> {
let (tx, rx) = oneshot::channel();

let handle = tokio::spawn(async move {
let mut stream = client.stream_conversations().await.unwrap();
let _ = tx.send(());
while let Some(convo) = stream.next().await {
convo_callback(convo)
}
Ok(())
})
});

StreamHandle {
start: Some(rx),
handle
}
}

pub(crate) fn stream_messages_with_callback(
client: Arc<Client<ApiClient>>,
group_id_to_info: HashMap<Vec<u8>, MessagesStreamInfo>,
mut callback: impl FnMut(StoredGroupMessage) + Send + 'static,
) -> JoinHandle<Result<(), ClientError>> {
tokio::spawn(async move {
) -> StreamHandle<Result<(), ClientError>> {
let (tx, rx) = oneshot::channel();

let handle = tokio::spawn(async move {
let mut stream = Self::stream_messages(client, group_id_to_info).await?;
let _ = tx.send(());
while let Some(message) = stream.next().await {
callback(message)
}
Ok(())
})
});

StreamHandle {
start: Some(rx),
handle
}
}

pub async fn stream_all_messages(
Expand Down Expand Up @@ -278,11 +318,9 @@ where
pub fn stream_all_messages_with_callback(
client: Arc<Client<ApiClient>>,
mut callback: impl FnMut(StoredGroupMessage) + Send + Sync + 'static,
) -> JoinHandle<Result<(), ClientError>> {
// make this call block until it is ready
// otherwise we miss messages
let (tx, rx) = tokio::sync::oneshot::channel();

) -> StreamHandle<Result<(), ClientError>> {
let (tx, rx) = oneshot::channel();

let handle = tokio::spawn(async move {
let mut stream = Self::stream_all_messages(client).await?;
let _ = tx.send(());
Expand All @@ -291,10 +329,11 @@ where
}
Ok(())
});

//TODO: dont need this?
let _ = tokio::task::block_in_place(|| rx.blocking_recv());
handle

StreamHandle {
start: Some(rx),
handle
}
}
}

Expand Down Expand Up @@ -356,10 +395,11 @@ mod tests {

let notify = Arc::new(tokio::sync::Notify::new());
let notify_pointer = notify.clone();
Client::<GrpcClient>::stream_all_messages_with_callback(Arc::new(caro), move |message| {
let handle = Client::<GrpcClient>::stream_all_messages_with_callback(Arc::new(caro), move |message| {
(*messages_clone.lock().unwrap()).push(message);
notify_pointer.notify_one();
});
handle.wait_for_ready().await;

alix_group
.send_message("first".as_bytes(), &alix)
Expand Down Expand Up @@ -412,6 +452,7 @@ mod tests {
notify_pointer.notify_one();
(*messages_clone.lock().unwrap()).push(message);
});
let handle = handle.wait_for_ready().await;

alix_group
.send_message("first".as_bytes(), &alix)
Expand Down

0 comments on commit 6e1a961

Please sign in to comment.