From 6e1a9617f79a9bafa940808541548ee09af096ac Mon Sep 17 00:00:00 2001 From: Andrew Plaza Date: Mon, 1 Jul 2024 16:52:57 -0400 Subject: [PATCH] refactor to allow waiting for stream readiness --- bindings_ffi/src/mls.rs | 45 +++++++++++------ bindings_node/src/streams.rs | 19 ++++---- xmtp_mls/src/groups/subscriptions.rs | 5 +- xmtp_mls/src/subscriptions.rs | 73 ++++++++++++++++++++++------ 4 files changed, 98 insertions(+), 44 deletions(-) diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index 52b7b4db8..3b0fd60dd 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -5,7 +5,7 @@ use crate::GenericError; use std::collections::HashMap; use std::convert::TryInto; use std::sync::Arc; -use tokio::{task::{JoinHandle, AbortHandle}, sync::Mutex}; +use tokio::{task::AbortHandle, sync::Mutex}; use xmtp_api_grpc::grpc_api_helper::Client as TonicApiClient; use xmtp_id::{ associations::{ @@ -42,6 +42,7 @@ use xmtp_mls::{ EncryptedMessageStore, EncryptionKey, StorageOption, }, client::ClientError, + subscriptions::StreamHandle, }; pub type RustXmtpClient = MlsClient; @@ -80,8 +81,7 @@ pub async fn create_client( legacy_signed_private_key_proto: Option>, history_sync_url: Option, ) -> Result, GenericError> { - // TODO: revert - // init_logger(logger); + init_logger(logger); log::info!( "Creating API client for host: {}, isSecure: {}", host, @@ -1143,23 +1143,31 @@ impl From for FfiMessage { #[derive(uniffi::Object, Clone, Debug)] pub struct FfiStreamCloser { - handle: Arc>>>>, + #[allow(clippy::type_complexity)] + stream_handle: Arc>>>>, // for convenience, does not require locking mutex. abort_handle: Arc, } impl FfiStreamCloser { - pub fn new(handle: JoinHandle>) -> Self { + pub fn new(stream_handle: StreamHandle>) -> Self { Self { - abort_handle: Arc::new(handle.abort_handle()), - handle: Arc::new(Mutex::new(Some(handle))), + abort_handle: Arc::new(stream_handle.handle.abort_handle()), + stream_handle: Arc::new(Mutex::new(Some(stream_handle))), + } + } + + #[cfg(test)] + pub async fn wait_for_ready(&self) { + let mut handle = self.stream_handle.lock().await; + if let Some(ref mut h) = &mut *handle { + h.wait_for_ready().await; } } } #[uniffi::export] impl FfiStreamCloser { - /// Signal the stream to end /// Does not wait for the stream to end. pub fn end(&self) { @@ -1168,15 +1176,16 @@ impl FfiStreamCloser { /// End the stream and asyncronously wait for it to shutdown pub async fn end_and_wait(&self) -> Result<(), GenericError> { + if self.abort_handle.is_finished() { return Ok(()); } - let mut handle = self.handle.lock().await; - let handle = handle.take(); - if let Some(h) = handle { - h.abort(); - let join_result = h.await; + let mut stream_handle = self.stream_handle.lock().await; + let stream_handle = stream_handle.take(); + if let Some(h) = stream_handle { + h.handle.abort(); + let join_result = h.handle.await; if matches!(join_result, Err(ref e) if !e.is_cancelled()) { return Err(GenericError::Generic { err: format!("subscription event loop join error {}", join_result.unwrap_err()), @@ -1267,7 +1276,7 @@ mod tests { get_inbox_id_for_address, inbox_owner::SigningError, logger::FfiLogger, FfiConversationCallback, FfiCreateGroupOptions, FfiGroupPermissionsOptions, FfiInboxOwner, FfiListConversationsOptions, FfiListMessagesOptions, FfiMetadataField, FfiPermissionPolicy, - FfiPermissionPolicySet, FfiPermissionUpdateType, + FfiPermissionPolicySet, FfiPermissionUpdateType, FfiGroup }; use std::{ env, @@ -1700,8 +1709,6 @@ mod tests { // Looks like this test might be a separate issue #[tokio::test(flavor = "multi_thread", worker_threads = 5)] async fn test_can_stream_group_messages_for_updates() { - let _ = tracing_subscriber::fmt::try_init(); - let alix = new_test_client().await; let bo = new_test_client().await; @@ -1710,6 +1717,7 @@ mod tests { let stream_messages = bo .conversations() .stream_all_messages(Box::new(message_callbacks.clone())); + stream_messages.wait_for_ready().await; // Create group and send first message let alix_group = alix @@ -1750,6 +1758,7 @@ mod tests { assert_eq!(message_callbacks.message_count(), 3); stream_messages.end_and_wait().await.unwrap(); + assert!(stream_messages.is_closed()); } @@ -1764,6 +1773,7 @@ mod tests { let stream_messages = bo .conversations() .stream_all_messages(Box::new(message_callbacks.clone())); + stream_messages.wait_for_ready().await; let first_msg_check = 2; let second_msg_check = 5; @@ -1887,6 +1897,7 @@ mod tests { let stream = caro .conversations() .stream_all_messages(Box::new(stream_callback.clone())); + stream.wait_for_ready().await; alix_group.send("first".as_bytes().to_vec()).await.unwrap(); stream_callback.wait_for_delivery().await; @@ -1965,6 +1976,7 @@ mod tests { let stream_closer = bola .conversations() .stream_all_messages(Box::new(stream_callback.clone())); + stream_closer.wait_for_ready().await; amal_group.send(b"hello1".to_vec()).await.unwrap(); stream_callback.wait_for_delivery().await; @@ -2066,6 +2078,7 @@ mod tests { let stream_messages = bo .conversations() .stream_all_messages(Box::new(message_callback.clone())); + stream_messages.wait_for_ready().await; // Create group and send first message let alix_group = alix diff --git a/bindings_node/src/streams.rs b/bindings_node/src/streams.rs index 58568bf1b..cbafa8ee0 100644 --- a/bindings_node/src/streams.rs +++ b/bindings_node/src/streams.rs @@ -1,28 +1,29 @@ use std::sync::Arc; -use tokio::{sync::Mutex, task::{JoinHandle, AbortHandle}}; -use xmtp_mls::client::ClientError; +use tokio::{sync::Mutex, task::AbortHandle}; +use xmtp_mls::{client::ClientError, subscriptions::StreamHandle}; use napi::bindgen_prelude::Error; use napi_derive::napi; #[napi] pub struct NapiStreamCloser { - handle: Arc>>>>, + #[allow(clippy::type_complexity)] + handle: Arc>>>>, // for convenience, does not require locking mutex. abort_handle: Arc, } impl NapiStreamCloser { - pub fn new(handle: JoinHandle>) -> Self { + pub fn new(handle: StreamHandle>) -> Self { Self { - abort_handle: Arc::new(handle.abort_handle()), + abort_handle: Arc::new(handle.handle.abort_handle()), handle: Arc::new(Mutex::new(Some(handle))), } } } -impl From>> for NapiStreamCloser { - fn from(handle: JoinHandle>) -> Self { +impl From>> for NapiStreamCloser { + fn from(handle: StreamHandle>) -> Self { NapiStreamCloser::new(handle) } } @@ -45,8 +46,8 @@ impl NapiStreamCloser { let mut handle = self.handle.lock().await; let handle = handle.take(); if let Some(h) = handle { - h.abort(); - let join_result = h.await; + h.handle.abort(); + let join_result = h.handle.await; if matches!(join_result, Err(ref e) if !e.is_cancelled()) { return Err(Error::from_reason( format!("subscription event loop join error {}", join_result.unwrap_err()) diff --git a/xmtp_mls/src/groups/subscriptions.rs b/xmtp_mls/src/groups/subscriptions.rs index e4888a44e..720fe37cb 100644 --- a/xmtp_mls/src/groups/subscriptions.rs +++ b/xmtp_mls/src/groups/subscriptions.rs @@ -3,11 +3,10 @@ use std::pin::Pin; use std::sync::Arc; use futures::Stream; -use tokio::task::JoinHandle; use super::{extract_message_v1, GroupError, MlsGroup}; use crate::storage::group_message::StoredGroupMessage; -use crate::subscriptions::MessagesStreamInfo; +use crate::subscriptions::{MessagesStreamInfo, StreamHandle}; use crate::XmtpApi; use crate::{retry_async, retry::Retry, Client}; use prost::Message; @@ -119,7 +118,7 @@ impl MlsGroup { group_id: Vec, created_at_ns: i64, callback: impl FnMut(StoredGroupMessage) + Send + 'static, - ) -> JoinHandle> + ) -> StreamHandle> where ApiClient: crate::XmtpApi, { diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index c6261119a..13b49e3bb 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -8,6 +8,7 @@ use futures::{Stream, StreamExt}; use prost::Message; use tokio::{ sync::mpsc::self, + sync::oneshot, task::JoinHandle, }; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -23,6 +24,29 @@ use crate::{ Client, XmtpApi, }; +#[derive(Debug)] +/// Wrapper around a [`tokio::task::JoinHandle`] but with a oneshot receiver +/// which allows waiting for a `with_callback` stream fn to be ready for stream items. +pub struct StreamHandle{ + pub handle: JoinHandle, + start: Option> +} + +impl StreamHandle { + /// Waits for the stream to be fully spawned + pub async fn wait_for_ready(&mut self) { + if let Some(s) = self.start.take() { + let _ = s.await; + } + } +} + +impl From> for JoinHandle { + fn from(stream: StreamHandle) -> JoinHandle { + stream.handle + } +} + #[derive(Clone, Debug)] pub(crate) struct MessagesStreamInfo { pub convo_created_at_ns: i64, @@ -195,28 +219,44 @@ where pub fn stream_conversations_with_callback( client: Arc>, mut convo_callback: impl FnMut(MlsGroup) + Send + 'static, - ) -> JoinHandle> { - tokio::spawn(async move { + ) -> StreamHandle> { + let (tx, rx) = oneshot::channel(); + + let handle = tokio::spawn(async move { let mut stream = client.stream_conversations().await.unwrap(); + let _ = tx.send(()); while let Some(convo) = stream.next().await { convo_callback(convo) } Ok(()) - }) + }); + + StreamHandle { + start: Some(rx), + handle + } } pub(crate) fn stream_messages_with_callback( client: Arc>, group_id_to_info: HashMap, MessagesStreamInfo>, mut callback: impl FnMut(StoredGroupMessage) + Send + 'static, - ) -> JoinHandle> { - tokio::spawn(async move { + ) -> StreamHandle> { + let (tx, rx) = oneshot::channel(); + + let handle = tokio::spawn(async move { let mut stream = Self::stream_messages(client, group_id_to_info).await?; + let _ = tx.send(()); while let Some(message) = stream.next().await { callback(message) } Ok(()) - }) + }); + + StreamHandle { + start: Some(rx), + handle + } } pub async fn stream_all_messages( @@ -278,11 +318,9 @@ where pub fn stream_all_messages_with_callback( client: Arc>, mut callback: impl FnMut(StoredGroupMessage) + Send + Sync + 'static, - ) -> JoinHandle> { - // make this call block until it is ready - // otherwise we miss messages - let (tx, rx) = tokio::sync::oneshot::channel(); - + ) -> StreamHandle> { + let (tx, rx) = oneshot::channel(); + let handle = tokio::spawn(async move { let mut stream = Self::stream_all_messages(client).await?; let _ = tx.send(()); @@ -291,10 +329,11 @@ where } Ok(()) }); - - //TODO: dont need this? - let _ = tokio::task::block_in_place(|| rx.blocking_recv()); - handle + + StreamHandle { + start: Some(rx), + handle + } } } @@ -356,10 +395,11 @@ mod tests { let notify = Arc::new(tokio::sync::Notify::new()); let notify_pointer = notify.clone(); - Client::::stream_all_messages_with_callback(Arc::new(caro), move |message| { + let handle = Client::::stream_all_messages_with_callback(Arc::new(caro), move |message| { (*messages_clone.lock().unwrap()).push(message); notify_pointer.notify_one(); }); + handle.wait_for_ready().await; alix_group .send_message("first".as_bytes(), &alix) @@ -412,6 +452,7 @@ mod tests { notify_pointer.notify_one(); (*messages_clone.lock().unwrap()).push(message); }); + let handle = handle.wait_for_ready().await; alix_group .send_message("first".as_bytes(), &alix)